@@ -65,7 +65,7 @@ def accumulate_gradient_with_states(
6565 accum_steps ):
6666 """Improved version of `u.accumulate_gradient()` that allows for states."""
6767 # This function handles the `loss_and_grad_fn` function which takes a state
68- # arguement and returns ((losses, states), grads).
68+ # argument and returns ((losses, states), grads).
6969 if accum_steps and accum_steps > 1 :
7070 assert images .shape [0 ] % accum_steps == 0 , (
7171 f'Bad accum_steps { accum_steps } for batch size { images .shape [0 ]} ' )
@@ -95,27 +95,16 @@ def acc_grad_and_loss(i, l_s_g):
9595
9696
9797def get_gp_kwargs (gp_config ):
98- """Extract keyword arguement parameters for the Gaussian process layer."""
99- normalize_input = gp_config .get ('normalize_input' , True )
100- kernel_stddev = gp_config .get ('random_feature_stddev' , 1. )
101- feature_scale = gp_config .get ('random_feature_scale' , - 1. )
98+ """Extract keyword argument parameters for the Gaussian process layer."""
10299 covmat_momentum = gp_config .get ('covmat_momentum' , 0.999 )
103100
104- logging .info ('gp_config.normalize_input = %s' , normalize_input )
105- logging .info ('gp_config.random_feature_stddev = %s' , kernel_stddev )
106- logging .info ('gp_config.random_feature_scale = %s' , feature_scale )
101+ # Extracts model parameter.
107102 logging .info ('gp_config.covmat_momentum = %s' , covmat_momentum )
108-
109- feature_scale = None if feature_scale < 0. else feature_scale
110- kernel_init = nn .initializers .normal (stddev = kernel_stddev )
111- hidden_kwargs = dict (feature_scale = feature_scale , kernel_init = kernel_init )
103+ covmat_momentum = None if covmat_momentum < 0. else covmat_momentum
112104 covmat_kwargs = dict (momentum = covmat_momentum )
113105
114- # Assemble into kwargs dictionary.
115- gp_layer_kwargs = dict (
116- normalize_input = normalize_input ,
117- hidden_kwargs = hidden_kwargs ,
118- covmat_kwargs = covmat_kwargs )
106+ # Assembles into kwargs dictionary.
107+ gp_layer_kwargs = dict (covmat_kwargs = covmat_kwargs )
119108
120109 return gp_layer_kwargs
121110
@@ -327,7 +316,7 @@ def representation_fn(params, images, labels, mask, states):
327316 @partial (jax .pmap , axis_name = 'batch' , donate_argnums = (0 ,))
328317 def update_fn (opt , states , lr , images , labels , rng ):
329318 """Update step."""
330-
319+ # TODO(jereliu): Expand to allow precision matrix resetting.
331320 measurements = {}
332321
333322 if config .get ('mixup' ) and config .mixup .p :
@@ -413,17 +402,17 @@ def decay_fn(v, wd):
413402 checkpoint ['states' ],
414403 checkpoint ['extra' ])
415404 elif config .get ('model_init' ):
416- write_note (f'Initialize model from { config .model_init } ...' )
417- raise ValueError (
418- 'Load from `config.model_init` checkpoint is currently not supported.' )
405+ # Load trainable parameters from the checkpoint.
406+ # This does not cause issue for SNGP since all non-trainable parameters
407+ # (random feature, precision matrix, etc) are last-layer parameters that
408+ # should be re-trained during fine-tuning.
409+ write_note (f'Initialize trainable parameters from { config .model_init } ...' )
419410 # TODO(dusenberrymw): Replace and test load function.
420- # pylint:disable=unreachable
421411 loaded = resformer .load (params_cpu , config .model_init , config .get ('model' ))
422412 opt_cpu = opt_cpu .replace (target = loaded )
423413 if jax .host_id () == 0 :
424414 logging .info ('Restored parameter overview:' )
425415 parameter_overview .log_parameter_overview (loaded )
426- # pylint:enable=unreachable
427416
428417 write_note ('Kicking off misc stuff...' )
429418 first_step = int (opt_cpu .state .step ) # Might be a DeviceArray type.
@@ -472,6 +461,7 @@ def decay_fn(v, wd):
472461 mw .step_start (step )
473462
474463 with jax .profiler .TraceContext ('train_step' , step_num = step , _r = 1 ):
464+ # TODO(jereliu): Expand to allow precision matrix resetting.
475465 (opt_repl , states_repl , loss_value , rngs_loop ,
476466 extra_measurements ) = update_fn (
477467 opt_repl ,
@@ -495,8 +485,9 @@ def decay_fn(v, wd):
495485 # alive while they'll be updated in a future step, creating hard to debug
496486 # memory errors (see b/160593526). Also, takes device 0's params only.
497487 # We will also do the same for untrainable parameters (`states`). This is
498- # ok since both `random features` and `predictive covariance` are frozen
499- # or task-specific parameters that are not important for pre-training.
488+ # ok since `random features` are frozen throughout pre-training, and
489+ # `predictive covariance` are irrelevant for downstream finetuning and
490+ # will be discarded anyway.
500491 opt_cpu = jax .tree_map (lambda x : np .array (x [0 ]), opt_repl )
501492 states_cpu = jax .tree_map (lambda x : np .array (x [0 ]), states_repl )
502493
0 commit comments