Skip to content

Commit dbea356

Browse files
committed
test: change config and dummy dataset for E2E CoCa test
1 parent f90f3e0 commit dbea356

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

config_files/config_example_coca.yaml

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ collate_fn:
3434
sample_keys:
3535
- images
3636
- ${settings.referencing_keys.sample_key}
37+
- modality
3738
target_keys: []
3839
text_sample_key: ${settings.referencing_keys.sample_key}
3940
text_target_key: ${settings.referencing_keys.target_key}
@@ -50,6 +51,9 @@ train_dataset:
5051
- sample_key: input_ids
5152
sample_shape: [1024]
5253
sample_type: int
54+
- sample_key: modality
55+
sample_shape: [0]
56+
sample_type: const
5357

5458
val_dataset:
5559
component_key: dataset
@@ -63,6 +67,9 @@ val_dataset:
6367
- sample_key: input_ids
6468
sample_shape: [1024]
6569
sample_type: int
70+
- sample_key: modality
71+
sample_shape: [0]
72+
sample_type: const
6673

6774
train_dataloader:
6875
component_key: data_loader
@@ -174,13 +181,14 @@ model:
174181
variant_key: coca
175182
config:
176183
prediction_key: logits
177-
vision_embd_prediction_key: vision_embeddings
184+
modality_key: modality
185+
modality_embd_prediction_key: modality_embeddings
178186
text_embd_prediction_key: text_embeddings
179-
vision_cls_prediction_key: vision_cls
187+
modality_cls_prediction_key: modality_cls
180188
text_cls_prediction_key: text_cls
181-
vision_encoder_config:
189+
modality_encoder_config:
182190
sample_key: images
183-
prediction_key: vision_embeddings
191+
prediction_key: modality_embeddings
184192
img_size: 224
185193
n_classes: Null # Disable vision transformer head
186194
n_layer: 12
@@ -212,6 +220,7 @@ model:
212220
epsilon: 1e-5
213221
n_pool_head: 8
214222
n_vision_queries: 256
223+
n_audio_queries: Null
215224
bias_attn_pool: False
216225
epsilon_attn_pool: 1e-5
217226
weight_init:

src/modalities/dataloader/dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def _check_if_inbounds(self, idx: int):
2929
class DummySampleDataType(str, Enum):
3030
FLOAT = "float"
3131
INT = "int"
32+
CONSTANT = "const"
3233

3334

3435
class DummySampleConfig(BaseModel):
@@ -53,6 +54,8 @@ def __init__(self, num_samples: int, sample_definition: Tuple[DummySampleConfig]
5354
self.num_samples = num_samples
5455
self.sample_definition = sample_definition
5556

57+
self.VISION = 1
58+
5659
def __len__(self) -> int:
5760
return self.num_samples
5861

@@ -66,6 +69,8 @@ def _create_random_sample(self):
6669
data = np.random.randn(*s.sample_shape)
6770
elif s.sample_type == DummySampleDataType.INT:
6871
data = np.random.randint(low=0, high=512, size=s.sample_shape)
72+
elif s.sample_type == DummySampleDataType.CONSTANT:
73+
data = self.VISION
6974
else:
7075
raise NotImplementedError(f"DummyDataset does not support type { s.sample_type}")
7176
sample[s.sample_key] = data

0 commit comments

Comments
 (0)