Skip to content

Commit cbc94e0

Browse files
committed
fixes
1 parent ad6c0c0 commit cbc94e0

File tree

4 files changed

+14
-19
lines changed

4 files changed

+14
-19
lines changed

fast_llm/layers/vision_encoder/adapter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,18 @@ class VisionAdapter(Layer):
1818

1919
def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace):
2020
super().__init__()
21-
input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)
21+
input_dim = tensor_space[VisionEncoderDimNames.out_channels]
2222
self._activation_type = config.adapter_activation_type
2323
self.layer_1 = Linear(
2424
input_dim,
25-
tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size),
25+
tensor_space[VisionEncoderDimNames.adapter_size],
2626
bias=True,
2727
weight_init_method=init_normal_(std=config.adapter_init_method_std),
2828
bias_init_method=init_normal_(std=config.adapter_init_method_std),
2929
)
3030
self.layer_2 = Linear(
31-
tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size),
32-
tensor_space.get_tensor_dim(TransformerDimNames.hidden),
31+
tensor_space[VisionEncoderDimNames.adapter_size],
32+
tensor_space[TransformerDimNames.hidden],
3333
bias=True,
3434
weight_init_method=init_normal_(std=config.adapter_init_method_std),
3535
bias_init_method=init_normal_(std=config.adapter_init_method_std),

fast_llm/layers/vision_encoder/patch_conv.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,23 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace):
1919
self._lr_scale = config.adapter_lr_scale
2020
self.weight = ParameterMeta.from_dims(
2121
(
22-
self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels),
23-
self._tensor_space.get_tensor_dim(VisionEncoderDimNames.in_channels),
24-
self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size),
25-
self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size),
22+
self._tensor_space[VisionEncoderDimNames.out_channels],
23+
self._tensor_space[VisionEncoderDimNames.in_channels],
24+
self._tensor_space[VisionEncoderDimNames.patch_size],
25+
self._tensor_space[VisionEncoderDimNames.patch_size],
2626
),
2727
init_method=init_normal_(),
2828
lr_scale=self._lr_scale,
2929
)
3030
if config.conv_bias:
3131
self.bias = ParameterMeta.from_dims(
32-
(self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels),),
32+
(self._tensor_space[VisionEncoderDimNames.out_channels],),
3333
init_method=init_normal_(),
34-
lr_sclae=self._lr_scale,
34+
lr_scale=self._lr_scale,
3535
)
3636
else:
3737
self.bias = None
38-
self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels))
38+
self.norm = config.patch_norm.get_layer(tensor_space[VisionEncoderDimNames.out_channels])
3939
self.stride = config.patch_size
4040

4141
def forward(

fast_llm/models/gpt/model.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,8 @@ def preprocess_meta(
173173
VisionEncoderKwargs.image_std: image_std,
174174
VisionEncoderKwargs.image_rescale_factor: image_rescale_factor,
175175
VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta,
176-
VisionEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim(
177-
VisionTransformerDimNames.kv_channels
178-
).size,
179-
VisionEncoderKwargs.out_channels: self._tensor_space.get_tensor_dim(
180-
VisionEncoderDimNames.out_channels
181-
).size,
176+
VisionEncoderKwargs.kv_channels: self._tensor_space[VisionTransformerDimNames.kv_channels].size,
177+
VisionEncoderKwargs.out_channels: self._tensor_space[VisionEncoderDimNames.out_channels].size,
182178
}
183179
else:
184180
vision_kwargs = {}
@@ -226,7 +222,7 @@ def preprocess_meta(
226222
else (batch_dim, hidden_sequence_q_dim, hidden_dim)
227223
)
228224
if self._config.vision_encoder.enabled:
229-
vision_hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden)
225+
vision_hidden_dim = self._tensor_space[VisionTransformerDimNames.hidden]
230226
vision_hidden_dims = (
231227
(hidden_sequence_q_dim, batch_dim, vision_hidden_dim)
232228
if sequence_first

fast_llm/models/ssm/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def get_handler_class(cls) -> type[CheckpointHandler]:
133133

134134

135135
class LlavaHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
136-
support_optimizer: typing.ClassVar[bool] = False
137136
name: typing.ClassVar[str] = "llava_hybrid"
138137
vision_name: typing.ClassVar[str] = "pixtral"
139138
text_name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid"

0 commit comments

Comments
 (0)