Skip to content

Commit 40996f2

Browse files
jereliucopybara-github
authored andcommitted
Removes unnecessary ViT-GP hyper-parameters.
PiperOrigin-RevId: 388484029
1 parent e11fa50 commit 40996f2

File tree

3 files changed

+31
-44
lines changed

3 files changed

+31
-44
lines changed

baselines/jft/experiments/jft300m_vit_base16_sngp.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,17 @@ def get_config():
4040

4141
pp_common = '|value_range(-1, 1)'
4242
pp_common += f'|onehot({config.num_classes})'
43-
# To use ancestor "smearing", use this line instead:
44-
# pp_common += f'|onehot({config.num_classes}, key="labels_extended", key_result="labels") # pylint: disable=line-too-long
43+
# To use ancestor 'smearing', use this line instead:
44+
# pp_common += f'|onehot({config.num_classes}, key='labels_extended', key_result='labels') # pylint: disable=line-too-long
4545
pp_common += '|keep("image", "labels")'
4646
config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common
4747
config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
4848
config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
4949

5050
config.log_training_steps = 50
5151
config.log_eval_steps = 1000
52-
# NOTE: eval is very fast O(seconds) so it's fine to run it often.
53-
config.checkpoint_steps = 1000
52+
# NOTE: For pretraining, save infrequently to prevent crowding diskspace.
53+
config.checkpoint_steps = 517790
5454

5555
# Model section
5656
config.model = ml_collections.ConfigDict()
@@ -66,11 +66,11 @@ def get_config():
6666
config.model.classifier = 'token' # Or 'gap'
6767
config.model.representation_size = 768
6868

69-
# GP layer parameters.
69+
# Gaussian process layer parameters.
7070
config.gp_layer = ml_collections.ConfigDict()
71-
config.gp_layer.normalize_input = True
72-
config.gp_layer.random_feature_scale = 1. # 1. or None
73-
config.gp_layer.random_feature_stddev = 0.025 # 1. or 0.025
71+
# Use momentum for pre-training to prevent numeric error when inverting a
72+
# precision matrix accumulated over 300M data.
73+
config.gp_layer.covmat_momentum = .999
7474

7575
# Optimizer section
7676
config.optim_name = 'Adam'
@@ -82,7 +82,8 @@ def get_config():
8282

8383
# TODO(lbeyer): make a mini-language like preprocessings.
8484
config.lr = ml_collections.ConfigDict()
85-
config.lr.base = 8e-4 # LR has to be lower for larger models!
85+
# LR has to be lower for GP layer and on larger models.
86+
config.lr.base = 4e-4
8687
config.lr.warmup_steps = 10_000
8788
config.lr.decay_type = 'linear'
8889
config.lr.linear_end = 1e-5
@@ -96,9 +97,4 @@ def get_config():
9697

9798

9899
def get_sweep(hyper):
99-
# lr_grid = [3e-4, 4e-4, 5e-4, 6e-4]
100-
# stddev_grid = [0.01, 0.02, 0.03, 0.04, 0.05]
101-
return hyper.product([
102-
# hyper.sweep('config.lr.base', lr_grid),
103-
# hyper.sweep('config.gp_layer.random_feature_stddev', stddev_grid)
104-
])
100+
return hyper.product([])

baselines/jft/sngp.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def accumulate_gradient_with_states(
6565
accum_steps):
6666
"""Improved version of `u.accumulate_gradient()` that allows for states."""
6767
# This function handles the `loss_and_grad_fn` function which takes a state
68-
# arguement and returns ((losses, states), grads).
68+
# argument and returns ((losses, states), grads).
6969
if accum_steps and accum_steps > 1:
7070
assert images.shape[0] % accum_steps == 0, (
7171
f'Bad accum_steps {accum_steps} for batch size {images.shape[0]}')
@@ -95,27 +95,16 @@ def acc_grad_and_loss(i, l_s_g):
9595

9696

9797
def get_gp_kwargs(gp_config):
98-
"""Extract keyword arguement parameters for the Gaussian process layer."""
99-
normalize_input = gp_config.get('normalize_input', True)
100-
kernel_stddev = gp_config.get('random_feature_stddev', 1.)
101-
feature_scale = gp_config.get('random_feature_scale', -1.)
98+
"""Extract keyword argument parameters for the Gaussian process layer."""
10299
covmat_momentum = gp_config.get('covmat_momentum', 0.999)
103100

104-
logging.info('gp_config.normalize_input = %s', normalize_input)
105-
logging.info('gp_config.random_feature_stddev = %s', kernel_stddev)
106-
logging.info('gp_config.random_feature_scale = %s', feature_scale)
101+
# Extracts model parameter.
107102
logging.info('gp_config.covmat_momentum = %s', covmat_momentum)
108-
109-
feature_scale = None if feature_scale < 0. else feature_scale
110-
kernel_init = nn.initializers.normal(stddev=kernel_stddev)
111-
hidden_kwargs = dict(feature_scale=feature_scale, kernel_init=kernel_init)
103+
covmat_momentum = None if covmat_momentum < 0. else covmat_momentum
112104
covmat_kwargs = dict(momentum=covmat_momentum)
113105

114-
# Assemble into kwargs dictionary.
115-
gp_layer_kwargs = dict(
116-
normalize_input=normalize_input,
117-
hidden_kwargs=hidden_kwargs,
118-
covmat_kwargs=covmat_kwargs)
106+
# Assembles into kwargs dictionary.
107+
gp_layer_kwargs = dict(covmat_kwargs=covmat_kwargs)
119108

120109
return gp_layer_kwargs
121110

@@ -327,7 +316,7 @@ def representation_fn(params, images, labels, mask, states):
327316
@partial(jax.pmap, axis_name='batch', donate_argnums=(0,))
328317
def update_fn(opt, states, lr, images, labels, rng):
329318
"""Update step."""
330-
319+
# TODO(jereliu): Expand to allow precision matrix resetting.
331320
measurements = {}
332321

333322
if config.get('mixup') and config.mixup.p:
@@ -413,17 +402,17 @@ def decay_fn(v, wd):
413402
checkpoint['states'],
414403
checkpoint['extra'])
415404
elif config.get('model_init'):
416-
write_note(f'Initialize model from {config.model_init}...')
417-
raise ValueError(
418-
'Load from `config.model_init` checkpoint is currently not supported.')
405+
# Load trainable parameters from the checkpoint.
406+
# This does not cause issue for SNGP since all non-trainable parameters
407+
# (random feature, precision matrix, etc) are last-layer parameters that
408+
# should be re-trained during fine-tuning.
409+
write_note(f'Initialize trainable parameters from {config.model_init}...')
419410
# TODO(dusenberrymw): Replace and test load function.
420-
# pylint:disable=unreachable
421411
loaded = resformer.load(params_cpu, config.model_init, config.get('model'))
422412
opt_cpu = opt_cpu.replace(target=loaded)
423413
if jax.host_id() == 0:
424414
logging.info('Restored parameter overview:')
425415
parameter_overview.log_parameter_overview(loaded)
426-
# pylint:enable=unreachable
427416

428417
write_note('Kicking off misc stuff...')
429418
first_step = int(opt_cpu.state.step) # Might be a DeviceArray type.
@@ -472,6 +461,7 @@ def decay_fn(v, wd):
472461
mw.step_start(step)
473462

474463
with jax.profiler.TraceContext('train_step', step_num=step, _r=1):
464+
# TODO(jereliu): Expand to allow precision matrix resetting.
475465
(opt_repl, states_repl, loss_value, rngs_loop,
476466
extra_measurements) = update_fn(
477467
opt_repl,
@@ -495,8 +485,9 @@ def decay_fn(v, wd):
495485
# alive while they'll be updated in a future step, creating hard to debug
496486
# memory errors (see b/160593526). Also, takes device 0's params only.
497487
# We will also do the same for untrainable parameters (`states`). This is
498-
# ok since both `random features` and `predictive covariance` are frozen
499-
# or task-specific parameters that are not important for pre-training.
488+
# ok since `random features` are frozen throughout pre-training, and
489+
# `predictive covariance` are irrelevant for downstream finetuning and
490+
# will be discarded anyway.
500491
opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
501492
states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl)
502493

baselines/jft/sngp_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,10 @@ def get_config(classifier, representation_size):
116116
class SNGPTest(parameterized.TestCase, tf.test.TestCase):
117117

118118
@parameterized.parameters(
119-
('token', 2, 1111.4404296875, 16258.519965277777, 0.16999999806284904),
120-
('token', None, 13992.8515625, 3621.3713107638887, 0.20999999344348907),
121-
('gap', 2, 8779.61328125, 3998.798285590278, 0.12999999895691872),
122-
('gap', None, 11279.3515625, 3212.2536892361113, 0.2199999988079071),
119+
('token', 2, 916.2851, 1954.3369140625, 0.16999999806284904),
120+
('token', None, 290.0307, 915.987548828125, 0.20999999344348907),
121+
('gap', 2, 695.6460, 600.8613823784722, 0.12999999895691872),
122+
('gap', None, 192.9434, 341.7078450520833, 0.2199999988079071),
123123
)
124124
def test_sngp_script(self, classifier, representation_size,
125125
correct_train_loss, correct_val_loss,

0 commit comments

Comments
 (0)