From d909ae0875bdb9833d18bcbd423a49a95c9be1e4 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 21 Mar 2024 20:57:01 +0100 Subject: [PATCH 001/161] feat: add basic webdataset --- pyproject.toml | 4 ++- src/modalities/dataloader/dataset.py | 52 +++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e9121daa7..475ca6771 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,9 @@ dependencies = [ "class_resolver", "wandb", "einops>=0.7.0", - "flash-attn", # install this directly via `pip install flash-attn --no-build-isolation` + "webdataset>=0.2.86", + "timm>=0.9.16", + "flash-attn", # install this directly via `pip install flash-attn --no-build-isolation` ] [project.optional-dependencies] diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 1d9518a34..2bb09d88c 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -2,11 +2,14 @@ from enum import Enum from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import jq import numpy as np +import webdataset as wds from pydantic import BaseModel +from timm.data import create_transform +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torch.utils.data.dataset import Dataset as TorchdataSet from tqdm import tqdm from transformers import BatchEncoding @@ -207,3 +210,50 @@ def _generate_packing_index(self) -> List[Tuple[int, int]]: curr_offset = segment_offset curr_len = segment_len return index + + +class ImageTransformConfig(BaseModel): + input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224 + is_training: bool = False + no_aug: bool = False + train_crop_mode: Optional[str] = None + scale: Optional[Tuple[float, float]] = None + ratio: Optional[Tuple[float, float]] = None + hflip: float = 0.5 + vflip: float = 0.0 + color_jitter: Union[float, Tuple[float, ...]] = 0.4 + color_jitter_prob: Optional[float] = None + grayscale_prob: float = 0.0 + gaussian_blur_prob: float = 0.0 + auto_augment: Optional[str] = None + interpolation: str = "bilinear" + mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN + std: Tuple[float, ...] = IMAGENET_DEFAULT_STD + re_prob: float = 0.0 + re_mode: str = "const" + re_count: int = 1 + re_num_splits: int = 0 + crop_pct: Optional[float] = None + crop_mode: Optional[str] = None + crop_border_pixels: Optional[int] = None + tf_preprocessing: bool = False + use_prefetcher: bool = False + separate: bool = False + + +class WebDatasetConfig(BaseModel): + urls: Union[List[str], str] + key_mapping: Optional[Dict[str, str]] = None + image_preprocessing: ImageTransformConfig = ImageTransformConfig() + + +class WebDataset(wds.WebDataset): + def __init__( + self, urls: Union[List[str], str], key_mapping: Dict[str, str], image_transform_config: ImageTransformConfig + ): + super().__init__(urls=urls) + if key_mapping is not None: + self.append(wds.filters.map(lambda x: {key_mapping[k]: v for k, v in x.items() if k in key_mapping.keys()})) + if image_transform_config is not None: + transform = create_transform(**image_transform_config.model_dump()) + self.append(wds.filters.map(lambda x: transform(x))) From e2336767ac8a147795275d4fea361a4b76dfd2c3 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 22 Apr 2024 09:01:01 +0200 Subject: [PATCH 002/161] fix: dim of cls token --- src/modalities/models/coca/coca_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index b531cf219..0b9125984 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -160,12 +160,12 @@ def _forward_encode_vision(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch vision_embd = self.vision_encoder(inputs)[self.vision_embd_prediction_key] queries = repeat(self.vision_queries, "n d -> b n d", b=vision_embd.shape[0]) vision_embd = self.attn_pool(queries, context=vision_embd) - vision_embd, vision_cls_token = vision_embd[:, :-1, :], vision_embd[:, -1:, :] + vision_embd, vision_cls_token = vision_embd[:, :-1, :], vision_embd[:, -1, :] return vision_embd, vision_cls_token def _forward_encode_text(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: text_embd = self.text_decoder(inputs)[self.text_embd_prediction_key] - text_embd, text_cls_token = text_embd[:, :-1, :], text_embd[:, -1:, :] + text_embd, text_cls_token = text_embd[:, :-1, :], text_embd[:, -1, :] return text_embd, text_cls_token def _forward_decode(self, text_embd: torch.Tensor, vision_embd: torch.Tensor) -> torch.Tensor: From 9986691824e29190be5dcadceffd4745c8fd1b6a Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 22 Apr 2024 09:05:29 +0200 Subject: [PATCH 003/161] feat: simple console logging --- src/modalities/config/config.py | 8 +++++ .../batch_progress_subscriber.py | 36 +++++++++++++++++++ .../subscriber_impl/subscriber_factory.py | 27 ++++++++++++++ src/modalities/registry/components.py | 7 ++++ 4 files changed, 78 insertions(+) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index e24ecc055..a8c573958 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -314,6 +314,14 @@ class DummyProgressSubscriberConfig(BaseModel): pass +class SimpleProgressSubscriberConfig(BaseModel): + train_dataloader: PydanticLLMDataLoaderIFType + eval_dataloaders: Optional[List[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) + world_size: int + global_num_seen_samples: int + local_rank: int + + class RichProgressSubscriberConfig(BaseModel): train_dataloader: PydanticLLMDataLoaderIFType eval_dataloaders: Optional[List[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) diff --git a/src/modalities/logging_broker/subscriber_impl/batch_progress_subscriber.py b/src/modalities/logging_broker/subscriber_impl/batch_progress_subscriber.py index 074726bc3..f20adcf75 100644 --- a/src/modalities/logging_broker/subscriber_impl/batch_progress_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/batch_progress_subscriber.py @@ -15,6 +15,42 @@ def consume_message(self, message: Message[BatchProgressUpdate]): pass +class SimpleProgressSubscriber(MessageSubscriberIF[BatchProgressUpdate]): + def __init__( + self, + train_split_num_samples: Dict[str, int], + eval_splits_num_samples: Dict[str, int], + ) -> None: + self.train_split_num_samples = train_split_num_samples + self.eval_splits_num_samples = eval_splits_num_samples + + def consume_message(self, message: Message[BatchProgressUpdate]): + if not isinstance(message.payload, BatchProgressUpdate): + return + + batch_progress = message.payload + completed_samples = 0 + total_samples = 0 + + [batch_progress.dataloader_tag] + + prefix = "" + if message.payload.experiment_status == ExperimentStatus.TRAIN: + prefix = "Train" + completed_samples = batch_progress.global_train_sample_id + 1 + total_samples = self.train_split_num_samples[batch_progress.dataloader_tag] + + elif message.payload.experiment_status == ExperimentStatus.EVALUATION: + prefix = "Evaluation" + completed_samples = batch_progress.global_dataset_sample_id + 1 + total_samples = self.eval_splits_num_samples[batch_progress.dataloader_tag] + + print( + f"{prefix}[{batch_progress.dataloader_tag}] " + f"[{completed_samples}/{total_samples} ({completed_samples/total_samples:.01f}%)]" + ) + + class RichProgressSubscriber(MessageSubscriberIF[BatchProgressUpdate]): """A subscriber object for the RichProgress observable.""" diff --git a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py index 7c8ad59ba..12d8d79dc 100644 --- a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py +++ b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py @@ -6,6 +6,7 @@ from modalities.logging_broker.subscriber_impl.batch_progress_subscriber import ( DummyProgressSubscriber, RichProgressSubscriber, + SimpleProgressSubscriber, ) from modalities.logging_broker.subscriber_impl.results_subscriber import ( DummyResultSubscriber, @@ -34,6 +35,32 @@ def get_rich_progress_subscriber( subscriber = ProgressSubscriberFactory.get_dummy_progress_subscriber() return subscriber + @staticmethod + def get_simple_progress_subscriber( + train_dataloader: LLMDataLoader, + eval_dataloaders: List[LLMDataLoader], + world_size: int, + global_num_seen_samples: int, + local_rank: int, + ) -> SimpleProgressSubscriber: + if local_rank == 0: + skip_num_local_train_batches = global_num_seen_samples // world_size // train_dataloader.batch_size + train_split_num_samples = { + train_dataloader.dataloader_tag: (len(train_dataloader) + skip_num_local_train_batches) + * world_size + * train_dataloader.batch_size + } + + eval_splits_num_samples = { + dataloader.dataloader_tag: len(dataloader) * world_size * dataloader.batch_size + for dataloader in eval_dataloaders + } + + subscriber = SimpleProgressSubscriber(train_split_num_samples, eval_splits_num_samples) + else: + subscriber = ProgressSubscriberFactory.get_dummy_progress_subscriber() + return subscriber + @staticmethod def get_dummy_progress_subscriber() -> DummyProgressSubscriber: return DummyProgressSubscriber() diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index a92580087..3e0b98168 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -45,6 +45,7 @@ RichResultSubscriberConfig, SaveEveryKStepsCheckpointingStrategyConfig, SaveKMostRecentCheckpointsStrategyConfig, + SimpleProgressSubscriberConfig, StepLRSchedulerConfig, TorchCheckpointLoadingConfig, WandBEvaluationResultSubscriberConfig, @@ -176,6 +177,12 @@ class ComponentEntity: ProgressSubscriberFactory.get_dummy_progress_subscriber, DummyProgressSubscriberConfig, ), + ComponentEntity( + "progress_subscriber", + "simple", + ProgressSubscriberFactory.get_simple_progress_subscriber, + SimpleProgressSubscriberConfig, + ), ComponentEntity( "progress_subscriber", "rich", From c47b6c1aa222f0a5189b02394e89a7dc6debfa61 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 22 Apr 2024 09:09:23 +0200 Subject: [PATCH 004/161] fix: add attention mask to cross entropy loss --- src/modalities/config/config.py | 5 ----- src/modalities/loss_functions.py | 26 +++++++++++++++++++++++--- src/modalities/models/coca/collator.py | 14 ++++++++++---- 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index a8c573958..057095be0 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -59,11 +59,6 @@ class ReferenceConfig(BaseModel): pass_type: PassType -class CLMCrossEntropyLossConfig(BaseModel): - target_key: str - prediction_key: str - - # Checkpointing class SaveEveryKStepsCheckpointingStrategyConfig(BaseModel): k: PositiveInt diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index 54d8de36b..0f3cd0c7a 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod import torch -from torch.nn import CrossEntropyLoss +from pydantic import BaseModel +from torch.nn import CrossEntropyLoss as TorchCrossEntropyLoss from modalities.batch import InferenceResultBatch @@ -23,16 +24,27 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: raise NotImplementedError -class CLMCrossEntropyLoss(Loss): +class CrossEntropyLossConfig(BaseModel): + target_key: str + prediction_key: str + tag: str = "CLMCrossEntropyLoss" + + +class CrossEntropyLoss(Loss): def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEntropyLoss"): super().__init__(tag) self.target_key = target_key self.prediction_key = prediction_key # Mean over the tokens in the local-batch (batch per rank) - self.loss_fun = CrossEntropyLoss(reduction="mean") + self.loss_fun = TorchCrossEntropyLoss(reduction="mean") def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: labels = forward_batch.get_targets(self.target_key) + + if "attention_mask" in forward_batch.targets: + attention_mask = forward_batch.get_targets("attention_mask") + labels[attention_mask == 0] = -100 + lm_logits = forward_batch.get_predictions(self.prediction_key) # move labels to correct device to enable model parallelism @@ -79,6 +91,14 @@ def nce_loss( return torch.mean(denominator - numerator) # calculated in log space +class NCELossConfig(BaseModel): + prediction_key1: str + prediction_key2: str + is_asymmetric: bool = True + temperature: float = 1.0 + tag: str = "NCELoss" + + class NCELoss(Loss): def __init__( self, diff --git a/src/modalities/models/coca/collator.py b/src/modalities/models/coca/collator.py index 0c9584ca9..42b40ee03 100644 --- a/src/modalities/models/coca/collator.py +++ b/src/modalities/models/coca/collator.py @@ -34,11 +34,17 @@ def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: samples = { sample_key: torch.stack([torch.tensor(d[sample_key]) for d in batch]) for sample_key in self.sample_keys } - targets = { - target_key: torch.stack([torch.tensor(d[target_key]) for d in batch]) for target_key in self.target_keys - } + if "attention_mask" in batch[0]: + samples["attention_mask"] = torch.stack([torch.tensor(d["attention_mask"]) for d in batch]) + + targets = {target_key: torch.stack([d[target_key] for d in batch]) for target_key in self.target_keys} # Create target for text input targets[self.text_target_key] = samples[self.text_sample_key][:, 1:].clone().detach() - samples[self.text_sample_key] = samples[self.text_sample_key][:, :-1].clone().detach() + samples[self.text_sample_key] = samples[self.text_sample_key][:, :-1] + + if "attention_mask" in batch[0]: + targets["attention_mask"] = samples["attention_mask"][:, 1:].clone().detach() + samples["attention_mask"] = samples["attention_mask"][:, :-1] + return DatasetBatch(targets=targets, samples=samples) From 70823e10764c5d6dcc92d17329ccf21926884add Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 22 Apr 2024 09:10:53 +0200 Subject: [PATCH 005/161] feat: allow multiple loss functions --- src/modalities/config/config.py | 14 ++--- src/modalities/config/instantiation_models.py | 4 +- src/modalities/evaluator.py | 61 +++++++++++++------ src/modalities/loss_functions.py | 3 +- src/modalities/trainer.py | 61 +++++++++++++------ 5 files changed, 97 insertions(+), 46 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 057095be0..d13a1b41a 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Annotated, Dict, List, Literal, Optional, Tuple +from typing import Annotated, Dict, List, Literal, Optional, Tuple, Union import torch from omegaconf import OmegaConf @@ -148,12 +148,12 @@ class OneCycleLRSchedulerConfig(BaseModel): pct_start: Annotated[float, Field(strict=True, gt=0.0, le=1.0)] anneal_strategy: str cycle_momentum: bool = True - base_momentum: Annotated[float, Field(strict=True, gt=0)] | List[ - Annotated[float, Field(strict=True, gt=0.0)] - ] = 0.85 - max_momentum: Annotated[float, Field(strict=True, gt=0.0)] | List[ - Annotated[float, Field(strict=True, gt=0.0)] - ] = 0.95 + base_momentum: Annotated[float, Field(strict=True, gt=0)] | List[Annotated[float, Field(strict=True, gt=0.0)]] = ( + 0.85 + ) + max_momentum: Annotated[float, Field(strict=True, gt=0.0)] | List[Annotated[float, Field(strict=True, gt=0.0)]] = ( + 0.95 + ) div_factor: Annotated[float, Field(strict=True, gt=0.0)] final_div_factor: Annotated[float, Field(strict=True, gt=0.0)] three_phase: bool = False diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index a77e626a0..bd2c99b85 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Annotated, Dict, List, Optional +from typing import Annotated, Dict, List, Optional, Union from pydantic import BaseModel, Field, FilePath, field_validator @@ -49,7 +49,7 @@ class Paths(BaseModel): wrapped_model: PydanticPytorchModuleType optimizer: PydanticOptimizerIFType scheduler: PydanticLRSchedulerIFType - loss_fn: PydanticLossIFType + loss_fn: Union[PydanticLossIFType, List[PydanticLossIFType]] train_dataloader: PydanticLLMDataLoaderIFType eval_dataloaders: List[PydanticLLMDataLoaderIFType] batch_progress_subscriber: PydanticMessageSubscriberIFType diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index 0c6b31ed8..300820625 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -4,11 +4,12 @@ import torch.distributed as dist import torch.nn as nn -from modalities.batch import DatasetBatch, EvaluationResultBatch, InferenceResultBatch +from modalities.batch import DatasetBatch, EvaluationResultBatch from modalities.dataloader.dataloader import LLMDataLoader from modalities.logging_broker.messages import BatchProgressUpdate, ExperimentStatus, MessageTypes from modalities.logging_broker.publisher import MessagePublisher -from modalities.models.model import model_predict_batch +from modalities.loss_functions import Loss +from modalities.models.model import NNModel, model_predict_batch from modalities.running_env.fsdp.reducer import Reducer from modalities.trainer import ThroughputAggregationKeys from modalities.util import Aggregator, TimeRecorder @@ -29,18 +30,34 @@ def evaluate_batch( self, batch: DatasetBatch, model: nn.Module, - loss_fun: Callable[[InferenceResultBatch], torch.Tensor], + loss_fun: List[Loss], ): with torch.no_grad(): result_batch = model_predict_batch(model=model, batch=batch) - loss = loss_fun(result_batch) - return loss + + total_loss = None + losses = [] + for lfn in loss_fun: + # Calculate loss + loss = lfn(result_batch) + + # Add loss to total loss + weighted_loss = loss * lfn.weight # / self.gradient_acc_steps + if total_loss is None: + total_loss = weighted_loss + else: + total_loss += weighted_loss + + # Append individual losses (for logging) + losses.append(loss) + + return total_loss, *losses def evaluate( self, model: nn.Module, data_loaders: List[LLMDataLoader], - loss_fun: Callable[[InferenceResultBatch], torch.Tensor], + loss_fun: List[Loss], train_step_id: int, ) -> Dict[str, EvaluationResultBatch]: result_dict: Dict[str, EvaluationResultBatch] = {} @@ -49,7 +66,7 @@ def evaluate( device = torch.device(self.local_rank if torch.cuda.is_available() else "cpu") for data_loader in data_loaders: - cumulated_loss = torch.zeros(3).to(device) + cumulated_loss = torch.zeros(len(loss_fun) + 1 + 1).to(device) # total loss, indidual losses, count Evaluator._publish_progress( batch_progress_publisher=self.batch_progress_publisher, @@ -59,14 +76,17 @@ def evaluate( thoughput_aggregator = Aggregator[ThroughputAggregationKeys]() with TimeRecorder() as forward_backward_timer_recorder: for batch_id, batch in enumerate(data_loader): - batch_loss = self.evaluate_batch( + batch_losses = self.evaluate_batch( batch=batch, model=model, loss_fun=loss_fun, ) - cumulated_loss[0] += batch_loss.item() # sum up batch loss - cumulated_loss[1] += 1 + # Accumulate losses + for i, batch_loss in enumerate(batch_losses): + cumulated_loss[i] += batch_loss.item() + cumulated_loss[-1] += 1 + batch_length_tensor = torch.tensor(len(batch)).to(device) thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) @@ -75,13 +95,8 @@ def evaluate( eval_step_id=batch_id, dataloader_tag=data_loader.dataloader_tag, ) - # TODO: insert reducer from outside so Evaluator is independent of FSDP - total_loss = Reducer.reduce( - tensor=cumulated_loss, - operation=dist.ReduceOp.SUM, - post_processing_fun=lambda t: t[0] / t[1], - ) + # TODO: insert reducer from outside so Evaluator is independent of FSDP forward_backward_time = torch.tensor(forward_backward_timer_recorder.delta_t).to(device) thoughput_aggregator.add_value( key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time @@ -92,8 +107,20 @@ def evaluate( ) num_samples_per_second = synced_num_samples / synced_forward_backward_time + # Agreggate loss from all ranks + total_losses = Reducer.reduce( + tensor=cumulated_loss, + operation=dist.ReduceOp.SUM, + post_processing_fun=lambda t: t[:-1] / t[-1], + ) + + # Fill logging dict with total loss and the individual losses + losses = {"total_loss": total_losses[0]} + for i, lfn in enumerate(loss_fun): + losses[lfn.tag] = total_losses[i + 1] + evaluation_result = EvaluationResultBatch( - losses={loss_fun.tag: total_loss}, + losses=losses, # TODO: hardcoded metric key throughput_metrics={"evaluation_num_samples_per_second": num_samples_per_second}, dataloader_tag=data_loader.dataloader_tag, diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index 0f3cd0c7a..923cf9ae2 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -8,8 +8,9 @@ class Loss(ABC): - def __init__(self, tag: str): + def __init__(self, tag: str, weight: float = 1.0): self._tag = tag + self.weight = weight @property def tag(self) -> str: diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index c77f45ba2..4a2bdb3de 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Callable, Tuple +from typing import Callable, List, Tuple import torch import torch.distributed as dist @@ -45,22 +45,38 @@ def _train_batch( model: FSDP, optimizer: Optimizer, scheduler: LRScheduler, - loss_fun: Loss, + loss_fun: List[Loss], train_step_id: int, data_loader: LLMDataLoader, ) -> Tuple[torch.Tensor, torch.Tensor]: result_batch = model_predict_batch(model=model, batch=batch) - loss = loss_fun(result_batch) - (loss / self.gradient_acc_steps).backward() + + total_loss = None + losses = [] + for lfn in loss_fun: + # Calculate loss + loss = lfn(result_batch) + + # Add loss to total loss + weighted_loss = (loss * lfn.weight) / self.gradient_acc_steps + if total_loss is None: + total_loss = weighted_loss + else: + total_loss += weighted_loss + + # Append individual losses (for logging) + losses.append(loss) + + (total_loss / self.gradient_acc_steps).backward() + self.gradient_clipper(model) if (train_step_id + 1) % self.gradient_acc_steps == 0 or (train_step_id + 1) == len(data_loader): gradient_norm_score = self.gradient_clipper.clip_gradients().sum() optimizer.step() scheduler.step() optimizer.zero_grad() - return loss, gradient_norm_score - else: - return loss, None + return total_loss, *losses, gradient_norm_score + return total_loss, *losses, None def train( self, @@ -80,6 +96,8 @@ def train( device = torch.device(self.local_rank if torch.cuda.is_available() else "cpu") + cumulated_loss = torch.zeros(len(loss_fun) + 1 + 1).to(device) + # batch loop batch: DatasetBatch # TODO: why do we need a barrier here? @@ -91,7 +109,7 @@ def train( # Because we might resume training, we add the starting batch id of the data loader train_step_id = batch_id + train_loader.fast_forward_batch_id # Train single batch - batch_loss, gradient_norm_score = self._train_batch( + *batch_losses, gradient_norm_score = self._train_batch( batch=batch, model=model, optimizer=optimizer, @@ -102,7 +120,8 @@ def train( ) forward_backward_time_recorder.stop() # Save the batch loss - cumulated_losses[0] += batch_loss.item() + for i, batch_loss in enumerate(batch_losses): + cumulated_loss[i] += batch_loss.item() # This works, because we always drop the last batch in case it has less samples than the batch size cumulated_losses[-1] += 1 # number of local batches @@ -141,17 +160,21 @@ def train( operation=dist.ReduceOp.SUM, # 1.) summed batch loss / (num batches * world size) # 2.) last batch loss / world size - post_processing_fun=lambda t: torch.stack([t[0] / t[-1], t[1] / dist.get_world_size()]), + post_processing_fun=lambda t: torch.stack([t[:-1] / t[-1], t[-1] / dist.get_world_size()]), ) train_loss_avg, train_loss_last_batch = ( reduced_losses[0], - reduced_losses[1], + reduced_losses[-1], ) + losses = { - f"{loss_fun.tag} average": train_loss_avg, - f"{loss_fun.tag} last step": train_loss_last_batch, + f"total_loss average": train_loss_avg, + f"total_loss last step": train_loss_last_batch, } + for i, lfn in enumerate(loss_fun): + losses[lfn.tag] = reduced_losses[i + 1] + if len(gradient_norm_scores) > 0: metrics = { "grad_norm_avg": torch.mean(torch.Tensor(gradient_norm_scores)), @@ -186,19 +209,19 @@ def train( evaluation_callback(train_step_id=train_step_id) checkpointing_callback(train_step_id=train_step_id) + cumulated_loss = torch.zeros(len(loss_fun) + 1 + 1).to(device) # we start the time recoder here again to also capture the time spend loading # via the dataloader. forward_backward_time_recorder.start() - def _reset_tracked_losses(self): + def _reset_loss(self): # TODO: we should handle the device assignment more centrally. - # summed lcoal losses, loss of last local batch, number of local batches (i.e., number of steps) - cumulated_loss_and_gradient_norm = torch.zeros(3) + cumulated_loss = torch.zeros(2) if torch.cuda.is_available(): - cumulated_loss_and_gradient_norm = cumulated_loss_and_gradient_norm.to(torch.device(self.local_rank)) + cumulated_loss = cumulated_loss.to(torch.device(self.local_rank)) else: - cumulated_loss_and_gradient_norm = cumulated_loss_and_gradient_norm.to("cpu") - return cumulated_loss_and_gradient_norm + cumulated_loss = cumulated_loss.to("cpu") + return cumulated_loss @staticmethod def _publish_progress( From 0c87d91e961b2a6e305204d7152f2b438ada8470 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 22 Apr 2024 09:11:53 +0200 Subject: [PATCH 006/161] fix: register nce loss --- src/modalities/registry/components.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 3e0b98168..b158e24e8 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -21,7 +21,6 @@ CheckpointedModelConfig, CheckpointedOptimizerConfig, CheckpointSavingConfig, - CLMCrossEntropyLossConfig, ConstantLRSchedulerConfig, CosineAnnealingLRSchedulerConfig, DistributedSamplerConfig, @@ -57,7 +56,7 @@ ProgressSubscriberFactory, ResultsSubscriberFactory, ) -from modalities.loss_functions import CLMCrossEntropyLoss +from modalities.loss_functions import CrossEntropyLoss, CrossEntropyLossConfig, NCELoss, NCELossConfig from modalities.models.coca.coca_model import CoCa, CoCaConfig from modalities.models.coca.collator import CoCaCollateFnConfig, CoCaCollatorFn from modalities.models.components.layer_norms import LayerNormConfig, RMSLayerNorm, RMSLayerNormConfig @@ -101,7 +100,8 @@ class ComponentEntity: ComponentEntity("model", "fsdp_wrapped", ModelFactory.get_fsdp_wrapped_model, FSDPWrappedModelConfig), ComponentEntity("model", "coca", CoCa, CoCaConfig), # losses - ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig), + ComponentEntity("loss", "cross_entropy_loss", CrossEntropyLoss, CrossEntropyLossConfig), + ComponentEntity("loss", "nce_loss", NCELoss, NCELossConfig), # optmizers ComponentEntity("optimizer", "adam", OptimizerFactory.get_adam, AdamOptimizerConfig), ComponentEntity("optimizer", "adam_w", OptimizerFactory.get_adam_w, AdamWOptimizerConfig), From b652a7d2f901b6f9369e28ec01c8d1f9c1fc1599 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 22 Apr 2024 09:13:14 +0200 Subject: [PATCH 007/161] feat: add dataloader for webdataset --- src/modalities/dataloader/dataloader.py | 49 ++++++++++++++++- .../dataloader/dataloader_factory.py | 16 +++++- src/modalities/dataloader/dataset.py | 55 ++++++++++++++++--- src/modalities/dataloader/dataset_factory.py | 33 ++++++++++- src/modalities/registry/components.py | 5 +- 5 files changed, 146 insertions(+), 12 deletions(-) diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index 43f7ab33a..471bc464a 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -1,12 +1,17 @@ from typing import Iterable, Optional, Union +import webdataset as wd from torch.utils.data import Dataset, DistributedSampler, Sampler from torch.utils.data.dataloader import DataLoader, T_co, _collate_fn_t, _worker_init_fn_t from modalities.dataloader.samplers import ResumableBatchSampler -class LLMDataLoader(DataLoader[T_co]): +class DataLoaderIF: + pass + + +class LLMDataLoader(DataLoader[T_co], DataLoaderIF): def __init__( self, dataloader_tag: str, @@ -141,3 +146,45 @@ def fast_forward_batch_id(self) -> int: def __len__(self) -> int: return self.num_epochs * len(self.dataloader) + + +class WebLoader(DataLoaderIF): + def __init__(self, dataloader_tag: str, dataset: Dataset[T_co], batch_size: Optional[int] = 1, *args, **kwargs): + self.num_batches = len(dataset) // batch_size + self.webloader = wd.WebLoader(dataset=dataset, batch_size=None) + # self.webloader = self.webloader.unbatched().shuffle(1000).batched(batch_size) + self.webloader = self.webloader.with_epoch(1282 * 100 // batch_size) + self.dataloader_tag = dataloader_tag + self.batch_size = batch_size + + def __len__(self): + return self.num_batches + + def __iter__(self): + return iter(self.webloader) + + @property + def batch_size(self) -> int: + return self._batch_size + + @batch_size.setter + def batch_size(self, value: int): + self._batch_size = value + + @property + def fast_forward_sample_id(self) -> int: + """The sample id until which we fast-forward, as specified in the ResumableBatchSampler. + + Returns: + int: fast forward sample id + """ + return 0 # self.batch_size * self.batch_sampler.start_index + + @property + def fast_forward_batch_id(self) -> int: + """The batch id until which we fast-forward, as specified in the ResumableBatchSampler. + + Returns: + int: fast forward batch id + """ + return 0 # self.batch_sampler.start_index diff --git a/src/modalities/dataloader/dataloader_factory.py b/src/modalities/dataloader/dataloader_factory.py index bbb3d798c..0b533aabe 100644 --- a/src/modalities/dataloader/dataloader_factory.py +++ b/src/modalities/dataloader/dataloader_factory.py @@ -3,7 +3,7 @@ from torch.utils.data import BatchSampler from torch.utils.data.dataset import Dataset -from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader +from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader, WebLoader from modalities.dataloader.samplers import ResumableBatchSampler @@ -38,3 +38,17 @@ def get_repeating_dataloader( ) -> RepeatingDataLoader: dataloader = RepeatingDataLoader(dataloader, num_epochs, reshuffle_after_epoch) return dataloader + + @staticmethod + def get_web_loader( + dataloader_tag: str, + dataset: Dataset, + batch_size: int, + collate_fn: Callable, + num_workers: int, + ) -> WebLoader: + dataset = dataset.batched(batch_size, collation_fn=collate_fn) + dataloader = WebLoader( + dataloader_tag=dataloader_tag, dataset=dataset, batch_size=batch_size, num_workers=num_workers + ) + return dataloader diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 2bb09d88c..19eae2289 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -16,6 +16,8 @@ from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper +from modalities.config.config import PydanticTokenizerIFType + from ..dataloader.large_file_lines_reader import LargeFileLinesReader from .create_packed_data import EmbeddedStreamData @@ -243,17 +245,54 @@ class ImageTransformConfig(BaseModel): class WebDatasetConfig(BaseModel): urls: Union[List[str], str] - key_mapping: Optional[Dict[str, str]] = None - image_preprocessing: ImageTransformConfig = ImageTransformConfig() + source_image_key: str + image_key: str + source_text_key: str + text_key: str + tokenizer: PydanticTokenizerIFType + block_size: int + num_samples: int + image_transform_config: Optional[ImageTransformConfig] = None class WebDataset(wds.WebDataset): def __init__( - self, urls: Union[List[str], str], key_mapping: Dict[str, str], image_transform_config: ImageTransformConfig + self, + urls: Union[List[str], str], + source_image_key: str, + image_key: str, + source_text_key: str, + text_key: str, + tokenizer: PreTrainedTokenizer, + block_size: int, + num_samples: int, + image_transform_config: ImageTransformConfig, ): super().__init__(urls=urls) - if key_mapping is not None: - self.append(wds.filters.map(lambda x: {key_mapping[k]: v for k, v in x.items() if k in key_mapping.keys()})) - if image_transform_config is not None: - transform = create_transform(**image_transform_config.model_dump()) - self.append(wds.filters.map(lambda x: transform(x))) + self.num_samples = num_samples + + self.append(wds.filters.shuffle(1000)) + self.append(wds.filters.decode("pil")) + + transform = create_transform(**image_transform_config.model_dump()) + + def make_sample(sample): + # print(sample["json"]) + batch_encoding: BatchEncoding = tokenizer( + sample["json"]["text0"], # [source_text_key], + max_length=block_size, + padding="max_length", + truncation=True, + return_attention_mask=True, + ) + + return { + image_key: transform(sample[source_image_key]), + text_key: batch_encoding.input_ids, + "attention_mask": batch_encoding.attention_mask, + } + + self.append(wds.filters.map(make_sample)) + + def __len__(self): + return self.num_samples diff --git a/src/modalities/dataloader/dataset_factory.py b/src/modalities/dataloader/dataset_factory.py index 1d31e27cf..9a4a1fe48 100644 --- a/src/modalities/dataloader/dataset_factory.py +++ b/src/modalities/dataloader/dataset_factory.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union from pydantic import FilePath from torch.utils.data.dataset import Dataset @@ -8,9 +8,11 @@ from modalities.dataloader.dataset import ( DummyDataset, DummySampleConfig, + ImageTransformConfig, MemMapDataset, PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron, + WebDataset, ) from modalities.dataloader.open_gptx_dataset.open_gptx_dataset import OpenGPTXMMapDataset @@ -90,3 +92,32 @@ def get_open_gptx_mmap_dataset( # TODO: Fix the OpenGPTX implementation and get rid of this hack. dataset_wrapped = OpenGPTXDatasetWrapper(open_gptx_dataset=dataset, num_samples=num_samples) return dataset_wrapped + + @staticmethod + def get_web_dataset( + urls: Union[List[str], str], + source_image_key: str, + image_key: str, + source_text_key: str, + text_key: str, + tokenizer: PreTrainedTokenizer, + block_size: int, + num_samples: int, + image_transform_config: Optional[ImageTransformConfig] = None, + ) -> WebDataset: + # TODO this was part of the old Dataloader implementation. + # we need to check if this is actually wanted generally. + tokenizer.pad_token = tokenizer.eos_token + + dataset = WebDataset( + urls=urls, + source_image_key=source_image_key, + image_key=image_key, + source_text_key=source_text_key, + text_key=text_key, + tokenizer=tokenizer, + block_size=block_size, + num_samples=num_samples, + image_transform_config=image_transform_config, + ) + return dataset diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index b158e24e8..34450d74a 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -48,9 +48,10 @@ StepLRSchedulerConfig, TorchCheckpointLoadingConfig, WandBEvaluationResultSubscriberConfig, + WebLoaderConfig, ) from modalities.dataloader.dataloader_factory import DataloaderFactory -from modalities.dataloader.dataset import DummyDatasetConfig +from modalities.dataloader.dataset import DummyDatasetConfig, WebDatasetConfig from modalities.dataloader.dataset_factory import DatasetFactory from modalities.logging_broker.subscriber_impl.subscriber_factory import ( ProgressSubscriberFactory, @@ -138,6 +139,7 @@ class ComponentEntity: "dataset", "open_gptx_mmap_dataset", DatasetFactory.get_open_gptx_mmap_dataset, OpenGPTXMMapDatasetConfig ), ComponentEntity("dataset", "dummy_dataset", DatasetFactory.get_dummy_dataset, DummyDatasetConfig), + ComponentEntity("dataset", "web_dataset", DatasetFactory.get_web_dataset, WebDatasetConfig), # samplers ComponentEntity("sampler", "distributed_sampler", DistributedSampler, DistributedSamplerConfig), # batch samplers @@ -147,6 +149,7 @@ class ComponentEntity: ComponentEntity("collate_fn", "coca_collator", CoCaCollatorFn, CoCaCollateFnConfig), # data loaders ComponentEntity("data_loader", "default", DataloaderFactory.get_dataloader, LLMDataLoaderConfig), + ComponentEntity("data_loader", "web_loader", DataloaderFactory.get_web_loader, WebLoaderConfig), ComponentEntity( "data_loader", "repeating_data_loader", DataloaderFactory.get_repeating_dataloader, RepeatingDataLoaderConfig ), From b0e933a1764273381b143a5950fa66d642b2e6bf Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 22 Apr 2024 09:15:41 +0200 Subject: [PATCH 008/161] chore: add config --- .../config_example_coca_webdataset.yaml | 273 ++++++++++++++++++ start.sh | 3 + 2 files changed, 276 insertions(+) create mode 100644 config_files/config_example_coca_webdataset.yaml create mode 100644 start.sh diff --git a/config_files/config_example_coca_webdataset.yaml b/config_files/config_example_coca_webdataset.yaml new file mode 100644 index 000000000..b2bd0c60a --- /dev/null +++ b/config_files/config_example_coca_webdataset.yaml @@ -0,0 +1,273 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + training: + callback_interval_in_samples: 32 + global_num_training_samples: 1281990 + global_num_seen_samples: 0 + do_apply_activation_checkpointing: true + gradient_acc_steps: 1 + local_train_micro_batch_size: 4 # This is the batch size per rank? + sequence_length: 256 + gradient_clipping: + mode: p2_norm + threshold: 1.0 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints + +tokenizer: + component_key: tokenizer + variant_key: gpt2_tokenizer_fast + config: + tokenizer_file: data/tokenizer/tokenizer_gpt2.json + +collate_fn: + component_key: collate_fn + variant_key: coca_collator + config: + sample_keys: + - images + - ${settings.referencing_keys.sample_key} + target_keys: [] + text_sample_key: ${settings.referencing_keys.sample_key} + text_target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: web_dataset + config: + urls: "coco_captions/data/train/{000000..000011}.tar" + source_image_key: jpg + image_key: images + source_text_key: txt + text_key: input_ids + block_size: ${settings.training.sequence_length} + num_samples: 1_281_000 + tokenizer: + instance_key: tokenizer + pass_type: BY_REFERENCE + image_transform_config: + is_training: True + +val_dataset: + component_key: dataset + variant_key: dummy_dataset + config: + num_samples: 4 + sample_definition: + - sample_key: images + sample_shape: [3, 224, 224] + sample_type: float + - sample_key: input_ids + sample_shape: [1024] + sample_type: int + +train_dataloader: + component_key: data_loader + variant_key: web_loader + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_size: ${settings.training.local_train_micro_batch_size} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "val" + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + +checkpointing: + component_key: checkpointing + variant_key: default + config: + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: FP_16 + sharding_strategy: FULL_SHARD + block_names: [TransformerBlock, VisionTransformerBlock] + +captioning_loss: + component_key: loss + variant_key: cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${model.config.prediction_key} + tag: captioning_loss + +contrastive_loss: + component_key: loss + variant_key: nce_loss + config: + prediction_key1: ${model.config.vision_cls_prediction_key} + prediction_key2: ${model.config.text_cls_prediction_key} + tag: contrastive_loss + +loss_fn: + - instance_key: captioning_loss + pass_type: BY_REFERENCE + - instance_key: contrastive_loss + pass_type: BY_REFERENCE + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: FP_16 + sharding_strategy: FULL_SHARD + block_names: [TransformerBlock, VisionTransformerBlock] + +model: + component_key: model + variant_key: coca + config: + prediction_key: logits + vision_embd_prediction_key: vision_embeddings + text_embd_prediction_key: text_embeddings + vision_cls_prediction_key: vision_cls + text_cls_prediction_key: text_cls + vision_encoder_config: + sample_key: images + prediction_key: vision_embeddings + img_size: 224 + n_classes: Null # Disable vision transformer head + n_layer: 4 + attention_config: + attention_engine_type: default_attention + n_head: 12 + n_embd: 768 + dropout: 0.0 + patch_size: 16 + patch_stride: 16 + n_img_channels: 3 + add_cls_token: False + bias: True + text_decoder_config: + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${model.config.prediction_key} + block_size: 1024 + vocab_size: 50304 + n_layer_text: 4 + n_layer_multimodal_text: 4 + attention_config: + attention_engine_type: default_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: fused_swiglu + epsilon: 1e-5 + n_pool_head: 8 + n_vision_queries: 256 + bias_attn_pool: False + epsilon_attn_pool: 1e-5 + weight_init: + mean: 0.0 + std: 0.02 + +scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: 4000000 + pct_start: 0.01 + anneal_strategy: cos + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: dummy + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_samples: ${settings.training.global_num_seen_samples} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: ONLINE + experiment_id: ${settings.experiment_id} + directory: "." diff --git a/start.sh b/start.sh new file mode 100644 index 000000000..1d8f5b66f --- /dev/null +++ b/start.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +CUDA_VISIBLE_DEVICES=0 torchrun --nnodes 1 --nproc_per_node 1 --rdzv-endpoint=0.0.0.0:29502 src/modalities/__main__.py run --config_file_path config_files/config_example_coca_webdataset.yaml \ No newline at end of file From d8d5a5f8936dcd8a86e0640c8636e087706312f3 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Tue, 7 May 2024 12:41:08 +0200 Subject: [PATCH 009/161] feat: add nicer logging to wandb --- .../subscriber_impl/results_subscriber.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py index 74f5797a2..8a5d8dc4a 100644 --- a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py @@ -89,6 +89,15 @@ def consume_message(self, message: Message[EvaluationResultBatch]): wandb.log(data=throughput_metrics, step=eval_result.train_step_id + 1) - # wandb.log({"tokens_loss": wandb.plot.scatter("num_tokens", "loss", title="Tokens vs Loss")}) - # wandb.log({"steps_loss": wandb.plot.scatter("steps_loss", "loss", title="Steps vs Loss")}) - # wandb.log({"samples_loss": wandb.plot.scatter("samples_loss", "loss", title="Samples vs Loss")}) + num_samples = eval_result.train_step_id + 1 + group_content = [f"Train [{num_samples}]:"] + + losses = [f"{k}: {v}" for k, v in losses.items()] + metrics = [f"{k}: {v}" for k, v in metrics.items()] + + if losses: + group_content.append(" ".join(losses)) + if metrics: + group_content.append(" ".join(metrics)) + + print(" ".join(group_content)) From 4ea65c8384fda04ffd8cbd3567e1c2e180e5b44e Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Tue, 7 May 2024 12:41:49 +0200 Subject: [PATCH 010/161] fix: hardcoded batches in web loader --- src/modalities/dataloader/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index 471bc464a..0a192553e 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -153,7 +153,7 @@ def __init__(self, dataloader_tag: str, dataset: Dataset[T_co], batch_size: Opti self.num_batches = len(dataset) // batch_size self.webloader = wd.WebLoader(dataset=dataset, batch_size=None) # self.webloader = self.webloader.unbatched().shuffle(1000).batched(batch_size) - self.webloader = self.webloader.with_epoch(1282 * 100 // batch_size) + self.webloader = self.webloader.with_epoch(self.num_batches) self.dataloader_tag = dataloader_tag self.batch_size = batch_size From dfe88c945233f548cc79b5200783557ec0597455 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Tue, 7 May 2024 12:43:25 +0200 Subject: [PATCH 011/161] chore: update coca config --- .../training/config_example_coca.yaml | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/config_files/training/config_example_coca.yaml b/config_files/training/config_example_coca.yaml index fcc886f9d..946a340ff 100644 --- a/config_files/training/config_example_coca.yaml +++ b/config_files/training/config_example_coca.yaml @@ -5,14 +5,14 @@ settings: sample_key: input_ids target_key: target_ids training: - global_training_log_interval_in_steps: 2 - global_checkpointing_interval_in_steps: 2 - global_evaluation_interval_in_steps: 2 - global_num_training_samples: 12 + global_training_log_interval_in_steps: 64 + global_checkpointing_interval_in_steps: 5_000 + global_evaluation_interval_in_steps: 5_000 + global_num_training_samples: 10_000 global_num_seen_steps: 0 do_apply_activation_checkpointing: true gradient_acc_steps: 1 - local_train_micro_batch_size: 3 + local_train_micro_batch_size: 16 sequence_length: 256 cuda_env: local_rank: ${cuda_env:LOCAL_RANK} @@ -42,26 +42,26 @@ train_dataset: component_key: dataset variant_key: dummy_dataset config: - num_samples: 4 + num_samples: 10_000 sample_definition: - sample_key: images sample_shape: [3, 224, 224] sample_type: float - sample_key: input_ids - sample_shape: [1024] + sample_shape: [256] sample_type: int val_dataset: component_key: dataset variant_key: dummy_dataset config: - num_samples: 4 + num_samples: 1_000 sample_definition: - sample_key: images sample_shape: [3, 224, 224] sample_type: float - sample_key: input_ids - sample_shape: [1024] + sample_shape: [256] sample_type: int train_dataloader: @@ -251,7 +251,6 @@ gradient_clipper: pass_type: BY_REFERENCE norm_type: P2_NORM - batch_progress_subscriber: component_key: progress_subscriber variant_key: rich @@ -275,4 +274,4 @@ evaluation_subscriber: mode: OFFLINE experiment_id: ${settings.experiment_id} directory: "." - config_file_path: ${settings.config_file_path} \ No newline at end of file + config_file_path: ${settings.config_file_path} From 5a3e84454da2953592ae6cb548935a6b306d3d40 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Tue, 7 May 2024 14:45:28 +0200 Subject: [PATCH 012/161] fix: rebase --- .../training/config_example_coca.yaml | 33 +++++++--- .../config_example_coca_webdataset.yaml | 61 +++++++++++-------- src/modalities/config/config.py | 22 ++++--- src/modalities/config/pydanctic_if_types.py | 4 +- src/modalities/dataloader/dataset.py | 9 +-- src/modalities/dataloader/dataset_factory.py | 4 -- src/modalities/trainer.py | 24 ++------ 7 files changed, 88 insertions(+), 69 deletions(-) rename config_files/{ => training}/config_example_coca_webdataset.yaml (84%) diff --git a/config_files/training/config_example_coca.yaml b/config_files/training/config_example_coca.yaml index 946a340ff..ee61b77cd 100644 --- a/config_files/training/config_example_coca.yaml +++ b/config_files/training/config_example_coca.yaml @@ -5,7 +5,7 @@ settings: sample_key: input_ids target_key: target_ids training: - global_training_log_interval_in_steps: 64 + global_training_log_interval_in_steps: 4 global_checkpointing_interval_in_steps: 5_000 global_evaluation_interval_in_steps: 5_000 global_num_training_samples: 10_000 @@ -148,12 +148,27 @@ checkpoint_saving: sharding_strategy: FULL_SHARD block_names: [TransformerBlock, VisionTransformerBlock] -loss_fn: +captioning_loss: component_key: loss - variant_key: clm_cross_entropy_loss + variant_key: cross_entropy_loss config: target_key: ${settings.referencing_keys.target_key} - prediction_key: logits + prediction_key: ${model.config.prediction_key} + tag: captioning_loss + +contrastive_loss: + component_key: loss + variant_key: nce_loss + config: + prediction_key1: ${model.config.vision_cls_prediction_key} + prediction_key2: ${model.config.text_cls_prediction_key} + tag: contrastive_loss + +loss_fn: + - instance_key: captioning_loss + pass_type: BY_REFERENCE + - instance_key: contrastive_loss + pass_type: BY_REFERENCE wrapped_model: component_key: model @@ -181,7 +196,7 @@ model: prediction_key: vision_embeddings img_size: 224 n_classes: Null # Disable vision transformer head - n_layer: 12 + n_layer: 4 attention_config: attention_engine_type: default_attention n_head: 12 @@ -194,11 +209,11 @@ model: bias: True text_decoder_config: sample_key: ${settings.referencing_keys.sample_key} - prediction_key: ${loss_fn.config.prediction_key} + prediction_key: ${model.config.prediction_key} block_size: 1024 vocab_size: 50304 - n_layer_text: 12 - n_layer_multimodal_text: 12 + n_layer_text: 4 + n_layer_multimodal_text: 4 attention_config: attention_engine_type: default_attention n_head: 12 @@ -226,7 +241,7 @@ scheduler: max_lr: 6e-4 div_factor: 10 final_div_factor: 1 - total_steps: 4 + total_steps: 625 pct_start: 0.01 anneal_strategy: cos diff --git a/config_files/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml similarity index 84% rename from config_files/config_example_coca_webdataset.yaml rename to config_files/training/config_example_coca_webdataset.yaml index b2bd0c60a..d2fd7032a 100644 --- a/config_files/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -1,19 +1,19 @@ settings: experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} referencing_keys: sample_key: input_ids target_key: target_ids training: - callback_interval_in_samples: 32 - global_num_training_samples: 1281990 - global_num_seen_samples: 0 + global_training_log_interval_in_steps: 4 + global_checkpointing_interval_in_steps: 5_000 + global_evaluation_interval_in_steps: 5_000 + global_num_training_samples: 10_000 + global_num_seen_steps: 0 do_apply_activation_checkpointing: true gradient_acc_steps: 1 - local_train_micro_batch_size: 4 # This is the batch size per rank? + local_train_micro_batch_size: 16 sequence_length: 256 - gradient_clipping: - mode: p2_norm - threshold: 1.0 cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK} @@ -23,9 +23,11 @@ settings: tokenizer: component_key: tokenizer - variant_key: gpt2_tokenizer_fast + variant_key: pretrained_hf_tokenizer config: - tokenizer_file: data/tokenizer/tokenizer_gpt2.json + pretrained_model_name_or_path: data/tokenizer/hf_gpt2 + padding: false + max_length: 256 collate_fn: component_key: collate_fn @@ -48,7 +50,7 @@ train_dataset: source_text_key: txt text_key: input_ids block_size: ${settings.training.sequence_length} - num_samples: 1_281_000 + num_samples: 100_000 tokenizer: instance_key: tokenizer pass_type: BY_REFERENCE @@ -59,13 +61,13 @@ val_dataset: component_key: dataset variant_key: dummy_dataset config: - num_samples: 4 + num_samples: 1_000 sample_definition: - sample_key: images sample_shape: [3, 224, 224] sample_type: float - sample_key: input_ids - sample_shape: [1024] + sample_shape: [256] sample_type: int train_dataloader: @@ -100,7 +102,6 @@ val_dataloader: variant_key: default config: batch_size: ${settings.training.local_train_micro_batch_size} - drop_last: false sampler: component_key: sampler variant_key: distributed_sampler @@ -109,7 +110,7 @@ val_dataloader: num_replicas: ${settings.cuda_env.world_size} shuffle: false dataset: - instance_key: val_dataset + instance_key: train_dataset pass_type: BY_REFERENCE collate_fn: instance_key: collate_fn @@ -119,18 +120,18 @@ eval_dataloaders: - instance_key: val_dataloader pass_type: BY_REFERENCE -checkpointing: - component_key: checkpointing +checkpoint_saving: + component_key: checkpoint_saving variant_key: default config: - checkpointing_strategy: - component_key: checkpointing_strategy + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy variant_key: save_k_most_recent_checkpoints_strategy config: k: -1 # -1 to save all checkpoints - checkpointing_execution: - component_key: checkpointing_execution - variant_key: fsdp_to_disc_checkpointing + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: fsdp config: checkpoint_path: ${settings.paths.checkpointing_path} global_rank: ${settings.cuda_env.global_rank} @@ -232,7 +233,7 @@ scheduler: max_lr: 6e-4 div_factor: 10 final_div_factor: 1 - total_steps: 4000000 + total_steps: 625 pct_start: 0.01 anneal_strategy: cos @@ -248,13 +249,22 @@ optimizer: instance_key: wrapped_model pass_type: BY_REFERENCE +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp_logging_only + config: + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + batch_progress_subscriber: component_key: progress_subscriber - variant_key: dummy + variant_key: rich config: local_rank: ${settings.cuda_env.local_rank} world_size: ${settings.cuda_env.world_size} - global_num_seen_samples: ${settings.training.global_num_seen_samples} + global_num_seen_steps: ${settings.training.global_num_seen_steps} train_dataloader: instance_key: train_dataloader pass_type: BY_REFERENCE @@ -268,6 +278,7 @@ evaluation_subscriber: config: local_rank: ${settings.cuda_env.local_rank} project: modalities - mode: ONLINE + mode: OFFLINE experiment_id: ${settings.experiment_id} directory: "." + config_file_path: ${settings.config_file_path} diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index d13a1b41a..5079ffa93 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Annotated, Dict, List, Literal, Optional, Tuple, Union +from typing import Annotated, Dict, List, Literal, Optional, Tuple import torch from omegaconf import OmegaConf @@ -148,12 +148,12 @@ class OneCycleLRSchedulerConfig(BaseModel): pct_start: Annotated[float, Field(strict=True, gt=0.0, le=1.0)] anneal_strategy: str cycle_momentum: bool = True - base_momentum: Annotated[float, Field(strict=True, gt=0)] | List[Annotated[float, Field(strict=True, gt=0.0)]] = ( - 0.85 - ) - max_momentum: Annotated[float, Field(strict=True, gt=0.0)] | List[Annotated[float, Field(strict=True, gt=0.0)]] = ( - 0.95 - ) + base_momentum: Annotated[float, Field(strict=True, gt=0)] | List[ + Annotated[float, Field(strict=True, gt=0.0)] + ] = 0.85 + max_momentum: Annotated[float, Field(strict=True, gt=0.0)] | List[ + Annotated[float, Field(strict=True, gt=0.0)] + ] = 0.95 div_factor: Annotated[float, Field(strict=True, gt=0.0)] final_div_factor: Annotated[float, Field(strict=True, gt=0.0)] three_phase: bool = False @@ -299,6 +299,14 @@ class LLMDataLoaderConfig(BaseModel): skip_num_steps: Optional[int] = 0 +class WebLoaderConfig(BaseModel): + dataloader_tag: str + dataset: PydanticDatasetIFType + batch_size: int + collate_fn: PydanticCollateFnIFType + num_workers: Annotated[int, Field(strict=True, ge=0)] + + class RepeatingDataLoaderConfig(BaseModel): dataloader: PydanticLLMDataLoaderIFType reshuffle_after_epoch: Optional[bool] = False diff --git a/src/modalities/config/pydanctic_if_types.py b/src/modalities/config/pydanctic_if_types.py index 880e2217f..8f9d96864 100644 --- a/src/modalities/config/pydanctic_if_types.py +++ b/src/modalities/config/pydanctic_if_types.py @@ -12,7 +12,7 @@ from modalities.checkpointing.checkpoint_loading import CheckpointLoadingIF from modalities.checkpointing.checkpoint_saving import CheckpointSaving, CheckpointSavingExecutionABC from modalities.checkpointing.checkpoint_saving_strategies import CheckpointSavingStrategyIF -from modalities.dataloader.dataloader import LLMDataLoader +from modalities.dataloader.dataloader import DataLoaderIF from modalities.inference.text.inference_component import TextInferenceComponent from modalities.logging_broker.subscriber import MessageSubscriberIF from modalities.loss_functions import Loss @@ -53,7 +53,7 @@ def __get_pydantic_core_schema__( PydanticDatasetIFType = Annotated[Dataset, PydanticThirdPartyTypeIF(Dataset)] PydanticSamplerIFType = Annotated[Sampler, PydanticThirdPartyTypeIF(Sampler)] PydanticCollateFnIFType = Annotated[CollateFnIF, PydanticThirdPartyTypeIF(CollateFnIF)] -PydanticLLMDataLoaderIFType = Annotated[LLMDataLoader, PydanticThirdPartyTypeIF(LLMDataLoader)] +PydanticLLMDataLoaderIFType = Annotated[DataLoaderIF, PydanticThirdPartyTypeIF(DataLoaderIF)] PydanticOptimizerIFType = Annotated[Optimizer, PydanticThirdPartyTypeIF(Optimizer)] PydanticLRSchedulerIFType = Annotated[LRScheduler, PydanticThirdPartyTypeIF(LRScheduler)] PydanticLossIFType = Annotated[Loss, PydanticThirdPartyTypeIF(Loss)] diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 19eae2289..e4f9c2a1f 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -14,9 +14,8 @@ from tqdm import tqdm from transformers import BatchEncoding -from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper - from modalities.config.config import PydanticTokenizerIFType +from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper from ..dataloader.large_file_lines_reader import LargeFileLinesReader from .create_packed_data import EmbeddedStreamData @@ -263,7 +262,7 @@ def __init__( image_key: str, source_text_key: str, text_key: str, - tokenizer: PreTrainedTokenizer, + tokenizer: TokenizerWrapper, block_size: int, num_samples: int, image_transform_config: ImageTransformConfig, @@ -274,11 +273,13 @@ def __init__( self.append(wds.filters.shuffle(1000)) self.append(wds.filters.decode("pil")) + tokenizer.tokenizer.pad_token = tokenizer.tokenizer.eos_token + transform = create_transform(**image_transform_config.model_dump()) def make_sample(sample): # print(sample["json"]) - batch_encoding: BatchEncoding = tokenizer( + batch_encoding: BatchEncoding = tokenizer.tokenizer( sample["json"]["text0"], # [source_text_key], max_length=block_size, padding="max_length", diff --git a/src/modalities/dataloader/dataset_factory.py b/src/modalities/dataloader/dataset_factory.py index 9a4a1fe48..1aea2c5e0 100644 --- a/src/modalities/dataloader/dataset_factory.py +++ b/src/modalities/dataloader/dataset_factory.py @@ -105,10 +105,6 @@ def get_web_dataset( num_samples: int, image_transform_config: Optional[ImageTransformConfig] = None, ) -> WebDataset: - # TODO this was part of the old Dataloader implementation. - # we need to check if this is actually wanted generally. - tokenizer.pad_token = tokenizer.eos_token - dataset = WebDataset( urls=urls, source_image_key=source_image_key, diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 4a2bdb3de..5c7feb9b0 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -68,7 +68,6 @@ def _train_batch( losses.append(loss) (total_loss / self.gradient_acc_steps).backward() - self.gradient_clipper(model) if (train_step_id + 1) % self.gradient_acc_steps == 0 or (train_step_id + 1) == len(data_loader): gradient_norm_score = self.gradient_clipper.clip_gradients().sum() @@ -90,13 +89,12 @@ def train( checkpointing_callback: Callable[[int], None], ): model.train() - cumulated_losses = self._reset_tracked_losses() thoughput_aggregator = Aggregator[ThroughputAggregationKeys]() device = torch.device(self.local_rank if torch.cuda.is_available() else "cpu") - cumulated_loss = torch.zeros(len(loss_fun) + 1 + 1).to(device) + cumulated_losses = torch.zeros(len(loss_fun) + 1 + 1).to(device) # batch loop batch: DatasetBatch @@ -121,7 +119,7 @@ def train( forward_backward_time_recorder.stop() # Save the batch loss for i, batch_loss in enumerate(batch_losses): - cumulated_loss[i] += batch_loss.item() + cumulated_losses[i] += batch_loss.item() # This works, because we always drop the last batch in case it has less samples than the batch size cumulated_losses[-1] += 1 # number of local batches @@ -160,7 +158,7 @@ def train( operation=dist.ReduceOp.SUM, # 1.) summed batch loss / (num batches * world size) # 2.) last batch loss / world size - post_processing_fun=lambda t: torch.stack([t[:-1] / t[-1], t[-1] / dist.get_world_size()]), + post_processing_fun=lambda t: torch.cat([t[:-1] / t[-1], t[-1:] / dist.get_world_size()]), ) train_loss_avg, train_loss_last_batch = ( @@ -169,8 +167,8 @@ def train( ) losses = { - f"total_loss average": train_loss_avg, - f"total_loss last step": train_loss_last_batch, + "total_loss average": train_loss_avg, + "total_loss last step": train_loss_last_batch, } for i, lfn in enumerate(loss_fun): losses[lfn.tag] = reduced_losses[i + 1] @@ -205,24 +203,14 @@ def train( thoughput_aggregator.remove_keys() model.train() - cumulated_losses = self._reset_tracked_losses() + cumulated_losses = torch.zeros(len(loss_fun) + 1 + 1).to(device) evaluation_callback(train_step_id=train_step_id) checkpointing_callback(train_step_id=train_step_id) - cumulated_loss = torch.zeros(len(loss_fun) + 1 + 1).to(device) # we start the time recoder here again to also capture the time spend loading # via the dataloader. forward_backward_time_recorder.start() - def _reset_loss(self): - # TODO: we should handle the device assignment more centrally. - cumulated_loss = torch.zeros(2) - if torch.cuda.is_available(): - cumulated_loss = cumulated_loss.to(torch.device(self.local_rank)) - else: - cumulated_loss = cumulated_loss.to("cpu") - return cumulated_loss - @staticmethod def _publish_progress( batch_progress_publisher: MessagePublisher[BatchProgressUpdate], From 043384dac6ee05b9c1b77a1c416cdd026399bbda Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 9 May 2024 12:29:27 +0200 Subject: [PATCH 013/161] fix: print only on main rank in component factory --- src/modalities/config/component_factory.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/modalities/config/component_factory.py b/src/modalities/config/component_factory.py index c3a3dfd5a..8c3fd87f3 100644 --- a/src/modalities/config/component_factory.py +++ b/src/modalities/config/component_factory.py @@ -1,3 +1,4 @@ +import os from typing import Any, Dict, List, Type, TypeVar, Union from pydantic import BaseModel @@ -14,7 +15,6 @@ def __init__(self, registry: Registry) -> None: def build_components(self, config_dict: Dict, components_model_type: Type[BaseModelChild]) -> BaseModelChild: component_names = list(components_model_type.model_fields.keys()) component_dict = self._build_config(config_dict=config_dict, component_names=component_names) - print(component_dict) components = components_model_type(**component_dict) return components @@ -67,7 +67,9 @@ def _build_component( component = self._instantiate_component( component_key=component_key, variant_key=variant_key, component_config=current_component_config ) - print(" -> ".join(traversal_path) + ":", component) + + if os.environ["RANK"] == 0: + print(" -> ".join(traversal_path) + ":", component) # if the component is a top level component, then we add it to the top level components dictionary # to make sure that we don't build it again. Building it again would mean that we work by-value @@ -91,7 +93,8 @@ def _build_component( # so that we don't instantiate it again when we reach the respective component config # in the subsequent config traversal top_level_components[referenced_entity_key] = materialized_referenced_component - print(" -> ".join(traversal_path) + ": ", f"--ref--> {top_level_components[referenced_entity_key]}") + if os.environ["RANK"]: + print(" -> ".join(traversal_path) + ": ", f"--ref--> {top_level_components[referenced_entity_key]}") return top_level_components[referenced_entity_key], top_level_components return materialized_component_config, top_level_components From 3cd924475558d22216f916178d07ba6eb75c7d1c Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 9 May 2024 12:30:07 +0200 Subject: [PATCH 014/161] fix: total loss average logging --- src/modalities/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 5c7feb9b0..3db46a837 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -167,7 +167,7 @@ def train( ) losses = { - "total_loss average": train_loss_avg, + "total_loss average": train_loss_avg / train_loss_last_batch, "total_loss last step": train_loss_last_batch, } for i, lfn in enumerate(loss_fun): From b09d20e8c5ecf1bfefc01c11dc4dc1ae7e20f04c Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 9 May 2024 12:33:07 +0200 Subject: [PATCH 015/161] fix: cuda env and run script --- src/modalities/__main__.py | 6 ++++-- src/modalities/running_env/cuda_env.py | 3 +-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 7914cb034..777c06fab 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -8,6 +8,7 @@ import click import click_pathlib +import torch.distributed as dist from pydantic import BaseModel, FilePath from modalities.activation_checkpointing import apply_activation_checkpointing_inplace @@ -48,11 +49,12 @@ def main() -> None: help="Path to a file with the YAML config file.", ) def entry_point_run_modalities(config_file_path: Path): - config_dict = load_app_config_dict(config_file_path) - main_obj = Main(config_dict, config_file_path) with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): + config_dict = load_app_config_dict(config_file_path) + main_obj = Main(config_dict, config_file_path) components = main_obj.build_components(components_model_type=TrainingComponentsInstantiationModel) main_obj.run(components) + dist.barrier() @main.command(name="generate_text") diff --git a/src/modalities/running_env/cuda_env.py b/src/modalities/running_env/cuda_env.py index 3faac6e2b..c7a92d0a7 100644 --- a/src/modalities/running_env/cuda_env.py +++ b/src/modalities/running_env/cuda_env.py @@ -13,7 +13,7 @@ def __init__( ) -> None: self.process_group_backend = process_group_backend # TODO we might want to set this from outside via the config - self.local_rank = int(os.getenv("LOCAL_RANK", "0")) + self.local_rank = int(os.environ["LOCAL_RANK"]) def __enter__(self) -> "CudaEnv": dist.init_process_group(self.process_group_backend.value) @@ -21,5 +21,4 @@ def __enter__(self) -> "CudaEnv": return self def __exit__(self, type, value, traceback): - dist.barrier() dist.destroy_process_group() From e09745db2ad2db5e8e5c2ccc988eb357ea212964 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 9 May 2024 12:58:17 +0200 Subject: [PATCH 016/161] chore: update coca config --- .../training/config_example_coca.yaml | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/config_files/training/config_example_coca.yaml b/config_files/training/config_example_coca.yaml index ee61b77cd..65f5604d4 100644 --- a/config_files/training/config_example_coca.yaml +++ b/config_files/training/config_example_coca.yaml @@ -8,11 +8,11 @@ settings: global_training_log_interval_in_steps: 4 global_checkpointing_interval_in_steps: 5_000 global_evaluation_interval_in_steps: 5_000 - global_num_training_samples: 10_000 + global_num_training_samples: 102_400 global_num_seen_steps: 0 do_apply_activation_checkpointing: true gradient_acc_steps: 1 - local_train_micro_batch_size: 16 + local_train_micro_batch_size: 144 sequence_length: 256 cuda_env: local_rank: ${cuda_env:LOCAL_RANK} @@ -45,7 +45,7 @@ train_dataset: num_samples: 10_000 sample_definition: - sample_key: images - sample_shape: [3, 224, 224] + sample_shape: [3, 256, 256] sample_type: float - sample_key: input_ids sample_shape: [256] @@ -58,7 +58,7 @@ val_dataset: num_samples: 1_000 sample_definition: - sample_key: images - sample_shape: [3, 224, 224] + sample_shape: [3, 256, 256] sample_type: float - sample_key: input_ids sample_shape: [256] @@ -145,7 +145,7 @@ checkpoint_saving: global_rank: ${settings.cuda_env.global_rank} experiment_id: ${settings.experiment_id} mixed_precision_settings: FP_16 - sharding_strategy: FULL_SHARD + sharding_strategy: HYBRID_SHARD block_names: [TransformerBlock, VisionTransformerBlock] captioning_loss: @@ -179,7 +179,7 @@ wrapped_model: pass_type: BY_REFERENCE sync_module_states: true mixed_precision_settings: FP_16 - sharding_strategy: FULL_SHARD + sharding_strategy: HYBRID_SHARD block_names: [TransformerBlock, VisionTransformerBlock] model: @@ -194,16 +194,16 @@ model: vision_encoder_config: sample_key: images prediction_key: vision_embeddings - img_size: 224 + img_size: 256 # 288 in the original coca n_classes: Null # Disable vision transformer head - n_layer: 4 + n_layer: 12 attention_config: attention_engine_type: default_attention n_head: 12 n_embd: 768 dropout: 0.0 - patch_size: 16 - patch_stride: 16 + patch_size: 16 # 18 in the original coca + patch_stride: 16 # 18 in the original coca n_img_channels: 3 add_cls_token: False bias: True @@ -211,9 +211,9 @@ model: sample_key: ${settings.referencing_keys.sample_key} prediction_key: ${model.config.prediction_key} block_size: 1024 - vocab_size: 50304 - n_layer_text: 4 - n_layer_multimodal_text: 4 + vocab_size: 50304 # 64k in the original coca + n_layer_text: 12 + n_layer_multimodal_text: 12 attention_config: attention_engine_type: default_attention n_head: 12 @@ -223,7 +223,7 @@ model: bias: true activation: fused_swiglu epsilon: 1e-5 - n_pool_head: 8 + n_pool_head: 12 n_vision_queries: 256 bias_attn_pool: False epsilon_attn_pool: 1e-5 @@ -233,26 +233,26 @@ model: scheduler: component_key: scheduler - variant_key: onecycle_lr + variant_key: onecycle_lr # COCA uses linear decay config: optimizer: instance_key: optimizer pass_type: BY_REFERENCE - max_lr: 6e-4 + max_lr: 8e-4 div_factor: 10 final_div_factor: 1 - total_steps: 625 - pct_start: 0.01 + total_steps: 500_000 # depends on 500.000 iterations on 65,536 image-text pairs -> 5 epochs on JFT -> 32.7B image-text pairs + pct_start: 0.02 anneal_strategy: cos optimizer: component_key: optimizer variant_key: adam_w config: - lr: 0.0001 - betas: [0.9, 0.95] + lr: 8e-4 + betas: [0.9, 0.999] eps: 1e-8 - weight_decay: 1e-1 + weight_decay: 0.01 wrapped_model: instance_key: wrapped_model pass_type: BY_REFERENCE @@ -270,7 +270,7 @@ batch_progress_subscriber: component_key: progress_subscriber variant_key: rich config: - local_rank: ${settings.cuda_env.local_rank} + local_rank: ${settings.cuda_env.global_rank} world_size: ${settings.cuda_env.world_size} global_num_seen_steps: ${settings.training.global_num_seen_steps} train_dataloader: @@ -284,9 +284,9 @@ evaluation_subscriber: component_key: results_subscriber variant_key: wandb config: - local_rank: ${settings.cuda_env.local_rank} + local_rank: ${settings.cuda_env.global_rank} project: modalities - mode: OFFLINE + mode: ONLINE experiment_id: ${settings.experiment_id} directory: "." config_file_path: ${settings.config_file_path} From 5a74deefec1122aaa1e6c96831fc79b17af54839 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 9 May 2024 13:11:57 +0200 Subject: [PATCH 017/161] fix: print parameters and done only on main rank --- src/modalities/__main__.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 777c06fab..bb4addff8 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -1,6 +1,5 @@ #!/usr/bin/env python -import logging import os import shutil from pathlib import Path @@ -221,7 +220,12 @@ def run(self, components: TrainingComponentsInstantiationModel): num_ranks=components.settings.cuda_env.world_size, ) wrapped_model = components.wrapped_model - logging.info(f"Training model with {compute_number_of_trainable_parameters(wrapped_model)} parameters.") + + if os.environ["RANK"] == 0: + # TODO calculate parameters for full model + print( + f"Training model with {compute_number_of_trainable_parameters(wrapped_model)} parameters (per process)." + ) if components.settings.training.do_apply_activation_checkpointing: apply_activation_checkpointing_inplace(wrapped_model) @@ -237,7 +241,10 @@ def run(self, components: TrainingComponentsInstantiationModel): global_evaluation_interval_in_steps=components.settings.training.global_evaluation_interval_in_steps, global_training_log_interval_in_steps=components.settings.training.global_training_log_interval_in_steps, ) - print("done") + + dist.barrier() + if os.environ["RANK"] == 0: + print("done") def get_logging_publishers( self, From f7b725c2165b068832ef5ef2543a14f84d9b886b Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 9 May 2024 15:43:00 +0200 Subject: [PATCH 018/161] chore: update coca wds config --- .../config_example_coca_webdataset.yaml | 104 ++++++++---------- 1 file changed, 46 insertions(+), 58 deletions(-) diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml index d2fd7032a..2bee396cf 100644 --- a/config_files/training/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -6,13 +6,13 @@ settings: target_key: target_ids training: global_training_log_interval_in_steps: 4 - global_checkpointing_interval_in_steps: 5_000 - global_evaluation_interval_in_steps: 5_000 - global_num_training_samples: 10_000 + global_checkpointing_interval_in_steps: 100 + global_evaluation_interval_in_steps: 100 + global_num_training_samples: ${train_dataset.config.num_samples} # 491 steps with 8 gpus and global bs of 1152 global_num_seen_steps: 0 do_apply_activation_checkpointing: true gradient_acc_steps: 1 - local_train_micro_batch_size: 16 + local_train_micro_batch_size: 144 sequence_length: 256 cuda_env: local_rank: ${cuda_env:LOCAL_RANK} @@ -23,11 +23,9 @@ settings: tokenizer: component_key: tokenizer - variant_key: pretrained_hf_tokenizer + variant_key: gpt2_tokenizer_fast config: - pretrained_model_name_or_path: data/tokenizer/hf_gpt2 - padding: false - max_length: 256 + tokenizer_file: data/tokenizer/tokenizer_gpt2.json collate_fn: component_key: collate_fn @@ -44,13 +42,13 @@ train_dataset: component_key: dataset variant_key: web_dataset config: - urls: "coco_captions/data/train/{000000..000011}.tar" + urls: "coco/train/{000000..000566}.tar" source_image_key: jpg image_key: images source_text_key: txt text_key: input_ids block_size: ${settings.training.sequence_length} - num_samples: 100_000 + num_samples: 566_747 tokenizer: instance_key: tokenizer pass_type: BY_REFERENCE @@ -59,24 +57,28 @@ train_dataset: val_dataset: component_key: dataset - variant_key: dummy_dataset + variant_key: web_dataset config: - num_samples: 1_000 - sample_definition: - - sample_key: images - sample_shape: [3, 224, 224] - sample_type: float - - sample_key: input_ids - sample_shape: [256] - sample_type: int + urls: "coco/val/{000000..000025}.tar" + source_image_key: jpg + image_key: images + source_text_key: txt + text_key: input_ids + block_size: ${settings.training.sequence_length} + num_samples: 25_010 + tokenizer: + instance_key: tokenizer + pass_type: BY_REFERENCE + image_transform_config: + is_training: False train_dataloader: component_key: data_loader variant_key: web_loader config: - num_workers: 2 + num_workers: 16 pin_memory: true - shuffle: false + shuffle: true dataloader_tag: "train" dataset: instance_key: train_dataset @@ -88,30 +90,16 @@ train_dataloader: val_dataloader: component_key: data_loader - variant_key: default + variant_key: web_loader config: - num_workers: 2 + num_workers: 16 pin_memory: true shuffle: false dataloader_tag: "val" dataset: instance_key: val_dataset pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.training.local_train_micro_batch_size} - sampler: - component_key: sampler - variant_key: distributed_sampler - config: - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: false - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE + batch_size: ${settings.training.local_train_micro_batch_size} collate_fn: instance_key: collate_fn pass_type: BY_REFERENCE @@ -137,7 +125,7 @@ checkpoint_saving: global_rank: ${settings.cuda_env.global_rank} experiment_id: ${settings.experiment_id} mixed_precision_settings: FP_16 - sharding_strategy: FULL_SHARD + sharding_strategy: HYBRID_SHARD block_names: [TransformerBlock, VisionTransformerBlock] captioning_loss: @@ -171,7 +159,7 @@ wrapped_model: pass_type: BY_REFERENCE sync_module_states: true mixed_precision_settings: FP_16 - sharding_strategy: FULL_SHARD + sharding_strategy: HYBRID_SHARD block_names: [TransformerBlock, VisionTransformerBlock] model: @@ -186,16 +174,16 @@ model: vision_encoder_config: sample_key: images prediction_key: vision_embeddings - img_size: 224 + img_size: 256 # 288 in the original coca n_classes: Null # Disable vision transformer head - n_layer: 4 + n_layer: 12 attention_config: attention_engine_type: default_attention n_head: 12 n_embd: 768 dropout: 0.0 - patch_size: 16 - patch_stride: 16 + patch_size: 16 # 18 in the original coca + patch_stride: 16 # 18 in the original coca n_img_channels: 3 add_cls_token: False bias: True @@ -203,9 +191,9 @@ model: sample_key: ${settings.referencing_keys.sample_key} prediction_key: ${model.config.prediction_key} block_size: 1024 - vocab_size: 50304 - n_layer_text: 4 - n_layer_multimodal_text: 4 + vocab_size: 50304 # 64k in the original coca + n_layer_text: 12 + n_layer_multimodal_text: 12 attention_config: attention_engine_type: default_attention n_head: 12 @@ -215,7 +203,7 @@ model: bias: true activation: fused_swiglu epsilon: 1e-5 - n_pool_head: 8 + n_pool_head: 12 n_vision_queries: 256 bias_attn_pool: False epsilon_attn_pool: 1e-5 @@ -225,26 +213,26 @@ model: scheduler: component_key: scheduler - variant_key: onecycle_lr + variant_key: onecycle_lr # COCA uses linear decay config: optimizer: instance_key: optimizer pass_type: BY_REFERENCE - max_lr: 6e-4 + max_lr: 8e-4 div_factor: 10 final_div_factor: 1 - total_steps: 625 - pct_start: 0.01 + total_steps: 491 # 500_000 # depends on 500.000 iterations on 65,536 image-text pairs -> 5 epochs on JFT -> 32.7B image-text pairs + pct_start: 0.02 anneal_strategy: cos optimizer: component_key: optimizer variant_key: adam_w config: - lr: 0.0001 - betas: [0.9, 0.95] + lr: 8e-4 + betas: [0.9, 0.999] eps: 1e-8 - weight_decay: 1e-1 + weight_decay: 0.01 wrapped_model: instance_key: wrapped_model pass_type: BY_REFERENCE @@ -262,7 +250,7 @@ batch_progress_subscriber: component_key: progress_subscriber variant_key: rich config: - local_rank: ${settings.cuda_env.local_rank} + local_rank: ${settings.cuda_env.global_rank} world_size: ${settings.cuda_env.world_size} global_num_seen_steps: ${settings.training.global_num_seen_steps} train_dataloader: @@ -276,9 +264,9 @@ evaluation_subscriber: component_key: results_subscriber variant_key: wandb config: - local_rank: ${settings.cuda_env.local_rank} + local_rank: ${settings.cuda_env.global_rank} project: modalities - mode: OFFLINE + mode: ONLINE experiment_id: ${settings.experiment_id} directory: "." config_file_path: ${settings.config_file_path} From c7308e2d21a11274b5c9b078c16bed9d68f3db1d Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 9 May 2024 15:53:38 +0200 Subject: [PATCH 019/161] fix: tokenizer config of coca --- config_files/training/config_example_coca.yaml | 6 ++++-- config_files/training/config_example_coca_webdataset.yaml | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/config_files/training/config_example_coca.yaml b/config_files/training/config_example_coca.yaml index 65f5604d4..b000ef122 100644 --- a/config_files/training/config_example_coca.yaml +++ b/config_files/training/config_example_coca.yaml @@ -23,9 +23,11 @@ settings: tokenizer: component_key: tokenizer - variant_key: gpt2_tokenizer_fast + variant_key: pretrained_hf_tokenizer config: - tokenizer_file: data/tokenizer/tokenizer_gpt2.json + pretrained_model_name_or_path: gpt2 + padding: true + max_length: 256 collate_fn: component_key: collate_fn diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml index 2bee396cf..0fb51870e 100644 --- a/config_files/training/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -23,9 +23,11 @@ settings: tokenizer: component_key: tokenizer - variant_key: gpt2_tokenizer_fast + variant_key: pretrained_hf_tokenizer config: - tokenizer_file: data/tokenizer/tokenizer_gpt2.json + pretrained_model_name_or_path: gpt2 + padding: true + max_length: 256 collate_fn: component_key: collate_fn From 32d0b19dcaade4d19d5b9679d4fff1b1cb3e7d64 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 9 May 2024 16:00:53 +0200 Subject: [PATCH 020/161] fix: add multinode splitter to webdataset --- src/modalities/dataloader/dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index e4f9c2a1f..0c18ae25d 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -267,7 +267,8 @@ def __init__( num_samples: int, image_transform_config: ImageTransformConfig, ): - super().__init__(urls=urls) + # TODO auto node splitter + super().__init__(urls=urls, nodesplitter=wds.shardlists.split_by_node, repeat=True) self.num_samples = num_samples self.append(wds.filters.shuffle(1000)) @@ -280,7 +281,7 @@ def __init__( def make_sample(sample): # print(sample["json"]) batch_encoding: BatchEncoding = tokenizer.tokenizer( - sample["json"]["text0"], # [source_text_key], + sample[source_text_key], max_length=block_size, padding="max_length", truncation=True, From 55c039fdab81bcd558dcd5096fdb819e71811861 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 9 May 2024 21:28:06 +0200 Subject: [PATCH 021/161] fix: webdataset slow loading --- src/modalities/dataloader/dataset.py | 39 +++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 0c18ae25d..744af9c09 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -6,6 +6,7 @@ import jq import numpy as np +import torch import webdataset as wds from pydantic import BaseModel from timm.data import create_transform @@ -252,6 +253,27 @@ class WebDatasetConfig(BaseModel): block_size: int num_samples: int image_transform_config: Optional[ImageTransformConfig] = None + shardshuffle: Optional[int] = None + repeat: bool = False + resample: bool = False + shuffle: int = 0 + + +def nodesplitter(src, group=None): + if torch.distributed.is_initialized(): + if group is None: + group = torch.distributed.group.WORLD + rank = torch.distributed.get_rank(group=group) + size = torch.distributed.get_world_size(group=group) + print(f"nodesplitter: rank={rank} size={size}") + count = 0 + for i, item in enumerate(src): + if i % size == rank: + yield item + count += 1 + print(f"nodesplitter: rank={rank} size={size} count={count} DONE") + else: + yield from src class WebDataset(wds.WebDataset): @@ -266,12 +288,23 @@ def __init__( block_size: int, num_samples: int, image_transform_config: ImageTransformConfig, + shardshuffle: int, + repeat: bool, + resample: bool, + shuffle: int, ): - # TODO auto node splitter - super().__init__(urls=urls, nodesplitter=wds.shardlists.split_by_node, repeat=True) + super().__init__( + urls=urls, + nodesplitter=nodesplitter if not resample else None, + workersplitter=wds.shardlists.split_by_worker, + shardshuffle=shardshuffle, + repeat=repeat, + handler=wds.ignore_and_continue, + resampled=resample, + ) self.num_samples = num_samples - self.append(wds.filters.shuffle(1000)) + self.append(wds.filters.shuffle(shuffle)) self.append(wds.filters.decode("pil")) tokenizer.tokenizer.pad_token = tokenizer.tokenizer.eos_token From 966d23785b4af0d08522ddc89d28f6307192e4b2 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 9 May 2024 21:31:55 +0200 Subject: [PATCH 022/161] fix: add batching --- src/modalities/dataloader/dataloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index 0a192553e..a3a6065c0 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -151,6 +151,7 @@ def __len__(self) -> int: class WebLoader(DataLoaderIF): def __init__(self, dataloader_tag: str, dataset: Dataset[T_co], batch_size: Optional[int] = 1, *args, **kwargs): self.num_batches = len(dataset) // batch_size + dataset = dataset.batched(batch_size) self.webloader = wd.WebLoader(dataset=dataset, batch_size=None) # self.webloader = self.webloader.unbatched().shuffle(1000).batched(batch_size) self.webloader = self.webloader.with_epoch(self.num_batches) From dacf639579ef235313a418fbd12c7ef030b0f7ea Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 9 May 2024 21:36:13 +0200 Subject: [PATCH 023/161] fix: add more options to webloader --- src/modalities/dataloader/dataloader.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index a3a6065c0..7cf3389fb 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -149,10 +149,18 @@ def __len__(self) -> int: class WebLoader(DataLoaderIF): - def __init__(self, dataloader_tag: str, dataset: Dataset[T_co], batch_size: Optional[int] = 1, *args, **kwargs): + def __init__( + self, + dataloader_tag: str, + dataset: Dataset[T_co], + batch_size: Optional[int] = 1, + num_workers: int = 0, + collate_fn: Optional[_collate_fn_t] = None, + pin_memory: bool = False, + ): self.num_batches = len(dataset) // batch_size - dataset = dataset.batched(batch_size) - self.webloader = wd.WebLoader(dataset=dataset, batch_size=None) + dataset = dataset.batched(batch_size, collation_fn=collate_fn) + self.webloader = wd.WebLoader(dataset=dataset, batch_size=None, num_workers=num_workers, pin_memory=pin_memory) # self.webloader = self.webloader.unbatched().shuffle(1000).batched(batch_size) self.webloader = self.webloader.with_epoch(self.num_batches) self.dataloader_tag = dataloader_tag From b40ecd52d18621fd87ab7a05698380d9c0314d37 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 9 May 2024 21:41:54 +0200 Subject: [PATCH 024/161] fix: webloader --- src/modalities/config/config.py | 1 + src/modalities/dataloader/dataloader.py | 1 - src/modalities/dataloader/dataloader_factory.py | 14 +++++++------- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 5079ffa93..d0a5828db 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -305,6 +305,7 @@ class WebLoaderConfig(BaseModel): batch_size: int collate_fn: PydanticCollateFnIFType num_workers: Annotated[int, Field(strict=True, ge=0)] + pin_memory: bool class RepeatingDataLoaderConfig(BaseModel): diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index 7cf3389fb..c92321f74 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -161,7 +161,6 @@ def __init__( self.num_batches = len(dataset) // batch_size dataset = dataset.batched(batch_size, collation_fn=collate_fn) self.webloader = wd.WebLoader(dataset=dataset, batch_size=None, num_workers=num_workers, pin_memory=pin_memory) - # self.webloader = self.webloader.unbatched().shuffle(1000).batched(batch_size) self.webloader = self.webloader.with_epoch(self.num_batches) self.dataloader_tag = dataloader_tag self.batch_size = batch_size diff --git a/src/modalities/dataloader/dataloader_factory.py b/src/modalities/dataloader/dataloader_factory.py index 0b533aabe..6241bbd62 100644 --- a/src/modalities/dataloader/dataloader_factory.py +++ b/src/modalities/dataloader/dataloader_factory.py @@ -41,14 +41,14 @@ def get_repeating_dataloader( @staticmethod def get_web_loader( - dataloader_tag: str, - dataset: Dataset, - batch_size: int, - collate_fn: Callable, - num_workers: int, + dataloader_tag: str, dataset: Dataset, batch_size: int, collate_fn: Callable, num_workers: int, pin_memory: bool ) -> WebLoader: - dataset = dataset.batched(batch_size, collation_fn=collate_fn) dataloader = WebLoader( - dataloader_tag=dataloader_tag, dataset=dataset, batch_size=batch_size, num_workers=num_workers + dataloader_tag=dataloader_tag, + dataset=dataset, + batch_size=batch_size, + collate_fn=collate_fn, + num_workers=num_workers, + pin_memory=pin_memory, ) return dataloader From a9ce132fd663a884d3c20e803507bb8874600939 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 9 May 2024 21:47:22 +0200 Subject: [PATCH 025/161] fix: dataset factory --- src/modalities/dataloader/dataset_factory.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/modalities/dataloader/dataset_factory.py b/src/modalities/dataloader/dataset_factory.py index 1aea2c5e0..261765c0c 100644 --- a/src/modalities/dataloader/dataset_factory.py +++ b/src/modalities/dataloader/dataset_factory.py @@ -104,6 +104,10 @@ def get_web_dataset( block_size: int, num_samples: int, image_transform_config: Optional[ImageTransformConfig] = None, + shardshuffle: Optional[int] = None, + repeat: bool = False, + resample: bool = False, + shuffle: int = 0, ) -> WebDataset: dataset = WebDataset( urls=urls, @@ -115,5 +119,9 @@ def get_web_dataset( block_size=block_size, num_samples=num_samples, image_transform_config=image_transform_config, + shardshuffle=shardshuffle, + repeat=repeat, + resample=resample, + shuffle=shuffle, ) return dataset From 63ef47c1a18491755aa92ebc9e5174cf93387c8f Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 9 May 2024 21:52:46 +0200 Subject: [PATCH 026/161] fix: webdataset --- src/modalities/dataloader/dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 744af9c09..74218a493 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -296,7 +296,6 @@ def __init__( super().__init__( urls=urls, nodesplitter=nodesplitter if not resample else None, - workersplitter=wds.shardlists.split_by_worker, shardshuffle=shardshuffle, repeat=repeat, handler=wds.ignore_and_continue, From d6d84dc919b9d53384ab261fb5b5d6f404bb5c06 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Fri, 10 May 2024 08:47:17 +0200 Subject: [PATCH 027/161] fix: loss accumulation --- src/modalities/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 3db46a837..c89307f14 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -151,7 +151,7 @@ def train( synced_num_samples_per_second = synced_num_samples / synced_forward_backward_time # TODO: insert reducer from outside so Trainer is independent of FSDP # add the loss and gradient norm for the LAST batch - cumulated_losses[1] = batch_loss.item() + # cumulated_losses[1] = batch_loss.item() reduced_losses = Reducer.reduce( tensor=cumulated_losses, From 3d04f788dd38c2ad9340be7388117daf1b19c474 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Fri, 10 May 2024 08:47:55 +0200 Subject: [PATCH 028/161] fix: loss average for eval --- src/modalities/evaluator.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index 300820625..4cd406ed7 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List +from typing import Dict, List import torch import torch.distributed as dist @@ -9,7 +9,7 @@ from modalities.logging_broker.messages import BatchProgressUpdate, ExperimentStatus, MessageTypes from modalities.logging_broker.publisher import MessagePublisher from modalities.loss_functions import Loss -from modalities.models.model import NNModel, model_predict_batch +from modalities.models.model import model_predict_batch from modalities.running_env.fsdp.reducer import Reducer from modalities.trainer import ThroughputAggregationKeys from modalities.util import Aggregator, TimeRecorder @@ -42,7 +42,7 @@ def evaluate_batch( loss = lfn(result_batch) # Add loss to total loss - weighted_loss = loss * lfn.weight # / self.gradient_acc_steps + weighted_loss = loss * lfn.weight if total_loss is None: total_loss = weighted_loss else: @@ -108,16 +108,24 @@ def evaluate( num_samples_per_second = synced_num_samples / synced_forward_backward_time # Agreggate loss from all ranks - total_losses = Reducer.reduce( + reduced_losses = Reducer.reduce( tensor=cumulated_loss, operation=dist.ReduceOp.SUM, - post_processing_fun=lambda t: t[:-1] / t[-1], + post_processing_fun=lambda t: torch.cat([t[:-1] / t[-1], t[-1:] / dist.get_world_size()]), ) # Fill logging dict with total loss and the individual losses - losses = {"total_loss": total_losses[0]} + loss_avg, loss_last_batch = ( + reduced_losses[0], + reduced_losses[-1], + ) + + losses = { + "total_loss average": loss_avg / loss_last_batch, + "total_loss last step": loss_last_batch, + } for i, lfn in enumerate(loss_fun): - losses[lfn.tag] = total_losses[i + 1] + losses[lfn.tag] = reduced_losses[i + 1] evaluation_result = EvaluationResultBatch( losses=losses, From e65a3cdd0135deb6469098d758d87a53b91e9293 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Fri, 10 May 2024 08:48:40 +0200 Subject: [PATCH 029/161] refactor: remove unused code from coca collator --- src/modalities/models/coca/collator.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/modalities/models/coca/collator.py b/src/modalities/models/coca/collator.py index 42b40ee03..2ca768e2e 100644 --- a/src/modalities/models/coca/collator.py +++ b/src/modalities/models/coca/collator.py @@ -31,11 +31,9 @@ def __init__(self, sample_keys: List[str], target_keys: List[str], text_sample_k self.text_target_key = text_target_key def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: - samples = { - sample_key: torch.stack([torch.tensor(d[sample_key]) for d in batch]) for sample_key in self.sample_keys - } + samples = {sample_key: torch.stack([d[sample_key] for d in batch]) for sample_key in self.sample_keys} if "attention_mask" in batch[0]: - samples["attention_mask"] = torch.stack([torch.tensor(d["attention_mask"]) for d in batch]) + samples["attention_mask"] = torch.stack([d["attention_mask"] for d in batch]) targets = {target_key: torch.stack([d[target_key] for d in batch]) for target_key in self.target_keys} From 622570d9668eb28548a4f9291def1eacfd155868 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Fri, 10 May 2024 09:02:58 +0200 Subject: [PATCH 030/161] fix: coca collator --- src/modalities/models/coca/collator.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/modalities/models/coca/collator.py b/src/modalities/models/coca/collator.py index 2ca768e2e..e220aa677 100644 --- a/src/modalities/models/coca/collator.py +++ b/src/modalities/models/coca/collator.py @@ -31,11 +31,17 @@ def __init__(self, sample_keys: List[str], target_keys: List[str], text_sample_k self.text_target_key = text_target_key def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: - samples = {sample_key: torch.stack([d[sample_key] for d in batch]) for sample_key in self.sample_keys} + samples = { + sample_key: torch.stack([self._prepare_sample(d[sample_key]) for d in batch]) + for sample_key in self.sample_keys + } if "attention_mask" in batch[0]: - samples["attention_mask"] = torch.stack([d["attention_mask"] for d in batch]) + samples["attention_mask"] = torch.stack([self._prepare_sample(d["attention_mask"]) for d in batch]) - targets = {target_key: torch.stack([d[target_key] for d in batch]) for target_key in self.target_keys} + targets = { + target_key: torch.stack([self._prepare_sample(d[target_key]) for d in batch]) + for target_key in self.target_keys + } # Create target for text input targets[self.text_target_key] = samples[self.text_sample_key][:, 1:].clone().detach() @@ -46,3 +52,9 @@ def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: samples["attention_mask"] = samples["attention_mask"][:, :-1] return DatasetBatch(targets=targets, samples=samples) + + @staticmethod + def _prepare_sample(x): + if isinstance(x, torch.Tensor): + return x + return torch.tensor(x) From efedc77022311b21d2785b6178a344467b9d1483 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Fri, 10 May 2024 10:21:28 +0200 Subject: [PATCH 031/161] fix: loss normalization --- src/modalities/evaluator.py | 2 +- src/modalities/trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index 4cd406ed7..c02f9f38d 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -121,7 +121,7 @@ def evaluate( ) losses = { - "total_loss average": loss_avg / loss_last_batch, + "total_loss average": loss_avg, "total_loss last step": loss_last_batch, } for i, lfn in enumerate(loss_fun): diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index c89307f14..8eb0db674 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -167,7 +167,7 @@ def train( ) losses = { - "total_loss average": train_loss_avg / train_loss_last_batch, + "total_loss average": train_loss_avg, "total_loss last step": train_loss_last_batch, } for i, lfn in enumerate(loss_fun): From eba23a9dd1ff570cec7b0af13f13b819f18f8530 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Fri, 10 May 2024 11:16:52 +0200 Subject: [PATCH 032/161] feat: add clip loss --- src/modalities/loss_functions.py | 76 +++++++++++++++++++++++++++ src/modalities/registry/components.py | 10 +++- 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index 923cf9ae2..3f8b9dd14 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod import torch +import torch.distributed as dist +import torch.nn.functional as F from pydantic import BaseModel from torch.nn import CrossEntropyLoss as TorchCrossEntropyLoss @@ -143,3 +145,77 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: contiguous_embedding1, contiguous_embedding2, embedding1.device, self.is_asymmetric, self.temperature ) return loss + + +class ClipLossConfig(BaseModel): + logit_scale_key: str + prediction_key1: str + prediction_key2: str + tag: str = "ClipLoss" + + +class ClipLoss(Loss): + def __init__( + self, + logit_scale_key: str, + prediction_key1: str, + prediction_key2: str, + tag: str = "ClipLoss", + ): + """ + CLIP Loss (Source: https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/loss.py) + + Args: + logit_scale_key (str): Value of a learnable logit scale parameter. + prediction_key1 (str): Key to access embedding 1. + prediction_key2 (str): Key to access embedding 2. + tag (str, optional): Defaults to "ClipLoss". + """ + super().__init__(tag) + self.logit_scale_key = logit_scale_key + self.prediction_key1 = prediction_key1 + self.prediction_key2 = prediction_key2 + + def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: + """ + Args: + forward_batch (InferenceResultBatch): data batch. + + Returns: + torch.Tensor: loss tensor. + """ + logit_scale = forward_batch.get_predictions(self.logit_scale_key) + embedding1 = forward_batch.get_predictions(self.prediction_key1).contiguous() + embedding2 = forward_batch.get_predictions(self.prediction_key2).contiguous() + device = embedding1.device + + # Gather all embeddings from each rank + world_size = dist.get_world_size() + rank = dist.get_rank() + gathered_embedding1 = [torch.zeros_like(embedding1) for _ in range(world_size)] + gathered_embedding2 = [torch.zeros_like(embedding2) for _ in range(world_size)] + dist.all_gather(gathered_embedding1, embedding1) + dist.all_gather(gathered_embedding2, embedding2) + + # Make sure we have gradients for the "local" embeddings + gathered_embedding1[rank] = embedding1 + gathered_embedding2[rank] = embedding2 + + # Combine embeddings + gathered_embedding1 = torch.cat(gathered_embedding1, dim=0) + gathered_embedding2 = torch.cat(gathered_embedding2, dim=0) + + # Calculate logits + logits_per_embedding1 = logit_scale * gathered_embedding1 @ gathered_embedding2.T + logits_per_embedding2 = logits_per_embedding1.T + + # Build gt labels for diagonal + num_logits = logits_per_embedding1.shape[0] + labels = torch.arange(num_logits, device=device, dtype=torch.long) + + # Calculate loss + clip_loss = ( + F.cross_entropy(logits_per_embedding1, labels) + F.cross_entropy(logits_per_embedding2, labels) + ) / 2 + + return clip_loss diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 34450d74a..77544ca33 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -57,7 +57,14 @@ ProgressSubscriberFactory, ResultsSubscriberFactory, ) -from modalities.loss_functions import CrossEntropyLoss, CrossEntropyLossConfig, NCELoss, NCELossConfig +from modalities.loss_functions import ( + ClipLoss, + ClipLossConfig, + CrossEntropyLoss, + CrossEntropyLossConfig, + NCELoss, + NCELossConfig, +) from modalities.models.coca.coca_model import CoCa, CoCaConfig from modalities.models.coca.collator import CoCaCollateFnConfig, CoCaCollatorFn from modalities.models.components.layer_norms import LayerNormConfig, RMSLayerNorm, RMSLayerNormConfig @@ -103,6 +110,7 @@ class ComponentEntity: # losses ComponentEntity("loss", "cross_entropy_loss", CrossEntropyLoss, CrossEntropyLossConfig), ComponentEntity("loss", "nce_loss", NCELoss, NCELossConfig), + ComponentEntity("loss", "clip_loss", ClipLoss, ClipLossConfig), # optmizers ComponentEntity("optimizer", "adam", OptimizerFactory.get_adam, AdamOptimizerConfig), ComponentEntity("optimizer", "adam_w", OptimizerFactory.get_adam_w, AdamWOptimizerConfig), From 9bdc830b7f2a44c75406bca9fd7829c089748c2c Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Fri, 10 May 2024 11:48:12 +0200 Subject: [PATCH 033/161] fix: use clip loss in coca --- config_files/training/config_example_coca.yaml | 3 ++- config_files/training/config_example_coca_webdataset.yaml | 3 ++- src/modalities/models/coca/coca_model.py | 8 ++++++++ start.sh | 2 +- 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/config_files/training/config_example_coca.yaml b/config_files/training/config_example_coca.yaml index b000ef122..2e7d65213 100644 --- a/config_files/training/config_example_coca.yaml +++ b/config_files/training/config_example_coca.yaml @@ -160,10 +160,11 @@ captioning_loss: contrastive_loss: component_key: loss - variant_key: nce_loss + variant_key: clip_loss config: prediction_key1: ${model.config.vision_cls_prediction_key} prediction_key2: ${model.config.text_cls_prediction_key} + logit_scale_key: ${model.config.logit_scale_prediction_key} tag: contrastive_loss loss_fn: diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml index 0fb51870e..f71ee7aad 100644 --- a/config_files/training/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -140,10 +140,11 @@ captioning_loss: contrastive_loss: component_key: loss - variant_key: nce_loss + variant_key: clip_loss config: prediction_key1: ${model.config.vision_cls_prediction_key} prediction_key2: ${model.config.text_cls_prediction_key} + logit_scale_key: ${model.config.logit_scale_prediction_key} tag: contrastive_loss loss_fn: diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index 0b9125984..885230352 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -2,6 +2,7 @@ from functools import partial from typing import Annotated, Dict, Tuple +import numpy as np import torch from einops import repeat from pydantic import BaseModel, Field @@ -39,6 +40,7 @@ class CoCaConfig(BaseModel): text_embd_prediction_key: str vision_cls_prediction_key: str text_cls_prediction_key: str + logit_scale_prediction_key: str vision_encoder_config: VisionTransformerConfig text_decoder_config: TextDecoderConfig n_pool_head: Annotated[int, Field(ge=1)] @@ -65,6 +67,7 @@ def __init__( text_cls_prediction_key: str, vision_embd_prediction_key: str, text_embd_prediction_key: str, + logit_scale_prediction_key: str, n_vision_queries: int, n_pool_head: int, bias_attn_pool: bool, @@ -79,6 +82,7 @@ def __init__( self.text_cls_prediction_key = text_cls_prediction_key self.vision_embd_prediction_key = vision_embd_prediction_key self.text_embd_prediction_key = text_embd_prediction_key + self.logit_scale_prediction_key = logit_scale_prediction_key self.vision_encoder = VisionTransformer(**dict(vision_encoder_config)) self.text_decoder = TextDecoder( @@ -126,6 +130,9 @@ def __init__( attention_config=text_decoder_config.attention_config, ) + # Logit scale for contrastive loss + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + # init all weights self.apply(partial(self._init_weights, weight_init=weight_init)) # apply special scaled init to the residual projections, per GPT-2 paper @@ -154,6 +161,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: self.prediction_key: logits, self.vision_cls_prediction_key: vision_cls_token, self.text_cls_prediction_key: text_cls_token, + self.logit_scale_prediction_key: self.logit_scale.exp(), } def _forward_encode_vision(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/start.sh b/start.sh index 1d8f5b66f..9878368a0 100644 --- a/start.sh +++ b/start.sh @@ -1,3 +1,3 @@ #!/bin/bash -CUDA_VISIBLE_DEVICES=0 torchrun --nnodes 1 --nproc_per_node 1 --rdzv-endpoint=0.0.0.0:29502 src/modalities/__main__.py run --config_file_path config_files/config_example_coca_webdataset.yaml \ No newline at end of file +CUDA_VISIBLE_DEVICES=0 torchrun --nnodes 1 --nproc_per_node 1 --rdzv-endpoint=0.0.0.0:29502 src/modalities/__main__.py run --config_file_path config_files/training/config_example_coca.yaml \ No newline at end of file From b3aff910995b9b678a0a5a45a39d3986f6e01ecf Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Fri, 10 May 2024 11:52:30 +0200 Subject: [PATCH 034/161] chore: update coca webdataset config --- .../training/config_example_coca_webdataset.yaml | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml index f71ee7aad..c139311e1 100644 --- a/config_files/training/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -51,11 +51,16 @@ train_dataset: text_key: input_ids block_size: ${settings.training.sequence_length} num_samples: 566_747 + shardshuffle: 1000 + repeat: true + resample: true + shuffle: true tokenizer: instance_key: tokenizer pass_type: BY_REFERENCE image_transform_config: is_training: True + input_size: 256 val_dataset: component_key: dataset @@ -73,14 +78,14 @@ val_dataset: pass_type: BY_REFERENCE image_transform_config: is_training: False + input_size: 256 train_dataloader: component_key: data_loader variant_key: web_loader config: - num_workers: 16 + num_workers: 8 pin_memory: true - shuffle: true dataloader_tag: "train" dataset: instance_key: train_dataset @@ -94,9 +99,8 @@ val_dataloader: component_key: data_loader variant_key: web_loader config: - num_workers: 16 + num_workers: 8 pin_memory: true - shuffle: false dataloader_tag: "val" dataset: instance_key: val_dataset @@ -224,7 +228,7 @@ scheduler: max_lr: 8e-4 div_factor: 10 final_div_factor: 1 - total_steps: 491 # 500_000 # depends on 500.000 iterations on 65,536 image-text pairs -> 5 epochs on JFT -> 32.7B image-text pairs + total_steps: 500_000 # depends on 500.000 iterations on 65,536 image-text pairs -> 5 epochs on JFT -> 32.7B image-text pairs pct_start: 0.02 anneal_strategy: cos From dffe644b3209a64cc19c06695153bcc6e4679c93 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Fri, 10 May 2024 11:55:20 +0200 Subject: [PATCH 035/161] chore: update coca webdataset config --- config_files/training/config_example_coca_webdataset.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml index c139311e1..f181a1dc8 100644 --- a/config_files/training/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -178,6 +178,7 @@ model: text_embd_prediction_key: text_embeddings vision_cls_prediction_key: vision_cls text_cls_prediction_key: text_cls + logit_scale_prediction_key: logit_scale vision_encoder_config: sample_key: images prediction_key: vision_embeddings From d8bdebcb06d84d39b2c29f2e4e356d6c75a38ef5 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 11:38:48 +0200 Subject: [PATCH 036/161] fix: gradient accumulation --- src/modalities/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 8eb0db674..3a1cfc271 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -58,7 +58,7 @@ def _train_batch( loss = lfn(result_batch) # Add loss to total loss - weighted_loss = (loss * lfn.weight) / self.gradient_acc_steps + weighted_loss = loss * lfn.weight if total_loss is None: total_loss = weighted_loss else: From 244fac9264747cb1743283293e2d1bc3d014356b Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 13:47:53 +0200 Subject: [PATCH 037/161] fix: normalize cls token of coca --- src/modalities/models/coca/coca_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index 885230352..8a33055a4 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -4,6 +4,7 @@ import numpy as np import torch +import torch.nn.functional as F from einops import repeat from pydantic import BaseModel, Field from torch import nn @@ -168,12 +169,12 @@ def _forward_encode_vision(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch vision_embd = self.vision_encoder(inputs)[self.vision_embd_prediction_key] queries = repeat(self.vision_queries, "n d -> b n d", b=vision_embd.shape[0]) vision_embd = self.attn_pool(queries, context=vision_embd) - vision_embd, vision_cls_token = vision_embd[:, :-1, :], vision_embd[:, -1, :] + vision_embd, vision_cls_token = vision_embd[:, :-1, :], F.normalize(vision_embd[:, -1, :], dim=-1) return vision_embd, vision_cls_token def _forward_encode_text(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: text_embd = self.text_decoder(inputs)[self.text_embd_prediction_key] - text_embd, text_cls_token = text_embd[:, :-1, :], text_embd[:, -1, :] + text_embd, text_cls_token = text_embd[:, :-1, :], F.normalize(text_embd[:, -1, :], dim=-1) return text_embd, text_cls_token def _forward_decode(self, text_embd: torch.Tensor, vision_embd: torch.Tensor) -> torch.Tensor: From 0524acde0387cb42261bd7c285ff45675899840b Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 13:50:21 +0200 Subject: [PATCH 038/161] feat: add weight option to loss config --- src/modalities/loss_functions.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index 3f8b9dd14..6513f971b 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -30,12 +30,13 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: class CrossEntropyLossConfig(BaseModel): target_key: str prediction_key: str + weight: float = 1 tag: str = "CLMCrossEntropyLoss" class CrossEntropyLoss(Loss): - def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEntropyLoss"): - super().__init__(tag) + def __init__(self, target_key: str, prediction_key: str, weight: float, tag: str = "CLMCrossEntropyLoss"): + super().__init__(tag, weight) self.target_key = target_key self.prediction_key = prediction_key # Mean over the tokens in the local-batch (batch per rank) @@ -99,6 +100,7 @@ class NCELossConfig(BaseModel): prediction_key2: str is_asymmetric: bool = True temperature: float = 1.0 + weight: float = 1 tag: str = "NCELoss" @@ -107,8 +109,9 @@ def __init__( self, prediction_key1: str, prediction_key2: str, - is_asymmetric: bool = True, - temperature: float = 1.0, + is_asymmetric: bool, + temperature: float, + weight: float, tag: str = "NCELoss", ): """ @@ -121,7 +124,7 @@ def __init__( temperature (float, optional): temperature. Defaults to 1.0. tag (str, optional): Defaults to "NCELoss". """ - super().__init__(tag) + super().__init__(tag, weight) self.prediction_key1 = prediction_key1 self.prediction_key2 = prediction_key2 self.is_asymmetric = is_asymmetric @@ -151,6 +154,7 @@ class ClipLossConfig(BaseModel): logit_scale_key: str prediction_key1: str prediction_key2: str + weight: float = 1 tag: str = "ClipLoss" @@ -160,6 +164,7 @@ def __init__( logit_scale_key: str, prediction_key1: str, prediction_key2: str, + weight: float, tag: str = "ClipLoss", ): """ @@ -171,7 +176,7 @@ def __init__( prediction_key2 (str): Key to access embedding 2. tag (str, optional): Defaults to "ClipLoss". """ - super().__init__(tag) + super().__init__(tag, weight) self.logit_scale_key = logit_scale_key self.prediction_key1 = prediction_key1 self.prediction_key2 = prediction_key2 From df48b6266f92cd0cb25ea435032a7953bc783986 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 14:24:08 +0200 Subject: [PATCH 039/161] fix: log weighted loss --- src/modalities/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 3a1cfc271..550a38af8 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -65,7 +65,7 @@ def _train_batch( total_loss += weighted_loss # Append individual losses (for logging) - losses.append(loss) + losses.append(weighted_loss) (total_loss / self.gradient_acc_steps).backward() From 2945a03bb3521fc9a9a45dceb87db6b1594cc204 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 14:44:00 +0200 Subject: [PATCH 040/161] feat: add cosine scheduler with warmup --- src/modalities/config/config.py | 7 +++++++ src/modalities/registry/components.py | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index d0a5828db..30b2299d4 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -183,6 +183,13 @@ class CosineAnnealingLRSchedulerConfig(BaseModel): verbose: bool = False +class CosineAnnealingWithWarmupLRSchedulerConfig(BaseModel): + optimizer: PydanticOptimizerIFType + num_warmup_steps: Annotated[int, Field(strict=True, gt=0)] + num_training_steps: Annotated[int, Field(strict=True, gt=0)] + last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1 + + class CheckpointedOptimizerConfig(BaseModel): checkpoint_loading: PydanticCheckpointLoadingIFType checkpoint_path: Path diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 77544ca33..739bf6506 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn +import transformers from pydantic import BaseModel from torch.utils.data import BatchSampler, DistributedSampler @@ -23,6 +24,7 @@ CheckpointSavingConfig, ConstantLRSchedulerConfig, CosineAnnealingLRSchedulerConfig, + CosineAnnealingWithWarmupLRSchedulerConfig, DistributedSamplerConfig, DummyLRSchedulerConfig, DummyProgressSubscriberConfig, @@ -125,6 +127,12 @@ class ComponentEntity: ComponentEntity( "scheduler", "cosine_annealing_lr", torch.optim.lr_scheduler.CosineAnnealingLR, CosineAnnealingLRSchedulerConfig ), + ComponentEntity( + "scheduler", + "cosine_annealing_with_warmup_lr", + transformers.get_linear_schedule_with_warmup, + CosineAnnealingWithWarmupLRSchedulerConfig, + ), # tokenizers ComponentEntity("tokenizer", "pretrained_hf_tokenizer", PreTrainedHFTokenizer, PreTrainedHFTokenizerConfig), ComponentEntity("tokenizer", "pretrained_sp_tokenizer", PreTrainedSPTokenizer, PreTrainedSPTokenizerConfig), From 2d0e3f37af88b1d32f6dd8d466cfb3c7b4a88c1d Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 15:12:27 +0200 Subject: [PATCH 041/161] fix: loss logging --- src/modalities/trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 550a38af8..acf6afaf0 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -55,17 +55,16 @@ def _train_batch( losses = [] for lfn in loss_fun: # Calculate loss - loss = lfn(result_batch) + weighted_loss = lfn(result_batch) * lfn.weight # Add loss to total loss - weighted_loss = loss * lfn.weight if total_loss is None: total_loss = weighted_loss else: total_loss += weighted_loss # Append individual losses (for logging) - losses.append(weighted_loss) + losses.append(weighted_loss.clone().detach()) (total_loss / self.gradient_acc_steps).backward() From 1c0993fa13b6281ae35f8fd381066660f8c2966f Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 15:57:25 +0200 Subject: [PATCH 042/161] feat: add local clip loss --- src/modalities/loss_functions.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index 6513f971b..311b0788e 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -155,6 +155,7 @@ class ClipLossConfig(BaseModel): prediction_key1: str prediction_key2: str weight: float = 1 + local_loss: bool = True tag: str = "ClipLoss" @@ -165,6 +166,7 @@ def __init__( prediction_key1: str, prediction_key2: str, weight: float, + local_loss: bool, tag: str = "ClipLoss", ): """ @@ -180,6 +182,7 @@ def __init__( self.logit_scale_key = logit_scale_key self.prediction_key1 = prediction_key1 self.prediction_key2 = prediction_key2 + self.local_loss = local_loss def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: """ @@ -203,20 +206,27 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: dist.all_gather(gathered_embedding2, embedding2) # Make sure we have gradients for the "local" embeddings - gathered_embedding1[rank] = embedding1 - gathered_embedding2[rank] = embedding2 + if not self.local_loss: + gathered_embedding1[rank] = embedding1 + gathered_embedding2[rank] = embedding2 # Combine embeddings gathered_embedding1 = torch.cat(gathered_embedding1, dim=0) gathered_embedding2 = torch.cat(gathered_embedding2, dim=0) # Calculate logits - logits_per_embedding1 = logit_scale * gathered_embedding1 @ gathered_embedding2.T - logits_per_embedding2 = logits_per_embedding1.T + if self.local_loss: + logits_per_embedding1 = logit_scale * embedding1 @ gathered_embedding2.T + logits_per_embedding2 = logit_scale * embedding2 @ gathered_embedding1.T + else: + logits_per_embedding1 = logit_scale * gathered_embedding1 @ gathered_embedding2.T + logits_per_embedding2 = logits_per_embedding1.T # Build gt labels for diagonal num_logits = logits_per_embedding1.shape[0] labels = torch.arange(num_logits, device=device, dtype=torch.long) + if world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank # Calculate loss clip_loss = ( From 0e3b239fbbbebb24a11407b6a59d8346be7e3216 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 16:01:13 +0200 Subject: [PATCH 043/161] fix: clip loss --- src/modalities/loss_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index 311b0788e..8c1ec96ba 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -226,7 +226,7 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: num_logits = logits_per_embedding1.shape[0] labels = torch.arange(num_logits, device=device, dtype=torch.long) if world_size > 1 and self.local_loss: - labels = labels + num_logits * self.rank + labels = labels + num_logits * rank # Calculate loss clip_loss = ( From 7459ee098faeb719f21ac75dd24041f2f60109b6 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 16:44:18 +0200 Subject: [PATCH 044/161] fix: add barrier to eval --- src/modalities/evaluator.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index c02f9f38d..a8f7fff4c 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -74,6 +74,10 @@ def evaluate( dataloader_tag=data_loader.dataloader_tag, ) thoughput_aggregator = Aggregator[ThroughputAggregationKeys]() + + dist.barrier() + print("All ranks reached the eval step") + with TimeRecorder() as forward_backward_timer_recorder: for batch_id, batch in enumerate(data_loader): batch_losses = self.evaluate_batch( @@ -96,6 +100,10 @@ def evaluate( dataloader_tag=data_loader.dataloader_tag, ) + print(f"Rank {dist.get_rank()} is done with eval step") + dist.barrier() + print("All ranks are done with the eval step") + # TODO: insert reducer from outside so Evaluator is independent of FSDP forward_backward_time = torch.tensor(forward_backward_timer_recorder.delta_t).to(device) thoughput_aggregator.add_value( From 1c0f4565ddee5eb78aab0da2f40dbd5a9324b75b Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 16:47:18 +0200 Subject: [PATCH 045/161] feat: print global batch size --- src/modalities/__main__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index bb4addff8..82d84cd89 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -227,6 +227,12 @@ def run(self, components: TrainingComponentsInstantiationModel): f"Training model with {compute_number_of_trainable_parameters(wrapped_model)} parameters (per process)." ) + # Print global batch size + world_size = dist.get_world_size() + acc_steps = components.settings.training.gradient_acc_steps + local_batch_size = components.settings.training.local_train_micro_batch_size + print(f"Training model with a global batch size of {world_size * acc_steps* local_batch_size} samples.") + if components.settings.training.do_apply_activation_checkpointing: apply_activation_checkpointing_inplace(wrapped_model) From a04bf97008f3a7b156f76a82d1032f738cd5cb29 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 16:54:12 +0200 Subject: [PATCH 046/161] fix: force integer for rank env variable --- src/modalities/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 82d84cd89..3d567bc3e 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -221,7 +221,7 @@ def run(self, components: TrainingComponentsInstantiationModel): ) wrapped_model = components.wrapped_model - if os.environ["RANK"] == 0: + if int(os.environ["RANK"]) == 0: # TODO calculate parameters for full model print( f"Training model with {compute_number_of_trainable_parameters(wrapped_model)} parameters (per process)." From 9086af6550117a48fa9b8bebb524c6736bb97cde Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 19:11:16 +0200 Subject: [PATCH 047/161] fix: validation set loading --- .../config_example_coca_webdataset.yaml | 27 +++++++++---------- src/modalities/dataloader/dataset.py | 9 +++++-- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml index f181a1dc8..62225a29a 100644 --- a/config_files/training/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -5,13 +5,13 @@ settings: sample_key: input_ids target_key: target_ids training: - global_training_log_interval_in_steps: 4 - global_checkpointing_interval_in_steps: 100 - global_evaluation_interval_in_steps: 100 + global_training_log_interval_in_steps: 30 # Needs to be a multiple of gradient_acc_steps + global_checkpointing_interval_in_steps: 30 + global_evaluation_interval_in_steps: 30 global_num_training_samples: ${train_dataset.config.num_samples} # 491 steps with 8 gpus and global bs of 1152 global_num_seen_steps: 0 do_apply_activation_checkpointing: true - gradient_acc_steps: 1 + gradient_acc_steps: 30 local_train_micro_batch_size: 144 sequence_length: 256 cuda_env: @@ -44,7 +44,7 @@ train_dataset: component_key: dataset variant_key: web_dataset config: - urls: "coco/train/{000000..000566}.tar" + urls: "/hpcwork/rwth1597/coco/train/{000000..000566}.tar" source_image_key: jpg image_key: images source_text_key: txt @@ -66,7 +66,7 @@ val_dataset: component_key: dataset variant_key: web_dataset config: - urls: "coco/val/{000000..000025}.tar" + urls: "/hpcwork/rwth1597/coco/val/{000000..000025}.tar" source_image_key: jpg image_key: images source_text_key: txt @@ -141,6 +141,7 @@ captioning_loss: target_key: ${settings.referencing_keys.target_key} prediction_key: ${model.config.prediction_key} tag: captioning_loss + weight: 2.0 contrastive_loss: component_key: loss @@ -150,6 +151,7 @@ contrastive_loss: prediction_key2: ${model.config.text_cls_prediction_key} logit_scale_key: ${model.config.logit_scale_prediction_key} tag: contrastive_loss + weight: 1.0 loss_fn: - instance_key: captioning_loss @@ -221,17 +223,13 @@ model: scheduler: component_key: scheduler - variant_key: onecycle_lr # COCA uses linear decay + variant_key: cosine_annealing_with_warmup_lr # COCA uses linear decay config: optimizer: instance_key: optimizer pass_type: BY_REFERENCE - max_lr: 8e-4 - div_factor: 10 - final_div_factor: 1 - total_steps: 500_000 # depends on 500.000 iterations on 65,536 image-text pairs -> 5 epochs on JFT -> 32.7B image-text pairs - pct_start: 0.02 - anneal_strategy: cos + num_warmup_steps: 2_000 + num_training_steps: 500_000 optimizer: component_key: optimizer @@ -247,12 +245,13 @@ optimizer: gradient_clipper: component_key: gradient_clipper - variant_key: fsdp_logging_only + variant_key: fsdp config: wrapped_model: instance_key: wrapped_model pass_type: BY_REFERENCE norm_type: P2_NORM + max_norm: 1.0 batch_progress_subscriber: component_key: progress_subscriber diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 74218a493..6e588a134 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -293,9 +293,12 @@ def __init__( resample: bool, shuffle: int, ): + # Dont apply nodesplitting + # This is not required for training due to resample + # For validation the datasets are small and we dont get an even split between all nodes super().__init__( urls=urls, - nodesplitter=nodesplitter if not resample else None, + nodesplitter=None, shardshuffle=shardshuffle, repeat=repeat, handler=wds.ignore_and_continue, @@ -303,7 +306,9 @@ def __init__( ) self.num_samples = num_samples - self.append(wds.filters.shuffle(shuffle)) + if shuffle > 0: + self.append(wds.filters.shuffle(shuffle)) + self.append(wds.filters.decode("pil")) tokenizer.tokenizer.pad_token = tokenizer.tokenizer.eos_token From 351990ab3cbaadee26d80b5182bbec6348aab839 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 19:23:26 +0200 Subject: [PATCH 048/161] fix: webdataset splitter --- src/modalities/dataloader/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 6e588a134..2d579482f 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -298,7 +298,7 @@ def __init__( # For validation the datasets are small and we dont get an even split between all nodes super().__init__( urls=urls, - nodesplitter=None, + nodesplitter=wds.single_node_only if not resample else None, shardshuffle=shardshuffle, repeat=repeat, handler=wds.ignore_and_continue, From c90df9d51f8df0a3e6491425b8c1d16eef997c59 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 19:26:17 +0200 Subject: [PATCH 049/161] feat: add drop_last to webloader --- config_files/training/config_example_coca_webdataset.yaml | 6 +++++- src/modalities/config/config.py | 1 + src/modalities/dataloader/dataloader.py | 3 ++- src/modalities/dataloader/dataloader_factory.py | 8 +++++++- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml index 62225a29a..b9daee357 100644 --- a/config_files/training/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -54,7 +54,7 @@ train_dataset: shardshuffle: 1000 repeat: true resample: true - shuffle: true + shuffle: 1_000 tokenizer: instance_key: tokenizer pass_type: BY_REFERENCE @@ -73,6 +73,8 @@ val_dataset: text_key: input_ids block_size: ${settings.training.sequence_length} num_samples: 25_010 + resample: false + shuffle: 0 # Disable shuffling tokenizer: instance_key: tokenizer pass_type: BY_REFERENCE @@ -86,6 +88,7 @@ train_dataloader: config: num_workers: 8 pin_memory: true + drop_last: true dataloader_tag: "train" dataset: instance_key: train_dataset @@ -101,6 +104,7 @@ val_dataloader: config: num_workers: 8 pin_memory: true + drop_last: false dataloader_tag: "val" dataset: instance_key: val_dataset diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 30b2299d4..fc950845e 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -313,6 +313,7 @@ class WebLoaderConfig(BaseModel): collate_fn: PydanticCollateFnIFType num_workers: Annotated[int, Field(strict=True, ge=0)] pin_memory: bool + drop_last: bool class RepeatingDataLoaderConfig(BaseModel): diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index c92321f74..436315912 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -157,8 +157,9 @@ def __init__( num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, pin_memory: bool = False, + drop_last: bool = False, ): - self.num_batches = len(dataset) // batch_size + self.num_batches = len(dataset) // batch_size + int(not drop_last) dataset = dataset.batched(batch_size, collation_fn=collate_fn) self.webloader = wd.WebLoader(dataset=dataset, batch_size=None, num_workers=num_workers, pin_memory=pin_memory) self.webloader = self.webloader.with_epoch(self.num_batches) diff --git a/src/modalities/dataloader/dataloader_factory.py b/src/modalities/dataloader/dataloader_factory.py index 6241bbd62..00bebfedf 100644 --- a/src/modalities/dataloader/dataloader_factory.py +++ b/src/modalities/dataloader/dataloader_factory.py @@ -41,7 +41,13 @@ def get_repeating_dataloader( @staticmethod def get_web_loader( - dataloader_tag: str, dataset: Dataset, batch_size: int, collate_fn: Callable, num_workers: int, pin_memory: bool + dataloader_tag: str, + dataset: Dataset, + batch_size: int, + collate_fn: Callable, + num_workers: int, + pin_memory: bool, + drop_last: bool, ) -> WebLoader: dataloader = WebLoader( dataloader_tag=dataloader_tag, From 860b8fd00f5d2bdf3fd79c2f910414b09cb45ee3 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 22:48:08 +0200 Subject: [PATCH 050/161] fix: val dataset with webdataset --- src/modalities/dataloader/dataset.py | 32 ++++++++-------------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 2d579482f..4a4c4e52d 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -6,7 +6,6 @@ import jq import numpy as np -import torch import webdataset as wds from pydantic import BaseModel from timm.data import create_transform @@ -259,21 +258,13 @@ class WebDatasetConfig(BaseModel): shuffle: int = 0 -def nodesplitter(src, group=None): - if torch.distributed.is_initialized(): - if group is None: - group = torch.distributed.group.WORLD - rank = torch.distributed.get_rank(group=group) - size = torch.distributed.get_world_size(group=group) - print(f"nodesplitter: rank={rank} size={size}") - count = 0 - for i, item in enumerate(src): - if i % size == rank: - yield item - count += 1 - print(f"nodesplitter: rank={rank} size={size} count={count} DONE") - else: - yield from src +def dummy_nodesplitter(src, group=None): + # This node splitter is not actually splitting the data over the nodes + # but keeps the complete dataset on each node. + # This is required so that each node has the same amount of data. + # In the case of 25 shards and 16 ranks for example 7 ranks are + # without data in the second iteration. This will cause a crash once all_gather is called. + yield from src class WebDataset(wds.WebDataset): @@ -293,30 +284,25 @@ def __init__( resample: bool, shuffle: int, ): - # Dont apply nodesplitting - # This is not required for training due to resample - # For validation the datasets are small and we dont get an even split between all nodes super().__init__( urls=urls, - nodesplitter=wds.single_node_only if not resample else None, + nodesplitter=wds.dummy_nodesplitter if not resample else None, shardshuffle=shardshuffle, repeat=repeat, handler=wds.ignore_and_continue, resampled=resample, ) self.num_samples = num_samples + tokenizer.tokenizer.pad_token = tokenizer.tokenizer.eos_token if shuffle > 0: self.append(wds.filters.shuffle(shuffle)) self.append(wds.filters.decode("pil")) - tokenizer.tokenizer.pad_token = tokenizer.tokenizer.eos_token - transform = create_transform(**image_transform_config.model_dump()) def make_sample(sample): - # print(sample["json"]) batch_encoding: BatchEncoding = tokenizer.tokenizer( sample[source_text_key], max_length=block_size, From 962c4caa4c17d8da43d5548b8bad3ca74331a1d8 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 22:48:29 +0200 Subject: [PATCH 051/161] refactor: evaluator --- src/modalities/evaluator.py | 87 +++++++++++++++++++------------------ 1 file changed, 44 insertions(+), 43 deletions(-) diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index a8f7fff4c..def731501 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -32,27 +32,26 @@ def evaluate_batch( model: nn.Module, loss_fun: List[Loss], ): - with torch.no_grad(): - result_batch = model_predict_batch(model=model, batch=batch) + result_batch = model_predict_batch(model=model, batch=batch) - total_loss = None - losses = [] - for lfn in loss_fun: - # Calculate loss - loss = lfn(result_batch) + total_loss = None + losses = [] + for lfn in loss_fun: + # Calculate loss + weighted_loss = lfn(result_batch) * lfn.weight - # Add loss to total loss - weighted_loss = loss * lfn.weight - if total_loss is None: - total_loss = weighted_loss - else: - total_loss += weighted_loss + # Add loss to total loss + if total_loss is None: + total_loss = weighted_loss + else: + total_loss += weighted_loss - # Append individual losses (for logging) - losses.append(loss) + # Append individual losses (for logging) + losses.append(weighted_loss.clone().detach()) return total_loss, *losses + @torch.no_grad() def evaluate( self, model: nn.Module, @@ -75,37 +74,39 @@ def evaluate( ) thoughput_aggregator = Aggregator[ThroughputAggregationKeys]() + # Make sure that all ranks reach this point at the same time dist.barrier() - print("All ranks reached the eval step") - - with TimeRecorder() as forward_backward_timer_recorder: - for batch_id, batch in enumerate(data_loader): - batch_losses = self.evaluate_batch( - batch=batch, - model=model, - loss_fun=loss_fun, - ) - - # Accumulate losses - for i, batch_loss in enumerate(batch_losses): - cumulated_loss[i] += batch_loss.item() - cumulated_loss[-1] += 1 - - batch_length_tensor = torch.tensor(len(batch)).to(device) - thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) - - Evaluator._publish_progress( - batch_progress_publisher=self.batch_progress_publisher, - eval_step_id=batch_id, - dataloader_tag=data_loader.dataloader_tag, - ) - - print(f"Rank {dist.get_rank()} is done with eval step") - dist.barrier() - print("All ranks are done with the eval step") + + forward_backward_time_recorder = TimeRecorder() + forward_backward_time_recorder.start() + for batch_id, batch in enumerate(data_loader): + batch_losses = self.evaluate_batch( + batch=batch, + model=model, + loss_fun=loss_fun, + ) + forward_backward_time_recorder.stop() + + # Accumulate losses + for i, batch_loss in enumerate(batch_losses): + cumulated_loss[i] += batch_loss.item() + cumulated_loss[-1] += 1 + + batch_length_tensor = torch.tensor(len(batch)).to(device) + thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) + + Evaluator._publish_progress( + batch_progress_publisher=self.batch_progress_publisher, + eval_step_id=batch_id, + dataloader_tag=data_loader.dataloader_tag, + ) + + # we start the time recoder here again to also capture the time spend loading + # via the dataloader. + forward_backward_time_recorder.start() # TODO: insert reducer from outside so Evaluator is independent of FSDP - forward_backward_time = torch.tensor(forward_backward_timer_recorder.delta_t).to(device) + forward_backward_time = torch.tensor(forward_backward_time_recorder.delta_t).to(device) thoughput_aggregator.add_value( key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time ) From 2d1ea92f188dd48892857c0228ceedda2e69ca57 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 22:52:29 +0200 Subject: [PATCH 052/161] fix: nodesplitter in webdataset --- src/modalities/dataloader/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 4a4c4e52d..4e8d33995 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -286,7 +286,7 @@ def __init__( ): super().__init__( urls=urls, - nodesplitter=wds.dummy_nodesplitter if not resample else None, + nodesplitter=dummy_nodesplitter if not resample else None, shardshuffle=shardshuffle, repeat=repeat, handler=wds.ignore_and_continue, From 1fb4fba7505d2bbde95acf67bce0181003e76d04 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 22:59:30 +0200 Subject: [PATCH 053/161] feat: add wandb grouping --- .../logging_broker/subscriber_impl/results_subscriber.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py index 8a5d8dc4a..c5cc99c02 100644 --- a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py @@ -68,11 +68,11 @@ def consume_message(self, message: Message[EvaluationResultBatch]): eval_result = message.payload losses = { - f"{eval_result.dataloader_tag} {loss_key}": loss_values + f"{eval_result.dataloader_tag}/{loss_key}": loss_values for loss_key, loss_values in eval_result.losses.items() } metrics = { - f"{eval_result.dataloader_tag} {metric_key}": metric_values + f"{eval_result.dataloader_tag}/{metric_key}": metric_values for metric_key, metric_values in eval_result.metrics.items() } # TODO step is not semantically correct here. Need to check if we can rename step to num_samples @@ -83,7 +83,7 @@ def consume_message(self, message: Message[EvaluationResultBatch]): data=metrics, step=eval_result.train_step_id + 1 ) # (eval_result.train_local_sample_id + 1) * self.num_ranks) throughput_metrics = { - f"{eval_result.dataloader_tag} {metric_key}": metric_values + f"{eval_result.dataloader_tag}/{metric_key}": metric_values for metric_key, metric_values in eval_result.throughput_metrics.items() } From 22c77cd57bb2dca0c2350e78cdde10b8a4bed62a Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 23:08:12 +0200 Subject: [PATCH 054/161] chore: update coca config --- .../training/config_example_coca_webdataset.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml index b9daee357..e07339636 100644 --- a/config_files/training/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -5,15 +5,15 @@ settings: sample_key: input_ids target_key: target_ids training: - global_training_log_interval_in_steps: 30 # Needs to be a multiple of gradient_acc_steps - global_checkpointing_interval_in_steps: 30 - global_evaluation_interval_in_steps: 30 + global_training_log_interval_in_steps: 129_600 # Needs to be a multiple of gradient_acc_steps + global_checkpointing_interval_in_steps: 9_990 + global_evaluation_interval_in_steps: 4_980 global_num_training_samples: ${train_dataset.config.num_samples} # 491 steps with 8 gpus and global bs of 1152 global_num_seen_steps: 0 do_apply_activation_checkpointing: true gradient_acc_steps: 30 local_train_micro_batch_size: 144 - sequence_length: 256 + sequence_length: 64 cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK} @@ -27,7 +27,7 @@ tokenizer: config: pretrained_model_name_or_path: gpt2 padding: true - max_length: 256 + max_length: ${settings.training.sequence_length} collate_fn: component_key: collate_fn From 6d980e76c3fa8b7639ccaf89091bd6285ba5c938 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 13 May 2024 23:14:29 +0200 Subject: [PATCH 055/161] fix: coca config --- config_files/training/config_example_coca_webdataset.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml index e07339636..70a2d304c 100644 --- a/config_files/training/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -5,7 +5,7 @@ settings: sample_key: input_ids target_key: target_ids training: - global_training_log_interval_in_steps: 129_600 # Needs to be a multiple of gradient_acc_steps + global_training_log_interval_in_steps: 30 # Needs to be a multiple of gradient_acc_steps global_checkpointing_interval_in_steps: 9_990 global_evaluation_interval_in_steps: 4_980 global_num_training_samples: ${train_dataset.config.num_samples} # 491 steps with 8 gpus and global bs of 1152 From dfb788444ac1838804543f908d37ed006e1884b7 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 10 Jun 2024 15:51:53 +0200 Subject: [PATCH 056/161] feat: add universal multimodal dataset --- src/modalities/dataloader/dataset.py | 192 +++++++++++++++++++++------ 1 file changed, 149 insertions(+), 43 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 4e8d33995..40c935768 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -10,6 +10,7 @@ from pydantic import BaseModel from timm.data import create_transform from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torch.utils.data import IterableDataset from torch.utils.data.dataset import Dataset as TorchdataSet from tqdm import tqdm from transformers import BatchEncoding @@ -213,7 +214,18 @@ def _generate_packing_index(self) -> List[Tuple[int, int]]: return index -class ImageTransformConfig(BaseModel): +class ModalityEnum(Enum): + TEXT = "text" + IMAGE = "image" + VIDEO = "video" + AUDIO = "audio" + + +class TransformConfig(BaseModel): + pass + + +class ImageTransformConfig(TransformConfig): input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224 is_training: bool = False no_aug: bool = False @@ -242,6 +254,14 @@ class ImageTransformConfig(BaseModel): separate: bool = False +class TextTransformConfig(TransformConfig): + tokenizer: TokenizerWrapper + max_length: int = 77 + padding: str = "max_length" + truncation: bool = True + return_attention_mask: bool = True + + class WebDatasetConfig(BaseModel): urls: Union[List[str], str] source_image_key: str @@ -267,57 +287,143 @@ def dummy_nodesplitter(src, group=None): yield from src -class WebDataset(wds.WebDataset): +class WebDataset(IterableDataset): def __init__( self, - urls: Union[List[str], str], - source_image_key: str, - image_key: str, - source_text_key: str, - text_key: str, - tokenizer: TokenizerWrapper, - block_size: int, - num_samples: int, - image_transform_config: ImageTransformConfig, - shardshuffle: int, - repeat: bool, - resample: bool, - shuffle: int, + urls: Dict[Tuple[ModalityEnum, ...], Union[List[str], str]], + modality_key_mapping: Dict[ModalityEnum, Tuple[str, str]], + modality_transforms_configs: Dict[ModalityEnum, TransformConfig], + num_samples: Dict[Tuple[ModalityEnum, ...], int], + mixing_ratios: Optional[Dict[Tuple[ModalityEnum, ...], int]] = None, + shardshuffle: int = 100, + repeat: bool = False, + resample: bool = True, + shuffle_buffer: Optional[int] = 10_000, ): - super().__init__( - urls=urls, - nodesplitter=dummy_nodesplitter if not resample else None, - shardshuffle=shardshuffle, - repeat=repeat, - handler=wds.ignore_and_continue, - resampled=resample, - ) + """WebDataset for loading and combining multimodal datasets. + + Args: + urls: Webdataset urls for each modality combination. + For example: {(ModalityEnum.IMAGE, ModalityEnum.TEXT): "/data/path/{00000..00012.tar"}} + + modality_key_mapping: Mapping from dataset keys to keys expected by the forward pass of the model. + For example: {ModalityEnum.IMAGE: ("jpg", "image"), ModalityEnum.TEXT: ("text", "caption")}} + modality_transforms_configs: The transform config for each modality. + num_samples: The number of samples for each modality combination. + For example: {(ModalityEnum.IMAGE, ModalityEnum.TEXT): 1_234_567}} + mixing_ratios: Mixing ratios of the different modality combinations. + For example: { + (ModalityEnum.IMAGE, ModalityEnum.TEXT): 0.7, + (ModalityEnum.VIDEO, ModalityEnum.TEXT): 0.3} + } + shardshuffle: Number of sharfs that should be used for shuffling. Defaults to 100. + repeat: Repeat the dataset. Defaults to False. + resample: Instead if iterating in order sample random shards. + This has the issue that the model will see sample multiple times but if significantly more + efficient. Defaults to True. + shuffle_buffer: Number of samples that should be used for shuffling. Defaults to 10_000. + """ self.num_samples = num_samples - tokenizer.tokenizer.pad_token = tokenizer.tokenizer.eos_token + self.total_num_samples = sum([self.num_samples[k] for k in urls.keys()]) + self.modality_key_mapping = modality_key_mapping + self.modality_transforms_configs = modality_transforms_configs + + # Create webdatasets for each modality combination + self.web_datasets = { + k: wds.WebDataset( + urls=u, + nodesplitter=dummy_nodesplitter if not resample else None, + shardshuffle=shardshuffle, + repeat=repeat, + handler=wds.ignore_and_continue, + resampled=resample, + ) + for k, u in urls.items() + } + + # Setup mixing ratios + self.mixing_ratios = mixing_ratios + if self.mixing_ratios is None and len(self.web_datasets) > 1: + uniform_ratio = 1 / len(self.web_datasets) + self.mixing_ratios = {k: uniform_ratio for k in self.web_datasets.keys()} + + # Mapping between modality and the decode "function" + self.modality_to_decode_fn = { + ModalityEnum.TEXT: None, + ModalityEnum.IMAGE: "pil", + ModalityEnum.VIDEO: wds.torch_video, + ModalityEnum.AUDIO: wds.torch_audio, + } + + # Some transforms require objects such as image + if ModalityEnum.IMAGE in self.modality_transforms_configs: + self._timm_image_transform = create_transform( + **self.modality_transforms_configs[ModalityEnum.IMAGE].model_dump() + ) - if shuffle > 0: - self.append(wds.filters.shuffle(shuffle)) + # Mapping between modality and transform + self.modality_to_transform_fn = { + ModalityEnum.TEXT: self._transform_text, + ModalityEnum.IMAGE: self._transform_image, + ModalityEnum.VIDEO: self._transform_video, + ModalityEnum.AUDIO: self._transform_audio, + } + + for k, web_dataset in self.web_datasets: + # Apply shuffling to samples + if shuffle_buffer is not None and shuffle_buffer > 0: + web_dataset.append(wds.filters.shuffle(shuffle_buffer)) + + # Load the actual data + for modality_key in k: + transform_fn = self.modality_to_transform_fn[modality_key] + self.append(wds.filters.map(transform_fn)) + + def _transform_text(self, sample): + source_key, target_key = self.modality_key_mapping[ModalityEnum.TEXT] + config: TextTransformConfig = self.modality_transforms_configs[ModalityEnum.TEXT] + batch_encoding: BatchEncoding = config.tokenizer.tokenizer( + sample[source_key], + max_length=config.block_size, + padding=config.padding, + truncation=config.truncation, + return_attention_mask=config.return_attention_mask, + ) + del sample[source_key] + sample[target_key] = batch_encoding.input_ids + sample["attention_mask"] = batch_encoding.attention_mask + return sample - self.append(wds.filters.decode("pil")) + def _transform_image(self, sample): + source_key, target_key = self.modality_key_mapping[ModalityEnum.IMAGE] + sample[target_key] = self._timm_image_transform(sample[source_key]) + del sample[source_key] + return sample - transform = create_transform(**image_transform_config.model_dump()) + def _transform_video(self, sample): + source_key, target_key = self.modality_key_mapping[ModalityEnum.VIDEO] + # config: VideoTransformConfig = self.modality_transforms_configs[ModalityEnum.VIDEO] + # TODO add video transform + return sample - def make_sample(sample): - batch_encoding: BatchEncoding = tokenizer.tokenizer( - sample[source_text_key], - max_length=block_size, - padding="max_length", - truncation=True, - return_attention_mask=True, - ) + def _transform_audio(self, sample): + source_key, target_key = self.modality_key_mapping[ModalityEnum.AUDIO] + # config: AudioTransformConfig = self.modality_transforms_configs[ModalityEnum.AUDIO] + # TODO add audio transform + return sample - return { - image_key: transform(sample[source_image_key]), - text_key: batch_encoding.input_ids, - "attention_mask": batch_encoding.attention_mask, - } + def __iter__(self): + if len(self.web_datasets) > 1: + datasets = [] + ratios = [] + for k in self.web_datasets.keys(): + datasets.append(self.web_datasets[k]) + ratios.append(self.mixing_ratios[k]) + dataset = wds.RandomMix(datasets, ratios) # Apply mixing at sample level + return iter(dataset) - self.append(wds.filters.map(make_sample)) + dataset = next(iter(self.web_datasets.values())) + return iter(dataset) def __len__(self): - return self.num_samples + return self.total_num_samples From da27f0d8407e34e34c58115fbcc698c5a6b68dc2 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 10 Jun 2024 17:46:40 +0200 Subject: [PATCH 057/161] feat: add flatten_dict function --- src/modalities/util.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/modalities/util.py b/src/modalities/util.py index e769262df..b6dd70a53 100644 --- a/src/modalities/util.py +++ b/src/modalities/util.py @@ -136,3 +136,25 @@ def get_all_reduced_value( post_processing_fun=postprocessing_fun, # lambda t: t[0] / t[1], ) return value + + +def flatten_dict(d, parent_key="", sep="_"): + """ + Flatten a nested dictionary. + + Args: + d: The dictionary to flatten. + parent_key: The base key to use for concatenation. + sep: The separator to use between concatenated keys. + + Return: + A flattened dictionary with concatenated keys. + """ + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) From 016bed34077837243b2cc1451eaa50c7da690598 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 10 Jun 2024 17:47:37 +0200 Subject: [PATCH 058/161] fix: webdataset --- src/modalities/dataloader/dataset.py | 190 +++++++++++++++++---------- 1 file changed, 124 insertions(+), 66 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 40c935768..d2f8f92b9 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -16,10 +16,10 @@ from transformers import BatchEncoding from modalities.config.config import PydanticTokenizerIFType +from modalities.dataloader.create_packed_data import EmbeddedStreamData +from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper - -from ..dataloader.large_file_lines_reader import LargeFileLinesReader -from .create_packed_data import EmbeddedStreamData +from modalities.util import flatten_dict class Dataset(TorchdataSet): @@ -255,7 +255,7 @@ class ImageTransformConfig(TransformConfig): class TextTransformConfig(TransformConfig): - tokenizer: TokenizerWrapper + tokenizer: PydanticTokenizerIFType max_length: int = 77 padding: str = "max_length" truncation: bool = True @@ -287,65 +287,30 @@ def dummy_nodesplitter(src, group=None): yield from src -class WebDataset(IterableDataset): +class MultimodalWebDatasetBuilder: def __init__( self, - urls: Dict[Tuple[ModalityEnum, ...], Union[List[str], str]], + urls: Union[List[str], str], modality_key_mapping: Dict[ModalityEnum, Tuple[str, str]], modality_transforms_configs: Dict[ModalityEnum, TransformConfig], - num_samples: Dict[Tuple[ModalityEnum, ...], int], - mixing_ratios: Optional[Dict[Tuple[ModalityEnum, ...], int]] = None, - shardshuffle: int = 100, - repeat: bool = False, - resample: bool = True, - shuffle_buffer: Optional[int] = 10_000, + num_samples: int, ): - """WebDataset for loading and combining multimodal datasets. + """A multimodal dataset instance for the WebDataset. Args: - urls: Webdataset urls for each modality combination. - For example: {(ModalityEnum.IMAGE, ModalityEnum.TEXT): "/data/path/{00000..00012.tar"}} - + urls: A webdataset url. For example: "/data/path/{00000..00012.tar". modality_key_mapping: Mapping from dataset keys to keys expected by the forward pass of the model. For example: {ModalityEnum.IMAGE: ("jpg", "image"), ModalityEnum.TEXT: ("text", "caption")}} modality_transforms_configs: The transform config for each modality. num_samples: The number of samples for each modality combination. - For example: {(ModalityEnum.IMAGE, ModalityEnum.TEXT): 1_234_567}} - mixing_ratios: Mixing ratios of the different modality combinations. - For example: { - (ModalityEnum.IMAGE, ModalityEnum.TEXT): 0.7, - (ModalityEnum.VIDEO, ModalityEnum.TEXT): 0.3} - } - shardshuffle: Number of sharfs that should be used for shuffling. Defaults to 100. - repeat: Repeat the dataset. Defaults to False. - resample: Instead if iterating in order sample random shards. - This has the issue that the model will see sample multiple times but if significantly more - efficient. Defaults to True. - shuffle_buffer: Number of samples that should be used for shuffling. Defaults to 10_000. """ - self.num_samples = num_samples - self.total_num_samples = sum([self.num_samples[k] for k in urls.keys()]) + self.urls = urls self.modality_key_mapping = modality_key_mapping self.modality_transforms_configs = modality_transforms_configs - - # Create webdatasets for each modality combination - self.web_datasets = { - k: wds.WebDataset( - urls=u, - nodesplitter=dummy_nodesplitter if not resample else None, - shardshuffle=shardshuffle, - repeat=repeat, - handler=wds.ignore_and_continue, - resampled=resample, - ) - for k, u in urls.items() - } - - # Setup mixing ratios - self.mixing_ratios = mixing_ratios - if self.mixing_ratios is None and len(self.web_datasets) > 1: - uniform_ratio = 1 / len(self.web_datasets) - self.mixing_ratios = {k: uniform_ratio for k in self.web_datasets.keys()} + assert self.modality_key_mapping.keys() == self.modality_transforms_configs.keys() + self.modalities = list(self.modality_key_mapping.keys()) + self.num_samples = num_samples + self.web_dataset = None # Mapping between modality and the decode "function" self.modality_to_decode_fn = { @@ -355,12 +320,17 @@ def __init__( ModalityEnum.AUDIO: wds.torch_audio, } + self.additional_extreacted_keys = [] + # Some transforms require objects such as image if ModalityEnum.IMAGE in self.modality_transforms_configs: self._timm_image_transform = create_transform( **self.modality_transforms_configs[ModalityEnum.IMAGE].model_dump() ) + if ModalityEnum.TEXT in self.modality_transforms_configs: + self.additional_extreacted_keys.append("attention_mask") + # Mapping between modality and transform self.modality_to_transform_fn = { ModalityEnum.TEXT: self._transform_text, @@ -369,22 +339,46 @@ def __init__( ModalityEnum.AUDIO: self._transform_audio, } - for k, web_dataset in self.web_datasets: - # Apply shuffling to samples - if shuffle_buffer is not None and shuffle_buffer > 0: - web_dataset.append(wds.filters.shuffle(shuffle_buffer)) + def prepare( + self, shardshuffle: int = 100, resample: bool = True, repeat: bool = False, shuffle_buffer: int = 10_000 + ): + self.web_dataset = wds.WebDataset( + urls=self.urls, + nodesplitter=dummy_nodesplitter if not resample else None, + shardshuffle=shardshuffle, + repeat=repeat, + handler=wds.ignore_and_continue, + resampled=resample, + ) + + # Apply shuffling to samples + if shuffle_buffer is not None and shuffle_buffer > 0: + self.web_dataset.append(wds.filters.shuffle(shuffle_buffer)) + + # Flatten the json structure for convenience + self.web_dataset.append(wds.filters.decode(partial=True)) # Decode json byte string + self.web_dataset.append(wds.filters.map(self._flatten_sample)) + + # Load the actual data + for modality_key in self.modalities: + decode_fn = self.modality_to_decode_fn[modality_key] + if decode_fn is None: + continue + self.web_dataset.append(wds.filters.decode(decode_fn, partial=True)) + + # Transform the data + for modality_key in self.modalities: + transform_fn = self.modality_to_transform_fn[modality_key] + self.web_dataset.append(wds.filters.map(transform_fn)) - # Load the actual data - for modality_key in k: - transform_fn = self.modality_to_transform_fn[modality_key] - self.append(wds.filters.map(transform_fn)) + self.web_dataset.append(wds.filters.map(self._select_keys)) def _transform_text(self, sample): source_key, target_key = self.modality_key_mapping[ModalityEnum.TEXT] config: TextTransformConfig = self.modality_transforms_configs[ModalityEnum.TEXT] batch_encoding: BatchEncoding = config.tokenizer.tokenizer( sample[source_key], - max_length=config.block_size, + max_length=config.max_length, padding=config.padding, truncation=config.truncation, return_attention_mask=config.return_attention_mask, @@ -404,26 +398,90 @@ def _transform_video(self, sample): source_key, target_key = self.modality_key_mapping[ModalityEnum.VIDEO] # config: VideoTransformConfig = self.modality_transforms_configs[ModalityEnum.VIDEO] # TODO add video transform + sample[target_key] = sample[source_key] + del sample[source_key] return sample def _transform_audio(self, sample): source_key, target_key = self.modality_key_mapping[ModalityEnum.AUDIO] # config: AudioTransformConfig = self.modality_transforms_configs[ModalityEnum.AUDIO] # TODO add audio transform + sample[target_key] = sample[source_key] + del sample[source_key] return sample + def _flatten_sample(self, sample): + return flatten_dict(sample) + + def _select_keys(self, sample): + select_keys = self.additional_extreacted_keys + [v[1] for v in self.modality_key_mapping.values()] + new_sample = {} + for k, v in sample.items(): + if k not in select_keys: + continue + new_sample[k] = v + return new_sample + + +class MultimodalWebDataset(IterableDataset): + def __init__( + self, + builders: List[MultimodalWebDatasetBuilder], + mixing_ratios: Optional[List[int]] = None, + shardshuffle: int = 100, + repeat: bool = False, + resample: bool = True, + shuffle_buffer: Optional[int] = 10_000, + ): + """WebDataset for loading and combining multimodal datasets. + + Args: + builders: WebDatasetBuilder instances. + mixing_ratios: Mixing ratios of the different modality combinations. + For example: [0.3, 0.7] + shardshuffle: Number of sharfs that should be used for shuffling. Defaults to 100. + repeat: Repeat the dataset. Defaults to False. + resample: Instead if iterating in order sample random shards. + This has the issue that the model will see sample multiple times but if significantly more + efficient. Defaults to True. + shuffle_buffer: Number of samples that should be used for shuffling. Defaults to 10_000. + """ + self.builders = builders + + self.output_keys_by_modality = {} + for b in builders: + for k, v in b.modality_key_mapping.items(): + if k not in self.output_keys_by_modality: + self.output_keys_by_modality[k] = v[1] + else: + assert ( + self.output_keys_by_modality[k] == v[1] + ), "Output keys for the same modality of all builders should be the same." + + # Build datasets + [ + b.prepare(shardshuffle=shardshuffle, resample=resample, repeat=repeat, shuffle_buffer=shuffle_buffer) + for b in self.builders + ] + + # Setup mixing ratios + self.mixing_ratios = mixing_ratios + if self.mixing_ratios is None: + uniform_ratio = 1 / len(self.builders) + self.mixing_ratios = [uniform_ratio for _ in self.builders] + assert len(self.mixing_ratios) == len(self.builders) + def __iter__(self): - if len(self.web_datasets) > 1: + if len(self.builders) > 1: datasets = [] - ratios = [] - for k in self.web_datasets.keys(): - datasets.append(self.web_datasets[k]) - ratios.append(self.mixing_ratios[k]) - dataset = wds.RandomMix(datasets, ratios) # Apply mixing at sample level + for b in self.builders: + datasets.append(b.web_dataset) + dataset = wds.RandomMix(datasets, self.mixing_ratios) # Apply mixing at sample level return iter(dataset) - dataset = next(iter(self.web_datasets.values())) + dataset = self.builders[0].web_dataset return iter(dataset) def __len__(self): - return self.total_num_samples + total_num_samples = sum([b.num_samples for b in self.builders]) + return total_num_samples From faeec418f39c45ad4f2f29d3037a82cd53165d9c Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Tue, 11 Jun 2024 11:22:11 +0200 Subject: [PATCH 059/161] fix: web dataset integration --- .../config_example_coca_webdataset.yaml | 74 +++++---- src/modalities/dataloader/dataset.py | 151 +++++++++++------- src/modalities/dataloader/dataset_factory.py | 37 +---- src/modalities/registry/components.py | 18 ++- 4 files changed, 155 insertions(+), 125 deletions(-) diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml index 70a2d304c..fdf259e3d 100644 --- a/config_files/training/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -8,7 +8,7 @@ settings: global_training_log_interval_in_steps: 30 # Needs to be a multiple of gradient_acc_steps global_checkpointing_interval_in_steps: 9_990 global_evaluation_interval_in_steps: 4_980 - global_num_training_samples: ${train_dataset.config.num_samples} # 491 steps with 8 gpus and global bs of 1152 + global_num_training_samples: 566748 # 491 steps with 8 gpus and global bs of 1152 global_num_seen_steps: 0 do_apply_activation_checkpointing: true gradient_acc_steps: 30 @@ -25,7 +25,7 @@ tokenizer: component_key: tokenizer variant_key: pretrained_hf_tokenizer config: - pretrained_model_name_or_path: gpt2 + pretrained_model_name_or_path: openai/clip-vit-base-patch32 padding: true max_length: ${settings.training.sequence_length} @@ -40,47 +40,61 @@ collate_fn: text_sample_key: ${settings.referencing_keys.sample_key} text_target_key: ${settings.referencing_keys.target_key} +train_image_transform: + component_key: transform + variant_key: image_transform + config: + is_training: True + input_size: 256 + +text_transform: + component_key: transform + variant_key: text_transform + config: + tokenizer: + instance_key: tokenizer + pass_type: BY_REFERENCE + +train_coco_dataset_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "/nm-raid/video/multimodal_data/coco_captions/data/train/{000000..000011}.tar" + modality_key_mapping: + TEXT: ["json_text0", "input_ids"] + IMAGE: ["jpg", "images"] + modality_transforms: + IMAGE: + instance_key: train_image_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 566_748 + train_dataset: component_key: dataset variant_key: web_dataset config: - urls: "/hpcwork/rwth1597/coco/train/{000000..000566}.tar" - source_image_key: jpg - image_key: images - source_text_key: txt - text_key: input_ids - block_size: ${settings.training.sequence_length} - num_samples: 566_747 + builders: + - instance_key: train_coco_dataset_builder + pass_type: BY_REFERENCE shardshuffle: 1000 repeat: true resample: true shuffle: 1_000 - tokenizer: - instance_key: tokenizer - pass_type: BY_REFERENCE - image_transform_config: - is_training: True - input_size: 256 val_dataset: component_key: dataset variant_key: web_dataset config: - urls: "/hpcwork/rwth1597/coco/val/{000000..000025}.tar" - source_image_key: jpg - image_key: images - source_text_key: txt - text_key: input_ids - block_size: ${settings.training.sequence_length} - num_samples: 25_010 - resample: false - shuffle: 0 # Disable shuffling - tokenizer: - instance_key: tokenizer - pass_type: BY_REFERENCE - image_transform_config: - is_training: False - input_size: 256 + builders: + - instance_key: train_coco_dataset_builder + pass_type: BY_REFERENCE + shardshuffle: 1000 + repeat: true + resample: true + shuffle: 1_000 train_dataloader: component_key: data_loader diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index d2f8f92b9..722d224e5 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -2,20 +2,21 @@ from enum import Enum from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Annotated, Dict, List, Optional, Tuple, Union import jq import numpy as np import webdataset as wds -from pydantic import BaseModel +from pydantic import BaseModel, Field from timm.data import create_transform from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from torch.utils.data import IterableDataset from torch.utils.data.dataset import Dataset as TorchdataSet from tqdm import tqdm from transformers import BatchEncoding from modalities.config.config import PydanticTokenizerIFType +from modalities.config.lookup_enum import LookupEnum +from modalities.config.pydanctic_if_types import PydanticThirdPartyTypeIF from modalities.dataloader.create_packed_data import EmbeddedStreamData from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper @@ -214,7 +215,7 @@ def _generate_packing_index(self) -> List[Tuple[int, int]]: return index -class ModalityEnum(Enum): +class ModalityEnum(LookupEnum): TEXT = "text" IMAGE = "image" VIDEO = "video" @@ -225,6 +226,13 @@ class TransformConfig(BaseModel): pass +class Transform: + pass + + +PydanticTransformIFType = Annotated[Transform, PydanticThirdPartyTypeIF(Transform)] + + class ImageTransformConfig(TransformConfig): input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224 is_training: bool = False @@ -254,6 +262,15 @@ class ImageTransformConfig(TransformConfig): separate: bool = False +# @register_component("transform", "image_transform", ImageTransformConfig) +class ImageTransform(Transform): + def __init__(self, **kwargs): + self._timm_image_transform = create_transform(**kwargs) + + def __call__(self, *args, **kwargs): + return self._timm_image_transform(*args, **kwargs) + + class TextTransformConfig(TransformConfig): tokenizer: PydanticTokenizerIFType max_length: int = 77 @@ -262,37 +279,47 @@ class TextTransformConfig(TransformConfig): return_attention_mask: bool = True -class WebDatasetConfig(BaseModel): - urls: Union[List[str], str] - source_image_key: str - image_key: str - source_text_key: str - text_key: str - tokenizer: PydanticTokenizerIFType - block_size: int - num_samples: int - image_transform_config: Optional[ImageTransformConfig] = None - shardshuffle: Optional[int] = None - repeat: bool = False - resample: bool = False - shuffle: int = 0 +# @register_component("transform", "text_transform", TextTransformConfig) +class TextTransform(Transform): + def __init__( + self, + tokenizer: TokenizerWrapper, + max_length: int = 77, + padding: str = "max_length", + truncation: bool = True, + return_attention_mask: bool = True, + ): + self.tokenizer = tokenizer + self.max_length = max_length + self.padding = padding + self.truncation = truncation + self.return_attention_mask = return_attention_mask + + def __call__(self, text): + batch_encoding: BatchEncoding = self.tokenizer.tokenizer( + text, + max_length=self.max_length, + padding=self.padding, + truncation=self.truncation, + return_attention_mask=self.return_attention_mask, + ) + return batch_encoding -def dummy_nodesplitter(src, group=None): - # This node splitter is not actually splitting the data over the nodes - # but keeps the complete dataset on each node. - # This is required so that each node has the same amount of data. - # In the case of 25 shards and 16 ranks for example 7 ranks are - # without data in the second iteration. This will cause a crash once all_gather is called. - yield from src +class MultimodalWebDatasetBuilderConfig(BaseModel): + urls: Union[List[str], str] + modality_key_mapping: Dict[ModalityEnum, Tuple[str, str]] + modality_transforms: Dict[ModalityEnum, PydanticTransformIFType] + num_samples: Annotated[int, Field(ge=1)] +# @register_component("dataset", "web_dataset_builder", MultimodalWebDatasetBuilderConfig) class MultimodalWebDatasetBuilder: def __init__( self, urls: Union[List[str], str], - modality_key_mapping: Dict[ModalityEnum, Tuple[str, str]], - modality_transforms_configs: Dict[ModalityEnum, TransformConfig], + modality_key_mapping: Dict[str, Tuple[str, str]], + modality_transforms: Dict[str, Transform], num_samples: int, ): """A multimodal dataset instance for the WebDataset. @@ -301,13 +328,13 @@ def __init__( urls: A webdataset url. For example: "/data/path/{00000..00012.tar". modality_key_mapping: Mapping from dataset keys to keys expected by the forward pass of the model. For example: {ModalityEnum.IMAGE: ("jpg", "image"), ModalityEnum.TEXT: ("text", "caption")}} - modality_transforms_configs: The transform config for each modality. + modality_transforms: The transforms for each modality. num_samples: The number of samples for each modality combination. """ self.urls = urls self.modality_key_mapping = modality_key_mapping - self.modality_transforms_configs = modality_transforms_configs - assert self.modality_key_mapping.keys() == self.modality_transforms_configs.keys() + self.modality_transforms = modality_transforms + assert self.modality_key_mapping.keys() == self.modality_transforms.keys() self.modalities = list(self.modality_key_mapping.keys()) self.num_samples = num_samples self.web_dataset = None @@ -321,14 +348,7 @@ def __init__( } self.additional_extreacted_keys = [] - - # Some transforms require objects such as image - if ModalityEnum.IMAGE in self.modality_transforms_configs: - self._timm_image_transform = create_transform( - **self.modality_transforms_configs[ModalityEnum.IMAGE].model_dump() - ) - - if ModalityEnum.TEXT in self.modality_transforms_configs: + if ModalityEnum.TEXT in self.modality_transforms: self.additional_extreacted_keys.append("attention_mask") # Mapping between modality and transform @@ -344,7 +364,7 @@ def prepare( ): self.web_dataset = wds.WebDataset( urls=self.urls, - nodesplitter=dummy_nodesplitter if not resample else None, + nodesplitter=self.dummy_nodesplitter if not resample else None, shardshuffle=shardshuffle, repeat=repeat, handler=wds.ignore_and_continue, @@ -375,14 +395,8 @@ def prepare( def _transform_text(self, sample): source_key, target_key = self.modality_key_mapping[ModalityEnum.TEXT] - config: TextTransformConfig = self.modality_transforms_configs[ModalityEnum.TEXT] - batch_encoding: BatchEncoding = config.tokenizer.tokenizer( - sample[source_key], - max_length=config.max_length, - padding=config.padding, - truncation=config.truncation, - return_attention_mask=config.return_attention_mask, - ) + transform: TextTransform = self.modality_transforms[ModalityEnum.TEXT] + batch_encoding: BatchEncoding = transform(sample[source_key]) del sample[source_key] sample[target_key] = batch_encoding.input_ids sample["attention_mask"] = batch_encoding.attention_mask @@ -390,7 +404,8 @@ def _transform_text(self, sample): def _transform_image(self, sample): source_key, target_key = self.modality_key_mapping[ModalityEnum.IMAGE] - sample[target_key] = self._timm_image_transform(sample[source_key]) + transform: TextTransform = self.modality_transforms[ModalityEnum.IMAGE] + sample[target_key] = transform(sample[source_key]) del sample[source_key] return sample @@ -422,8 +437,33 @@ def _select_keys(self, sample): new_sample[k] = v return new_sample + @staticmethod + def dummy_nodesplitter(src, group=None): + # This node splitter is not actually splitting the data over the nodes + # but keeps the complete dataset on each node. + # This is required so that each node has the same amount of data. + # In the case of 25 shards and 16 ranks for example 7 ranks are + # without data in the second iteration. This will cause a crash once all_gather is called. + # This is only relevant for validation. + yield from src -class MultimodalWebDataset(IterableDataset): + +PydanticMultimodalWebDatasetBuilderIFType = Annotated[ + MultimodalWebDatasetBuilder, PydanticThirdPartyTypeIF(MultimodalWebDatasetBuilder) +] + + +class MultimodalWebDatasetConfig(BaseModel): + builders: List[PydanticMultimodalWebDatasetBuilderIFType] + mixing_ratios: Optional[List[int]] = None + shardshuffle: int = 100 + repeat: bool = False + resample: bool = True + shuffle_buffer: Optional[int] = 10_000 + + +# @register_component("dataset", "web_dataset", MultimodalWebDatasetConfig) +class MultimodalWebDataset(wds.DataPipeline, wds.compat.FluidInterface): def __init__( self, builders: List[MultimodalWebDatasetBuilder], @@ -446,6 +486,7 @@ def __init__( efficient. Defaults to True. shuffle_buffer: Number of samples that should be used for shuffling. Defaults to 10_000. """ + super().__init__() self.builders = builders self.output_keys_by_modality = {} @@ -471,17 +512,13 @@ def __init__( self.mixing_ratios = [uniform_ratio for _ in self.builders] assert len(self.mixing_ratios) == len(self.builders) - def __iter__(self): if len(self.builders) > 1: datasets = [] for b in self.builders: datasets.append(b.web_dataset) dataset = wds.RandomMix(datasets, self.mixing_ratios) # Apply mixing at sample level - return iter(dataset) - - dataset = self.builders[0].web_dataset - return iter(dataset) + self.pipeline.extend(dataset.pipeline) + else: + self.pipeline.extend(self.builders[0].web_dataset.pipeline) - def __len__(self): - total_num_samples = sum([b.num_samples for b in self.builders]) - return total_num_samples + self.with_length(sum([b.num_samples for b in self.builders])) diff --git a/src/modalities/dataloader/dataset_factory.py b/src/modalities/dataloader/dataset_factory.py index 261765c0c..1d31e27cf 100644 --- a/src/modalities/dataloader/dataset_factory.py +++ b/src/modalities/dataloader/dataset_factory.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple from pydantic import FilePath from torch.utils.data.dataset import Dataset @@ -8,11 +8,9 @@ from modalities.dataloader.dataset import ( DummyDataset, DummySampleConfig, - ImageTransformConfig, MemMapDataset, PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron, - WebDataset, ) from modalities.dataloader.open_gptx_dataset.open_gptx_dataset import OpenGPTXMMapDataset @@ -92,36 +90,3 @@ def get_open_gptx_mmap_dataset( # TODO: Fix the OpenGPTX implementation and get rid of this hack. dataset_wrapped = OpenGPTXDatasetWrapper(open_gptx_dataset=dataset, num_samples=num_samples) return dataset_wrapped - - @staticmethod - def get_web_dataset( - urls: Union[List[str], str], - source_image_key: str, - image_key: str, - source_text_key: str, - text_key: str, - tokenizer: PreTrainedTokenizer, - block_size: int, - num_samples: int, - image_transform_config: Optional[ImageTransformConfig] = None, - shardshuffle: Optional[int] = None, - repeat: bool = False, - resample: bool = False, - shuffle: int = 0, - ) -> WebDataset: - dataset = WebDataset( - urls=urls, - source_image_key=source_image_key, - image_key=image_key, - source_text_key=source_text_key, - text_key=text_key, - tokenizer=tokenizer, - block_size=block_size, - num_samples=num_samples, - image_transform_config=image_transform_config, - shardshuffle=shardshuffle, - repeat=repeat, - resample=resample, - shuffle=shuffle, - ) - return dataset diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 739bf6506..5dab70737 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -53,7 +53,17 @@ WebLoaderConfig, ) from modalities.dataloader.dataloader_factory import DataloaderFactory -from modalities.dataloader.dataset import DummyDatasetConfig, WebDatasetConfig +from modalities.dataloader.dataset import ( + DummyDatasetConfig, + ImageTransform, + ImageTransformConfig, + MultimodalWebDataset, + MultimodalWebDatasetBuilder, + MultimodalWebDatasetBuilderConfig, + MultimodalWebDatasetConfig, + TextTransform, + TextTransformConfig, +) from modalities.dataloader.dataset_factory import DatasetFactory from modalities.logging_broker.subscriber_impl.subscriber_factory import ( ProgressSubscriberFactory, @@ -155,7 +165,11 @@ class ComponentEntity: "dataset", "open_gptx_mmap_dataset", DatasetFactory.get_open_gptx_mmap_dataset, OpenGPTXMMapDatasetConfig ), ComponentEntity("dataset", "dummy_dataset", DatasetFactory.get_dummy_dataset, DummyDatasetConfig), - ComponentEntity("dataset", "web_dataset", DatasetFactory.get_web_dataset, WebDatasetConfig), + ComponentEntity("dataset", "web_dataset", MultimodalWebDataset, MultimodalWebDatasetConfig), + ComponentEntity("dataset", "web_dataset_builder", MultimodalWebDatasetBuilder, MultimodalWebDatasetBuilderConfig), + # Data transforms & augmentations + ComponentEntity("transform", "text_transform", TextTransform, TextTransformConfig), + ComponentEntity("transform", "image_transform", ImageTransform, ImageTransformConfig), # samplers ComponentEntity("sampler", "distributed_sampler", DistributedSampler, DistributedSamplerConfig), # batch samplers From 1c6fcffb29dff0653f8ba21f12a81f72333ff931 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Tue, 11 Jun 2024 11:23:52 +0200 Subject: [PATCH 060/161] chore: add todo statement --- src/modalities/dataloader/dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 722d224e5..b1a83ffb1 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -488,6 +488,7 @@ def __init__( """ super().__init__() self.builders = builders + assert len(builders) == 1, "Multiple dataset builders are not supported yet" # TODO self.output_keys_by_modality = {} for b in builders: From 42f469b144300eb709e9cddc4004b97811a0fe71 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Tue, 11 Jun 2024 11:52:24 +0200 Subject: [PATCH 061/161] chore: update start script --- start.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/start.sh b/start.sh index 9878368a0..236ff7e68 100644 --- a/start.sh +++ b/start.sh @@ -1,3 +1,3 @@ #!/bin/bash -CUDA_VISIBLE_DEVICES=0 torchrun --nnodes 1 --nproc_per_node 1 --rdzv-endpoint=0.0.0.0:29502 src/modalities/__main__.py run --config_file_path config_files/training/config_example_coca.yaml \ No newline at end of file +CUDA_VISIBLE_DEVICES=0 torchrun --nnodes 1 --nproc_per_node 1 --rdzv-endpoint=0.0.0.0:29502 src/modalities/__main__.py run --config_file_path config_files/training/config_example_coca_webdataset.yaml \ No newline at end of file From 78efe14ecb43776c967e3b4d75d98bd73d524c19 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Tue, 11 Jun 2024 11:57:39 +0200 Subject: [PATCH 062/161] fix: coca webdataset config --- .../training/config_example_coca_webdataset.yaml | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml index fdf259e3d..4161d4fec 100644 --- a/config_files/training/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -79,10 +79,10 @@ train_dataset: builders: - instance_key: train_coco_dataset_builder pass_type: BY_REFERENCE - shardshuffle: 1000 + shardshuffle: 100 repeat: true resample: true - shuffle: 1_000 + shuffle_buffer: 10_000 val_dataset: component_key: dataset @@ -94,7 +94,7 @@ val_dataset: shardshuffle: 1000 repeat: true resample: true - shuffle: 1_000 + shuffle_buffer: 10_000 train_dataloader: component_key: data_loader @@ -148,9 +148,6 @@ checkpoint_saving: checkpoint_path: ${settings.paths.checkpointing_path} global_rank: ${settings.cuda_env.global_rank} experiment_id: ${settings.experiment_id} - mixed_precision_settings: FP_16 - sharding_strategy: HYBRID_SHARD - block_names: [TransformerBlock, VisionTransformerBlock] captioning_loss: component_key: loss @@ -276,7 +273,6 @@ batch_progress_subscriber: variant_key: rich config: local_rank: ${settings.cuda_env.global_rank} - world_size: ${settings.cuda_env.world_size} global_num_seen_steps: ${settings.training.global_num_seen_steps} train_dataloader: instance_key: train_dataloader From deb07886ee552b3c04097711db19e43a902f08b1 Mon Sep 17 00:00:00 2001 From: Sogol Haghighat Date: Fri, 7 Jun 2024 18:30:36 +0200 Subject: [PATCH 063/161] feat: extend vision transformer model to video data --- .../vision_transformer_model.py | 144 ++++++++++++++++-- 1 file changed, 133 insertions(+), 11 deletions(-) diff --git a/src/modalities/models/vision_transformer/vision_transformer_model.py b/src/modalities/models/vision_transformer/vision_transformer_model.py index 78cbbb7af..1611d6ad5 100644 --- a/src/modalities/models/vision_transformer/vision_transformer_model.py +++ b/src/modalities/models/vision_transformer/vision_transformer_model.py @@ -25,6 +25,8 @@ class VisionTransformerConfig(BaseModel): n_img_channels: Annotated[int, Field(ge=1)] = 3 add_cls_token: bool = True bias: bool = True + num_video_frames: Annotated[int, Field(ge=0)] = 1 # TODO: read this from dataloader/train config + n_latents: Annotated[int, Field(ge=1)] = 64 class ImagePatchEmbedding(nn.Module): @@ -61,6 +63,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class VideoPatchEmbedding(nn.Module): + def __init__( + self, + n_img_channels: int = 3, + n_embd: int = 768, + patch_size: int = 16, + patch_stride: int = 16, + add_cls_token: bool = True, + ) -> None: + super().__init__() + self.conv = nn.Conv3d( + in_channels=n_img_channels, + out_channels=n_embd, + kernel_size=(1, patch_size, patch_size), + stride=(1, patch_size, patch_stride), + ) # TODO: check the 3D conv again + + # See https://github.com/arogozhnikov/einops/wiki/Using-torch.compile-with-einops + self.rearrange = Rearrange("b c T h w -> b T (h w) c") # TODO: this might change when implementing dataloader + + self.cls_token = None + if add_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, n_embd)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + x = self.rearrange(x) + B, T = x.shape[:2] + if self.cls_token is not None: # TODO: remove cls token at the frame level + x = torch.cat([self.cls_token.repeat(B, T, 1, 1), x], dim=2) + return x # [b T S D] + + class VisionTransformerBlock(nn.Module): def __init__( self, @@ -88,6 +123,45 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +# TODO: extend to all modalities based on the original paper (https://arxiv.org/pdf/2103.03206)! +# TODO: extend this to work with video and images! +class PerceiverTransformerBlock(nn.Module): + """Perceiver Resampler + + This is a transformer based architecture that performs cross and self attention to compress and embed video inputs. + paper: 'Flamingo: a Visual Language Model for Few-Shot Learning' + Link: https://github.com/mlfoundations/open_flamingo + """ + + def __init__( + self, + n_embd: int = 768, + n_head: int = 8, + ffn_hidden: int = 3072, + bias: bool = True, + dropout: float = 0.0, + attention_config: AttentionConfig = None, + ) -> None: + super().__init__() + self.norm_latents = nn.LayerNorm(n_embd) + self.norm = nn.LayerNorm(n_embd) + self.attention = MultiHeadAttention( + n_embd=n_embd, + n_head=n_head, + attention_config=attention_config, + attention_type=AttentionType.CROSS_ATTENTION, + ) + self.mlp = MLP(in_features=n_embd, hidden_features=ffn_hidden, bias=bias, dropout=dropout) + + def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: + latents = self.norm_latents(latents) + x = self.norm(x) + context = torch.cat((x, latents), dim=-2) # video features and the latent together + latents = latents + self.attention(latents, context=context) + latents = latents + self.mlp(latents) + return latents + + class VisionTransformer(nn.Module): """ViT @@ -96,6 +170,8 @@ class VisionTransformer(nn.Module): Paper: `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` Link: https://arxiv.org/abs/2010.11929 + + This architecture is extended to encode videos using a perceiver resampler transformer model """ def __init__( @@ -115,19 +191,40 @@ def __init__( n_img_channels: int = 3, add_cls_token: bool = True, bias: bool = True, + num_video_frames: int = 1, # when dealing with video this is bigger than 1 + n_latents: int = 64, ) -> None: super().__init__() self.sample_key = sample_key self.prediction_key = prediction_key self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) self.block_size = self._calculate_block_size(self.img_size, patch_size, patch_stride, add_cls_token) - - self.embedding_fn = ImagePatchEmbedding(n_img_channels, n_embd, patch_size, patch_stride, add_cls_token) - self.positional_embedding_fn = nn.Embedding(num_embeddings=self.block_size, embedding_dim=n_embd) self.dropout = nn.Dropout(dropout) + + self.head = None + if n_classes is not None: + self.norm = nn.LayerNorm(n_embd) + self.head = nn.Linear(in_features=n_embd, out_features=n_classes, bias=bias) + + self.vision_input = "Image" + if num_video_frames > 1: # video data + self.vision_input = "Video" + self.embedding_fn = VideoPatchEmbedding(n_img_channels, n_embd, patch_size, patch_stride) # [b T S D] + self.time_embd = nn.Parameter(torch.randn(num_video_frames, 1, n_embd)) # [T,1,d] + if add_cls_token: + # self.block_size -= 1 # to remove cls token at frame level + n_latents += 1 # to count for a video level cls token + self.latents = nn.Parameter(torch.randn(n_latents, n_embd)) # [R,d] + self.rearrange = Rearrange("b T S D -> b (T S) D") + else: + self.embedding_fn = ImagePatchEmbedding(n_img_channels, n_embd, patch_size, patch_stride, add_cls_token) + + self.positional_embedding_fn = nn.Embedding(num_embeddings=self.block_size, embedding_dim=n_embd) # [S D] + block_classes = {"Video": PerceiverTransformerBlock, "Image": VisionTransformerBlock} + self.blocks = nn.ModuleList( [ - VisionTransformerBlock( + block_classes[self.vision_input]( n_embd=n_embd, n_head=n_head, ffn_hidden=ffn_hidden, @@ -139,11 +236,6 @@ def __init__( ] ) - self.head = None - if n_classes is not None: - self.norm = nn.LayerNorm(n_embd) - self.head = nn.Linear(in_features=n_embd, out_features=n_classes, bias=bias) - def forward_images(self, x: torch.Tensor) -> torch.Tensor: x = self.embedding_fn(x) x = self.dropout(x + self.positional_embedding_fn.weight) @@ -151,9 +243,39 @@ def forward_images(self, x: torch.Tensor) -> torch.Tensor: x = block(x) return x - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward_videos(self, x: torch.Tensor) -> torch.Tensor: + """Encode video data into a shorter sequence of tokens + + Args: + x (torch.Tensor): images from multiple video frames + shape (b c T h w) + b: batch size + T: temporal dim + h,w: spatial dims (S=h*w) + c: embedding dim (D) + + Returns: + torch.Tensor: latents + shape (b R D) R << T*S + """ + x = self.embedding_fn(x) # [b T S D] + b, T = x.shape[:2] + # TODO: check this! + x = self.dropout(x + self.positional_embedding_fn.weight) + x = self.dropout(x + self.time_embd.repeat(b, 1, 1, 1)) + x = self.rearrange(x) # [b T*S D] + latents = self.latents.repeat(b, 1, 1) # [b,R,d] with R< Dict[str, torch.Tensor]: # TODO video adapt + # TODO: add video_sample_key and video_prediction_key x = inputs[self.sample_key] - x = self.forward_images(x) + if self.vision_input == "Video": + x = self.forward_videos(x) + else: + x = self.forward_images(x) if self.head: if self.embedding_fn.cls_token is not None: x = x[:, 0] From 4a9695d7a06a8ad66bb1d6fea400998f224dc6b0 Mon Sep 17 00:00:00 2001 From: Sogol Haghighat Date: Fri, 7 Jun 2024 18:32:20 +0200 Subject: [PATCH 064/161] test: extend vision transformer test to video data --- .../test_vision_transformer.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/models/vision_transformer/test_vision_transformer.py b/tests/models/vision_transformer/test_vision_transformer.py index 24b03921a..bb68a7f03 100644 --- a/tests/models/vision_transformer/test_vision_transformer.py +++ b/tests/models/vision_transformer/test_vision_transformer.py @@ -33,6 +33,31 @@ def test_vision_transformer(): assert "logits" in out assert out["logits"].shape == (1, 1000) + # Test for video input + # Create model + config_file_path2 = _ROOT_DIR / Path("tests/models/vision_transformer/vision_transformer_config2.yaml") + config_dict2 = load_app_config_dict(config_file_path=config_file_path2) + config2 = VisionTransformerConfig.model_validate(config_dict2) + model2 = VisionTransformer(**dict(config2)) + + # Create dummy inputs + dummy_input_video = torch.randn(1, 3, 16, 224, 224) # [b c T h w] + dummy_input2 = dict(videos=dummy_input_video) + + # Create optimizer + optimizer2 = torch.optim.SGD(model2.parameters(), lr=0.001, momentum=0.9) + + # Run one training step + optimizer2.zero_grad() + out2 = model2(dummy_input2) + loss2 = out2["logits"].sum() + loss2.backward() + optimizer2.step() + + # Test outputs + assert "logits" in out2 + assert out2["logits"].shape == (1, 1000) + @pytest.mark.parametrize( "img_size,patch_size,patch_stride,add_cls_token,target_block_size", From f04243a5803aed8bc673e3c6a5a86ac4d8e0de9b Mon Sep 17 00:00:00 2001 From: Sogol Haghighat Date: Fri, 7 Jun 2024 18:33:32 +0200 Subject: [PATCH 065/161] test: add and update config for testing vision transformer with image and video --- .../vision_transformer_config.yaml | 2 ++ .../vision_transformer_config2.yaml | 15 +++++++++++++++ 2 files changed, 17 insertions(+) create mode 100644 tests/models/vision_transformer/vision_transformer_config2.yaml diff --git a/tests/models/vision_transformer/vision_transformer_config.yaml b/tests/models/vision_transformer/vision_transformer_config.yaml index d6657c5c1..507719791 100644 --- a/tests/models/vision_transformer/vision_transformer_config.yaml +++ b/tests/models/vision_transformer/vision_transformer_config.yaml @@ -11,3 +11,5 @@ patch_stride: 16 n_img_channels: 3 add_cls_token: True bias: True +num_video_frames: 1 +n_latents: 64 diff --git a/tests/models/vision_transformer/vision_transformer_config2.yaml b/tests/models/vision_transformer/vision_transformer_config2.yaml new file mode 100644 index 000000000..7951ec737 --- /dev/null +++ b/tests/models/vision_transformer/vision_transformer_config2.yaml @@ -0,0 +1,15 @@ +sample_key: videos +prediction_key: logits +img_size: 224 +n_classes: 1000 +n_layer: 6 +n_head: 8 +n_embd: 768 +dropout: 0.0 +patch_size: 16 +patch_stride: 16 +n_img_channels: 3 +add_cls_token: True +bias: True +num_video_frames: 16 +n_latents: 64 From cad8d11d4b7f5ab2e94aaf8b8dc6db3e25cfe7a4 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Tue, 11 Jun 2024 17:21:24 +0200 Subject: [PATCH 066/161] feat: add video transforms --- src/modalities/dataloader/dataset.py | 51 ++++++++++++++++++++++++--- src/modalities/registry/components.py | 3 ++ 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index b1a83ffb1..515c29139 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -1,16 +1,19 @@ from __future__ import annotations +import random from enum import Enum from pathlib import Path from typing import Annotated, Dict, List, Optional, Tuple, Union import jq import numpy as np +import torch import webdataset as wds from pydantic import BaseModel, Field from timm.data import create_transform from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torch.utils.data.dataset import Dataset as TorchdataSet +from torchvision import transforms from tqdm import tqdm from transformers import BatchEncoding @@ -267,8 +270,8 @@ class ImageTransform(Transform): def __init__(self, **kwargs): self._timm_image_transform = create_transform(**kwargs) - def __call__(self, *args, **kwargs): - return self._timm_image_transform(*args, **kwargs) + def __call__(self, image): + return self._timm_image_transform(image) class TextTransformConfig(TransformConfig): @@ -306,6 +309,45 @@ def __call__(self, text): return batch_encoding +class RandomTemporalCrop: + def __init__(self, num_frames): + self.num_frames = num_frames + + def __call__(self, video): + total_frames = len(video) + start = random.randint(0, total_frames - self.num_frames) + return video[start : start + self.num_frames] + + +class VideoTransformConfig(TransformConfig): + input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224 + is_training: bool = False + num_frames: int = 16 + + +class VideoTransform(Transform): + def __init__( + self, + input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224, + is_training: bool = False, + num_frames: int = 16, + ): + self.spatial_transform = transforms.Compose( + [ + transforms.RandomResizedCrop(input_size), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + self.temporal_transform = RandomTemporalCrop(num_frames=16) + + def __call__(self, video): + video = self.temporal_transform(video) + return torch.stack([self.spatial_transform(frame) for frame in video]) + + class MultimodalWebDatasetBuilderConfig(BaseModel): urls: Union[List[str], str] modality_key_mapping: Dict[ModalityEnum, Tuple[str, str]] @@ -411,9 +453,8 @@ def _transform_image(self, sample): def _transform_video(self, sample): source_key, target_key = self.modality_key_mapping[ModalityEnum.VIDEO] - # config: VideoTransformConfig = self.modality_transforms_configs[ModalityEnum.VIDEO] - # TODO add video transform - sample[target_key] = sample[source_key] + transform: VideoTransform = self.modality_transforms[ModalityEnum.VIDEO] + sample[target_key] = transform(sample[source_key]) del sample[source_key] return sample diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index a53388956..6acfcceea 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -63,6 +63,8 @@ MultimodalWebDatasetConfig, TextTransform, TextTransformConfig, + VideoTransform, + VideoTransformConfig, ) from modalities.dataloader.dataset_factory import DatasetFactory from modalities.logging_broker.subscriber_impl.subscriber_factory import ( @@ -173,6 +175,7 @@ class ComponentEntity: # Data transforms & augmentations ComponentEntity("transform", "text_transform", TextTransform, TextTransformConfig), ComponentEntity("transform", "image_transform", ImageTransform, ImageTransformConfig), + ComponentEntity("transform", "video_transform", VideoTransform, VideoTransformConfig), # samplers ComponentEntity("sampler", "distributed_sampler", DistributedSampler, DistributedSamplerConfig), # batch samplers From 36539226346db2061ae705a6b8bbad5675d27e08 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Tue, 11 Jun 2024 17:21:37 +0200 Subject: [PATCH 067/161] chore: add video config --- .../config_example_video_coca_webdataset.yaml | 296 ++++++++++++++++++ 1 file changed, 296 insertions(+) create mode 100644 config_files/training/config_example_video_coca_webdataset.yaml diff --git a/config_files/training/config_example_video_coca_webdataset.yaml b/config_files/training/config_example_video_coca_webdataset.yaml new file mode 100644 index 000000000..539d396ba --- /dev/null +++ b/config_files/training/config_example_video_coca_webdataset.yaml @@ -0,0 +1,296 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + training: + global_training_log_interval_in_steps: 30 # Needs to be a multiple of gradient_acc_steps + global_checkpointing_interval_in_steps: 9_990 + global_evaluation_interval_in_steps: 4_980 + global_num_training_samples: 566748 # 491 steps with 8 gpus and global bs of 1152 + global_num_seen_steps: 0 + do_apply_activation_checkpointing: true + gradient_acc_steps: 30 + local_train_micro_batch_size: 8 + sequence_length: 64 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints + +tokenizer: + component_key: tokenizer + variant_key: pretrained_hf_tokenizer + config: + pretrained_model_name_or_path: openai/clip-vit-base-patch32 + padding: true + max_length: ${settings.training.sequence_length} + +collate_fn: + component_key: collate_fn + variant_key: coca_collator + config: + sample_keys: + - images + - ${settings.referencing_keys.sample_key} + target_keys: [] + text_sample_key: ${settings.referencing_keys.sample_key} + text_target_key: ${settings.referencing_keys.target_key} + +train_video_transform: + component_key: transform + variant_key: video_transform + config: + is_training: True + input_size: 256 + num_frames: ${model.config.vision_encoder_config.num_video_frames} + +text_transform: + component_key: transform + variant_key: text_transform + config: + tokenizer: + instance_key: tokenizer + pass_type: BY_REFERENCE + +train_video_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "/path/to/data/{000000..000009}.tar" + modality_key_mapping: + TEXT: ["caption", "input_ids"] + VIDEO: ["mp4", "videos"] + modality_transforms: + VIDEO: + instance_key: train_video_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 566_748 + +train_dataset: + component_key: dataset + variant_key: web_dataset + config: + builders: + - instance_key: train_video_builder + pass_type: BY_REFERENCE + shardshuffle: 100 + repeat: true + resample: true + shuffle_buffer: 10_000 + +val_dataset: + component_key: dataset + variant_key: web_dataset + config: + builders: + - instance_key: train_video_builder + pass_type: BY_REFERENCE + shardshuffle: 1000 + repeat: true + resample: true + shuffle_buffer: 10_000 + +train_dataloader: + component_key: data_loader + variant_key: web_loader + config: + num_workers: 8 + pin_memory: true + drop_last: true + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_size: ${settings.training.local_train_micro_batch_size} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +val_dataloader: + component_key: data_loader + variant_key: web_loader + config: + num_workers: 8 + pin_memory: true + drop_last: false + dataloader_tag: "val" + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_size: ${settings.training.local_train_micro_batch_size} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: fsdp + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +captioning_loss: + component_key: loss + variant_key: cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${model.config.prediction_key} + tag: captioning_loss + weight: 2.0 + +contrastive_loss: + component_key: loss + variant_key: clip_loss + config: + prediction_key1: ${model.config.vision_cls_prediction_key} + prediction_key2: ${model.config.text_cls_prediction_key} + logit_scale_key: ${model.config.logit_scale_prediction_key} + tag: contrastive_loss + weight: 1.0 + +loss_fn: + - instance_key: captioning_loss + pass_type: BY_REFERENCE + - instance_key: contrastive_loss + pass_type: BY_REFERENCE + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: FP_16 + sharding_strategy: HYBRID_SHARD + block_names: [TransformerBlock, PerceiverTransformerBlock] + +model: + component_key: model + variant_key: coca + config: + prediction_key: logits + vision_embd_prediction_key: vision_embeddings + text_embd_prediction_key: text_embeddings + vision_cls_prediction_key: vision_cls + text_cls_prediction_key: text_cls + logit_scale_prediction_key: logit_scale + vision_encoder_config: + sample_key: images + prediction_key: vision_embeddings + img_size: 256 # 288 in the original coca + n_classes: Null # Disable vision transformer head + n_layer: 12 + attention_config: + attention_engine_type: default_attention + n_head: 12 + n_embd: 768 + dropout: 0.0 + patch_size: 16 # 18 in the original coca + patch_stride: 16 # 18 in the original coca + n_img_channels: 3 + add_cls_token: False + bias: True + num_video_frames: 16 + n_latents: 64 + text_decoder_config: + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${model.config.prediction_key} + block_size: 1024 + vocab_size: 50304 # 64k in the original coca + n_layer_text: 12 + n_layer_multimodal_text: 12 + attention_config: + attention_engine_type: default_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: fused_swiglu + epsilon: 1e-5 + n_pool_head: 12 + n_vision_queries: 256 + bias_attn_pool: False + epsilon_attn_pool: 1e-5 + weight_init: + mean: 0.0 + std: 0.02 + +scheduler: + component_key: scheduler + variant_key: cosine_annealing_with_warmup_lr # COCA uses linear decay + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + num_warmup_steps: 2_000 + num_training_steps: 500_000 + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 8e-4 + betas: [0.9, 0.999] + eps: 1e-8 + weight_decay: 0.01 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp + config: + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.global_rank} + global_num_seen_steps: ${settings.training.global_num_seen_steps} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.global_rank} + project: modalities + mode: ONLINE + experiment_id: ${settings.experiment_id} + directory: "." + config_file_path: ${settings.config_file_path} From 3b72cd2df4157d31c8736d7f7601bd33861a1daf Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Tue, 11 Jun 2024 18:17:08 +0200 Subject: [PATCH 068/161] fix: video coca --- .../config_example_video_coca_webdataset.yaml | 8 ++++---- src/modalities/dataloader/dataset.py | 9 +++++---- .../vision_transformer/vision_transformer_model.py | 14 ++++---------- 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/config_files/training/config_example_video_coca_webdataset.yaml b/config_files/training/config_example_video_coca_webdataset.yaml index 539d396ba..58d2225d6 100644 --- a/config_files/training/config_example_video_coca_webdataset.yaml +++ b/config_files/training/config_example_video_coca_webdataset.yaml @@ -34,7 +34,7 @@ collate_fn: variant_key: coca_collator config: sample_keys: - - images + - videos - ${settings.referencing_keys.sample_key} target_keys: [] text_sample_key: ${settings.referencing_keys.sample_key} @@ -60,9 +60,9 @@ train_video_builder: component_key: dataset variant_key: web_dataset_builder config: - urls: "/path/to/data/{000000..000009}.tar" + urls: "/nm-raid/video/Kinetics/kinetics-dataset/k400/dummy_wds/{000000..000010}.tar" modality_key_mapping: - TEXT: ["caption", "input_ids"] + TEXT: ["json", "input_ids"] VIDEO: ["mp4", "videos"] modality_transforms: VIDEO: @@ -198,7 +198,7 @@ model: text_cls_prediction_key: text_cls logit_scale_prediction_key: logit_scale vision_encoder_config: - sample_key: images + sample_key: videos prediction_key: vision_embeddings img_size: 256 # 288 in the original coca n_classes: Null # Disable vision transformer head diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 515c29139..b6513fa33 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -316,7 +316,7 @@ def __init__(self, num_frames): def __call__(self, video): total_frames = len(video) start = random.randint(0, total_frames - self.num_frames) - return video[start : start + self.num_frames] + return video[start : start + self.num_frames].permute(0, 3, 1, 2) # F C H W class VideoTransformConfig(TransformConfig): @@ -334,18 +334,19 @@ def __init__( ): self.spatial_transform = transforms.Compose( [ - transforms.RandomResizedCrop(input_size), + transforms.RandomResizedCrop(input_size, antialias=True), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), - transforms.ToTensor(), + transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) self.temporal_transform = RandomTemporalCrop(num_frames=16) def __call__(self, video): + video = video[0] video = self.temporal_transform(video) - return torch.stack([self.spatial_transform(frame) for frame in video]) + return self.spatial_transform(video) class MultimodalWebDatasetBuilderConfig(BaseModel): diff --git a/src/modalities/models/vision_transformer/vision_transformer_model.py b/src/modalities/models/vision_transformer/vision_transformer_model.py index 1611d6ad5..de58a83ef 100644 --- a/src/modalities/models/vision_transformer/vision_transformer_model.py +++ b/src/modalities/models/vision_transformer/vision_transformer_model.py @@ -70,9 +70,9 @@ def __init__( n_embd: int = 768, patch_size: int = 16, patch_stride: int = 16, - add_cls_token: bool = True, ) -> None: super().__init__() + self.input_rearrange = Rearrange("b T c h w -> b c T h w") self.conv = nn.Conv3d( in_channels=n_img_channels, out_channels=n_embd, @@ -83,16 +83,10 @@ def __init__( # See https://github.com/arogozhnikov/einops/wiki/Using-torch.compile-with-einops self.rearrange = Rearrange("b c T h w -> b T (h w) c") # TODO: this might change when implementing dataloader - self.cls_token = None - if add_cls_token: - self.cls_token = nn.Parameter(torch.zeros(1, 1, n_embd)) - def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.input_rearrange(x) x = self.conv(x) x = self.rearrange(x) - B, T = x.shape[:2] - if self.cls_token is not None: # TODO: remove cls token at the frame level - x = torch.cat([self.cls_token.repeat(B, T, 1, 1), x], dim=2) return x # [b T S D] @@ -196,6 +190,7 @@ def __init__( ) -> None: super().__init__() self.sample_key = sample_key + self.has_cls_token = add_cls_token self.prediction_key = prediction_key self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) self.block_size = self._calculate_block_size(self.img_size, patch_size, patch_stride, add_cls_token) @@ -212,7 +207,6 @@ def __init__( self.embedding_fn = VideoPatchEmbedding(n_img_channels, n_embd, patch_size, patch_stride) # [b T S D] self.time_embd = nn.Parameter(torch.randn(num_video_frames, 1, n_embd)) # [T,1,d] if add_cls_token: - # self.block_size -= 1 # to remove cls token at frame level n_latents += 1 # to count for a video level cls token self.latents = nn.Parameter(torch.randn(n_latents, n_embd)) # [R,d] self.rearrange = Rearrange("b T S D -> b (T S) D") @@ -277,7 +271,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: else: x = self.forward_images(x) if self.head: - if self.embedding_fn.cls_token is not None: + if self.has_cls_token: x = x[:, 0] else: x = x.mean(dim=1) From 0c0317aa02c4f1e62c8d61b7173532d02c1635cb Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Wed, 27 Mar 2024 15:38:46 +0000 Subject: [PATCH 069/161] feat: add conformer audio encoder --- pyproject.toml | 1 + .../models/audio_transformer/__init__.py | 0 .../audio_transformer_model.py | 107 ++++++++++++++++++ tests/models/audio_transformer/__init__.py | 0 .../test_audio_transformer_model.py | 77 +++++++++++++ 5 files changed, 185 insertions(+) create mode 100644 src/modalities/models/audio_transformer/__init__.py create mode 100644 src/modalities/models/audio_transformer/audio_transformer_model.py create mode 100644 tests/models/audio_transformer/__init__.py create mode 100644 tests/models/audio_transformer/test_audio_transformer_model.py diff --git a/pyproject.toml b/pyproject.toml index 6094d5256..2acb8d5d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ requires-python = ">=3.8,<3.12" description = "Modalities, a python framework for distributed and reproducible foundation model training." dependencies = [ "torch>=2.0", +"torchaudio", "tqdm", "pyyaml", "transformers", diff --git a/src/modalities/models/audio_transformer/__init__.py b/src/modalities/models/audio_transformer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/modalities/models/audio_transformer/audio_transformer_model.py b/src/modalities/models/audio_transformer/audio_transformer_model.py new file mode 100644 index 000000000..f9142c356 --- /dev/null +++ b/src/modalities/models/audio_transformer/audio_transformer_model.py @@ -0,0 +1,107 @@ +from typing import Annotated, Dict + +import torch +from pydantic import BaseModel, Field +from torch import nn +from torchaudio.models import Conformer + + +class AudioTransformerConfig(BaseModel): + sample_key: str + prediction_key: str + input_dims: Annotated[int, Field(ge=1)] + pre_conformer_dropout: Annotated[float, Field(lt=1.0)] + conformer_dropout: Annotated[float, Field(lt=1.0)] + n_heads: Annotated[int, Field(ge=1)] + n_embd: Annotated[int, Field(ge=1)] + n_layers: Annotated[int, Field(ge=1)] + depthwise_conv_kernel_size: Annotated[int, Field(ge=1)] + + +class PreConformer(nn.Module): + def __init__( + self, + *, + n_input_dims: int, + dropout: float, + ): + super().__init__() + self.subsampler = nn.Sequential( + nn.Conv1d( + in_channels=n_input_dims, + out_channels=n_input_dims, + kernel_size=2, + stride=2, + ), + nn.Conv1d( + in_channels=n_input_dims, + out_channels=n_input_dims, + kernel_size=2, + stride=2, + ), + ) + self.linear = nn.Linear(n_input_dims, n_input_dims) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + # x.shape: batch_size, n_input_dims, n_input_frames + + x = self.subsampler(x) # x.shape: batch_size, n_input_dims, ceil(n_input_frames / 4) + x = x.transpose(1, 2) + x = self.linear(x) # x.shape: batch_size, ceil(n_input_frames / 4), n_input_dims + x = self.dropout(x) + return x + + +class AudioTransformer(nn.Module): + def __init__( + self, + *, + sample_key: str, + prediction_key: str, + input_dims: int, + n_heads: int, + n_embd: int, + n_layers: int, + depthwise_conv_kernel_size: int, + pre_conformer_dropout: float, + conformer_dropout: float, + ): + super().__init__() + self.sample_key = sample_key + self.prediction_key = prediction_key + self.pre_conformer = PreConformer( + n_input_dims=input_dims, + dropout=pre_conformer_dropout, + ) + + self.conformer = Conformer( + input_dim=input_dims, + num_heads=n_heads, + ffn_dim=n_embd, + num_layers=n_layers, + depthwise_conv_kernel_size=depthwise_conv_kernel_size, + dropout=conformer_dropout, + ) + + self.post_conformer = nn.Sequential( + nn.Linear( + input_dims, + n_embd, + ), + nn.LayerNorm(n_embd), + ) + + def forward( + self, + inputs: Dict[str, tuple[torch.Tensor, torch.Tensor]], + ) -> Dict[str, tuple[torch.Tensor, torch.Tensor]]: + x, x_length = inputs[self.sample_key] # x.shape: batch_size, n_input_dims, n_input_frames + x = self.pre_conformer(x) # x.shape: batch_size, ceil(n_input_frames / 4), n_input_dims + x, x_length = self.conformer(x, x_length) # x.shape: batch_size, ceil(n_input_frames / 4), n_input_dims + x = self.post_conformer(x) + return {self.prediction_key: (x, x_length)} diff --git a/tests/models/audio_transformer/__init__.py b/tests/models/audio_transformer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/models/audio_transformer/test_audio_transformer_model.py b/tests/models/audio_transformer/test_audio_transformer_model.py new file mode 100644 index 000000000..173e8239a --- /dev/null +++ b/tests/models/audio_transformer/test_audio_transformer_model.py @@ -0,0 +1,77 @@ +import pytest +import torch + +from modalities.models.audio_transformer.audio_transformer_model import AudioTransformer + + +@pytest.fixture +def pre_conformer_config(): + return { + "input_dims": 80, + "dropout": 0.1, + } + + +@pytest.fixture +def audio_transformer_config(): + return { + "sample_key": "audio_feats", + "prediction_key": "audio_embeddings", + "n_heads": 4, + "n_embd": 512, + "n_layers": 2, + "depthwise_conv_kernel_size": 3, + "dropout": 0.1, + } + + +@pytest.fixture +def audio_transformer( + pre_conformer_config, + audio_transformer_config, +): + return AudioTransformer( + sample_key=audio_transformer_config["sample_key"], + prediction_key=audio_transformer_config["prediction_key"], + input_dims=pre_conformer_config["input_dims"], + n_heads=audio_transformer_config["n_heads"], + n_embd=audio_transformer_config["n_embd"], + n_layers=audio_transformer_config["n_layers"], + depthwise_conv_kernel_size=audio_transformer_config["depthwise_conv_kernel_size"], + pre_conformer_dropout=pre_conformer_config["dropout"], + conformer_dropout=audio_transformer_config["dropout"], + ) + + +@pytest.fixture +def dummy_input_div4(): + return {"audio_feats": (torch.randn(4, 80, 1000), torch.Tensor([1000 / 4] * 4))} + + +@pytest.fixture +def dummy_input_notdiv4(): + return {"audio_feats": (torch.randn(4, 80, 750), torch.Tensor([750 // 4] * 4))} + + +def test_audio_transformer_output_shape_div4( + dummy_input_div4, + audio_transformer, + audio_transformer_config, +): + output = audio_transformer(dummy_input_div4) + audio_embeddings, audio_lengths = output[audio_transformer_config["prediction_key"]] + assert audio_embeddings.shape[0] == 4 + assert audio_embeddings.shape[1] == 1000 / 4 + assert audio_embeddings.shape[2] == 512 + + +def test_audio_transformer_output_shape_notdiv4( + dummy_input_notdiv4, + audio_transformer, + audio_transformer_config, +): + output = audio_transformer(dummy_input_notdiv4) + audio_embeddings, audio_lengths = output[audio_transformer_config["prediction_key"]] + assert audio_embeddings.shape[0] == 4 + assert audio_embeddings.shape[1] == 750 // 4 + assert audio_embeddings.shape[2] == 512 From 5f63246bff68c6265e87efd98a3cbbfdbfa3ba3b Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Wed, 27 Mar 2024 15:41:52 +0000 Subject: [PATCH 070/161] feat: make CoCa audio compatible --- src/modalities/models/coca/coca_model.py | 111 +++++++++++++----- tests/models/coca/coca_config_audio.yaml | 40 +++++++ tests/models/coca/coca_config_av.yaml | 57 +++++++++ ...ca_config.yaml => coca_config_vision.yaml} | 10 +- tests/models/coca/test_coca.py | 79 +++++++++++-- 5 files changed, 254 insertions(+), 43 deletions(-) create mode 100644 tests/models/coca/coca_config_audio.yaml create mode 100644 tests/models/coca/coca_config_av.yaml rename tests/models/coca/{coca_config.yaml => coca_config_vision.yaml} (81%) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index 8a33055a4..09df2bd50 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, Field from torch import nn +from modalities.models.audio_transformer.audio_transformer_model import AudioTransformer, AudioTransformerConfig from modalities.models.coca.attention_pooling import AttentionPooling from modalities.models.coca.multi_modal_decoder import MultiModalTextDecoder from modalities.models.coca.text_decoder import TextDecoder @@ -18,6 +19,11 @@ from modalities.nn.attention import AttentionConfig +class AVConfig(BaseModel): + audio_transformer_config: AudioTransformerConfig + vision_transformer_config: VisionTransformerConfig + + class TextDecoderConfig(BaseModel): sample_key: str prediction_key: str @@ -37,15 +43,17 @@ class TextDecoderConfig(BaseModel): class CoCaConfig(BaseModel): prediction_key: str = "logits" - vision_embd_prediction_key: str # same key as vision encoder + modality_key: str = "modality" + modality_embd_prediction_key: str text_embd_prediction_key: str - vision_cls_prediction_key: str + modality_cls_prediction_key: str text_cls_prediction_key: str logit_scale_prediction_key: str - vision_encoder_config: VisionTransformerConfig + modality_encoder_config: AudioTransformerConfig | VisionTransformerConfig | AVConfig text_decoder_config: TextDecoderConfig n_pool_head: Annotated[int, Field(ge=1)] - n_vision_queries: Annotated[int, Field(ge=1)] + n_vision_queries: Annotated[int, Field(ge=1)] | None + n_audio_queries: Annotated[int, Field(ge=1)] | None bias_attn_pool: bool epsilon_attn_pool: Annotated[float, Field(ge=0.0)] weight_init: WeightInitializationConfig @@ -64,28 +72,64 @@ class CoCa(NNModel): def __init__( self, prediction_key: str, - vision_cls_prediction_key: str, - text_cls_prediction_key: str, - vision_embd_prediction_key: str, + modality_key: str, + modality_embd_prediction_key: str, text_embd_prediction_key: str, logit_scale_prediction_key: str, + modality_cls_prediction_key: str, + text_cls_prediction_key: str, n_vision_queries: int, + n_audio_queries: int, n_pool_head: int, bias_attn_pool: bool, epsilon_attn_pool: float, - vision_encoder_config: VisionTransformerConfig, + modality_encoder_config: VisionTransformerConfig | AudioTransformerConfig | AVConfig, text_decoder_config: TextDecoderConfig, weight_init: WeightInitializationConfig, ) -> None: super().__init__() + + self.AUDIO = 0 + self.VISION = 1 + self.prediction_key = prediction_key - self.vision_cls_prediction_key = vision_cls_prediction_key - self.text_cls_prediction_key = text_cls_prediction_key - self.vision_embd_prediction_key = vision_embd_prediction_key + self.modality_key = modality_key + self.modality_embd_prediction_key = modality_embd_prediction_key self.text_embd_prediction_key = text_embd_prediction_key self.logit_scale_prediction_key = logit_scale_prediction_key - self.vision_encoder = VisionTransformer(**dict(vision_encoder_config)) + self.modality_cls_prediction_key = modality_cls_prediction_key + self.text_cls_prediction_key = text_cls_prediction_key + + self.n_pool_head = n_pool_head + self.bias_attn_pool = bias_attn_pool + self.epsilon_attn_pool = epsilon_attn_pool + self.text_decoder_config = text_decoder_config + + if isinstance(modality_encoder_config, VisionTransformerConfig): + self.vision_encoder, self.vision_queries, self.vision_attn_pool = self._init_modality( + VisionTransformer, + modality_encoder_config, + n_vision_queries, + ) + elif isinstance(modality_encoder_config, AudioTransformerConfig): + self.audio_encoder, self.audio_queries, self.audio_attn_pool = self._init_modality( + AudioTransformer, + modality_encoder_config, + n_audio_queries, + ) + else: + self.vision_encoder, self.vision_queries, self.vision_attn_pool = self._init_modality( + VisionTransformer, + modality_encoder_config.vision_transformer_config, + n_vision_queries, + ) + self.audio_encoder, self.audio_queries, self.audio_attn_pool = self._init_modality( + AudioTransformer, + modality_encoder_config.audio_transformer_config, + n_audio_queries, + ) + self.text_decoder = TextDecoder( sample_key=text_decoder_config.sample_key, prediction_key=text_embd_prediction_key, @@ -121,16 +165,6 @@ def __init__( self.multimodal_decoder.lm_head.weight ) # https://paperswithcode.com/method/weight-tying - # vision_queries: 256 queries for multimodal cross attention and 1 as vision cls token for contrastive learning - self.vision_queries = nn.Parameter(torch.randn(n_vision_queries + 1, vision_encoder_config.n_embd)) - self.attn_pool = AttentionPooling( - n_embd=vision_encoder_config.n_embd, - n_head=n_pool_head, - bias=bias_attn_pool, - epsilon=epsilon_attn_pool, - attention_config=text_decoder_config.attention_config, - ) - # Logit scale for contrastive loss self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) @@ -146,6 +180,18 @@ def __init__( / math.sqrt(2 * (text_decoder_config.n_layer_text + text_decoder_config.n_layer_multimodal_text)), ) + def _init_modality(self, encoder_class, encoder_config, n_queries): + encoder = encoder_class(**dict(encoder_config)) + queries = nn.Parameter(torch.randn(n_queries + 1, encoder_config.n_embd)) + attn_pool = AttentionPooling( + n_embd=encoder_config.n_embd, + n_head=self.n_pool_head, + bias=self.bias_attn_pool, + epsilon=self.epsilon_attn_pool, + attention_config=self.text_decoder_config.attention_config, + ) + return encoder, queries, attn_pool + def _init_weights(self, module: nn.Module, weight_init: WeightInitializationConfig): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std) @@ -155,23 +201,34 @@ def _init_weights(self, module: nn.Module, weight_init: WeightInitializationConf torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std) def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - vision_embd, vision_cls_token = self._forward_encode_vision(inputs) + # TODO: The "modality_key" needs to be implemented. + if inputs[self.modality_key][0] == self.AUDIO: + modality_embd, modality_cls_token = self._forward_encode_audio(inputs) + if inputs[self.modality_key][0] == self.VISION: + modality_embd, modality_cls_token = self._forward_encode_vision(inputs) text_embd, text_cls_token = self._forward_encode_text(inputs) - logits = self._forward_decode(text_embd, vision_embd) + logits = self._forward_decode(text_embd, modality_embd) return { self.prediction_key: logits, - self.vision_cls_prediction_key: vision_cls_token, + self.modality_cls_prediction_key: modality_cls_token, self.text_cls_prediction_key: text_cls_token, self.logit_scale_prediction_key: self.logit_scale.exp(), } def _forward_encode_vision(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - vision_embd = self.vision_encoder(inputs)[self.vision_embd_prediction_key] + vision_embd = self.vision_encoder(inputs)[self.modality_embd_prediction_key] queries = repeat(self.vision_queries, "n d -> b n d", b=vision_embd.shape[0]) - vision_embd = self.attn_pool(queries, context=vision_embd) + vision_embd = self.vision_attn_pool(queries, context=vision_embd) vision_embd, vision_cls_token = vision_embd[:, :-1, :], F.normalize(vision_embd[:, -1, :], dim=-1) return vision_embd, vision_cls_token + def _forward_encode_audio(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + audio_embd, _ = self.audio_encoder(inputs)[self.modality_embd_prediction_key] + queries = repeat(self.audio_queries, "n d -> b n d", b=audio_embd.shape[0]) + audio_embd = self.audio_attn_pool(queries, context=audio_embd) + audio_embd, audio_cls_token = audio_embd[:, :-1, :], F.normalize(audio_embd[:, -1:, :], dim=-1) + return audio_embd, audio_cls_token + def _forward_encode_text(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: text_embd = self.text_decoder(inputs)[self.text_embd_prediction_key] text_embd, text_cls_token = text_embd[:, :-1, :], F.normalize(text_embd[:, -1, :], dim=-1) diff --git a/tests/models/coca/coca_config_audio.yaml b/tests/models/coca/coca_config_audio.yaml new file mode 100644 index 000000000..79c08df8b --- /dev/null +++ b/tests/models/coca/coca_config_audio.yaml @@ -0,0 +1,40 @@ +prediction_key: logits +modality_key: modality +modality_embd_prediction_key: modality_embeddings +text_embd_prediction_key: text_embeddings +modality_cls_prediction_key: modality_cls +text_cls_prediction_key: text_cls +modality_encoder_config: + sample_key: audio + prediction_key: modality_embeddings + input_dims: 128 + pre_conformer_dropout: 0.1 + conformer_dropout: 0.1 + n_heads: 8 + n_embd: 768 + n_layers: 17 + depthwise_conv_kernel_size: 31 +text_decoder_config: + sample_key: input_ids + prediction_key: text_embeddings + block_size: 1024 + vocab_size: 50304 + n_layer_text: 6 + n_layer_multimodal_text: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: fused_swiglu + epsilon: 1e-5 +n_pool_head: 8 +n_vision_queries: Null +n_audio_queries: 256 +bias_attn_pool: False +epsilon_attn_pool: 1e-5 +weight_init: + mean: 0.0 + std: 0.02 \ No newline at end of file diff --git a/tests/models/coca/coca_config_av.yaml b/tests/models/coca/coca_config_av.yaml new file mode 100644 index 000000000..30025d45f --- /dev/null +++ b/tests/models/coca/coca_config_av.yaml @@ -0,0 +1,57 @@ +prediction_key: logits +modality_key: modality +modality_embd_prediction_key: modality_embeddings +text_embd_prediction_key: text_embeddings +modality_cls_prediction_key: modality_cls +text_cls_prediction_key: text_cls +modality_encoder_config: + vision_transformer_config: + sample_key: images + prediction_key: modality_embeddings + img_size: 224 + n_classes: Null # Disable vision transformer head + n_layer: 12 + attention_config: + attention_engine_type: default_attention + n_head: 12 + n_embd: 768 + dropout: 0.0 + patch_size: 16 + patch_stride: 16 + n_img_channels: 3 + add_cls_token: False + bias: True + audio_transformer_config: + sample_key: audio + prediction_key: modality_embeddings + input_dims: 128 + pre_conformer_dropout: 0.1 + conformer_dropout: 0.1 + n_heads: 8 + n_embd: 768 + n_layers: 17 + depthwise_conv_kernel_size: 31 +text_decoder_config: + sample_key: input_ids + prediction_key: text_embeddings + block_size: 1024 + vocab_size: 50304 + n_layer_text: 6 + n_layer_multimodal_text: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: fused_swiglu + epsilon: 1e-5 +n_pool_head: 8 +n_vision_queries: 256 +n_audio_queries: 256 +bias_attn_pool: False +epsilon_attn_pool: 1e-5 +weight_init: + mean: 0.0 + std: 0.02 \ No newline at end of file diff --git a/tests/models/coca/coca_config.yaml b/tests/models/coca/coca_config_vision.yaml similarity index 81% rename from tests/models/coca/coca_config.yaml rename to tests/models/coca/coca_config_vision.yaml index 952cda66e..704f5c9a4 100644 --- a/tests/models/coca/coca_config.yaml +++ b/tests/models/coca/coca_config_vision.yaml @@ -1,11 +1,12 @@ prediction_key: logits -vision_embd_prediction_key: vision_embeddings +modality_key: modality +modality_embd_prediction_key: modality_embeddings text_embd_prediction_key: text_embeddings -vision_cls_prediction_key: vision_cls +modality_cls_prediction_key: modality_cls text_cls_prediction_key: text_cls -vision_encoder_config: +modality_encoder_config: sample_key: images - prediction_key: vision_embeddings + prediction_key: modality_embeddings img_size: 224 n_classes: Null # Disable vision transformer head n_layer: 6 @@ -37,6 +38,7 @@ text_decoder_config: epsilon: 1e-5 n_pool_head: 8 n_vision_queries: 256 +n_audio_queries: Null bias_attn_pool: False epsilon_attn_pool: 1e-5 weight_init: diff --git a/tests/models/coca/test_coca.py b/tests/models/coca/test_coca.py index 9cb27ccd6..b7facd0cf 100644 --- a/tests/models/coca/test_coca.py +++ b/tests/models/coca/test_coca.py @@ -11,39 +11,94 @@ from tests.conftest import _ROOT_DIR -def test_coca(): +def dummy_image_sample(): + input_image = torch.randn(1, 3, 224, 224) + text_decoder_vocab_size = 50304 + text_decoder_block_size = 1024 + input_text = torch.randint(0, text_decoder_vocab_size, (1, text_decoder_block_size)) + VISION = torch.tensor([1]) + return dict( + images=input_image, + input_ids=input_text, + modality=VISION, + ) + + +def dummy_audio_sample(): + audio_features = torch.randn(1, 128, 1000) + audio_len = torch.Tensor([1000 / 4]) + text_decoder_vocab_size = 50304 + text_decoder_block_size = 1024 + input_text = torch.randint(0, text_decoder_vocab_size, (1, text_decoder_block_size)) + AUDIO = torch.tensor([0]) + return dict( + audio=(audio_features, audio_len), + input_ids=input_text, + modality=AUDIO, + ) + + +@pytest.mark.parametrize( + "yaml,dummy_sample", + [ + ("tests/models/coca/coca_config_vision.yaml", dummy_image_sample()), + ("tests/models/coca/coca_config_audio.yaml", dummy_audio_sample()), + ], +) +def test_coca(yaml, dummy_sample): # Create model - config_file_path = _ROOT_DIR / Path("tests/models/coca/coca_config.yaml") + config_file_path = _ROOT_DIR / Path(yaml) config_dict = load_app_config_dict(config_file_path=config_file_path) coca_config = CoCaConfig.model_validate(config_dict) model = CoCa(**dict(coca_config)) - # Create dummy inputs - dummy_input_image = torch.randn(1, 3, 224, 224) - dummy_input_text = torch.randint( - 0, coca_config.text_decoder_config.vocab_size, (1, coca_config.text_decoder_config.block_size) - ) - dummy_input = dict(images=dummy_input_image, input_ids=dummy_input_text) - # Create optimizer optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Run one training step optimizer.zero_grad() - out = model(dummy_input) + out = model(dummy_sample) loss = out["logits"].sum() loss.backward() optimizer.step() # Test outputs assert "logits" in out - assert "vision_cls" in out + assert "modality_cls" in out assert "text_cls" in out assert out["logits"].shape == (1, 1024, 50304) - assert out["vision_cls"].shape == (1, 1, 768) + assert out["modality_cls"].shape == (1, 1, 768) assert out["text_cls"].shape == (1, 1, 768) +def test_coca_audio_vision_together(): + # Create model + config_file_path = _ROOT_DIR / Path("tests/models/coca/coca_config_av.yaml") + config_dict = load_app_config_dict(config_file_path=config_file_path) + coca_config = CoCaConfig.model_validate(config_dict) + model = CoCa(**dict(coca_config)) + + # Create optimizer + optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + + audio_sample = dummy_audio_sample() + image_sample = dummy_image_sample() + + # Run for image + optimizer.zero_grad() + out = model(image_sample) + loss = out["logits"].sum() + loss.backward() + optimizer.step() + + # Run for audio + optimizer.zero_grad() + out = model(audio_sample) + loss = out["logits"].sum() + loss.backward() + optimizer.step() + + @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This e2e test requires 1 GPU.") def test_e2e_coca_training_run_without_checkpoint(monkeypatch): monkeypatch.setenv("RANK", "0") From 857867e9469c5fcc669f890e25c96f9db8283fc0 Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Thu, 28 Mar 2024 15:10:04 +0000 Subject: [PATCH 071/161] test: change config and dummy dataset for E2E CoCa test --- config_files/training/config_example_coca.yaml | 17 +++++++++++++---- src/modalities/dataloader/dataset.py | 5 +++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/config_files/training/config_example_coca.yaml b/config_files/training/config_example_coca.yaml index 2e7d65213..8045410c8 100644 --- a/config_files/training/config_example_coca.yaml +++ b/config_files/training/config_example_coca.yaml @@ -36,6 +36,7 @@ collate_fn: sample_keys: - images - ${settings.referencing_keys.sample_key} + - modality target_keys: [] text_sample_key: ${settings.referencing_keys.sample_key} text_target_key: ${settings.referencing_keys.target_key} @@ -52,6 +53,9 @@ train_dataset: - sample_key: input_ids sample_shape: [256] sample_type: int + - sample_key: modality + sample_shape: [0] + sample_type: const val_dataset: component_key: dataset @@ -65,6 +69,9 @@ val_dataset: - sample_key: input_ids sample_shape: [256] sample_type: int + - sample_key: modality + sample_shape: [0] + sample_type: const train_dataloader: component_key: data_loader @@ -190,13 +197,14 @@ model: variant_key: coca config: prediction_key: logits - vision_embd_prediction_key: vision_embeddings + modality_key: modality + modality_embd_prediction_key: modality_embeddings text_embd_prediction_key: text_embeddings - vision_cls_prediction_key: vision_cls + modality_cls_prediction_key: modality_cls text_cls_prediction_key: text_cls - vision_encoder_config: + modality_encoder_config: sample_key: images - prediction_key: vision_embeddings + prediction_key: modality_embeddings img_size: 256 # 288 in the original coca n_classes: Null # Disable vision transformer head n_layer: 12 @@ -228,6 +236,7 @@ model: epsilon: 1e-5 n_pool_head: 12 n_vision_queries: 256 + n_audio_queries: Null bias_attn_pool: False epsilon_attn_pool: 1e-5 weight_init: diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index b6513fa33..ce7180d5f 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -40,6 +40,7 @@ def _check_if_inbounds(self, idx: int): class DummySampleDataType(str, Enum): FLOAT = "float" INT = "int" + CONSTANT = "const" class DummySampleConfig(BaseModel): @@ -64,6 +65,8 @@ def __init__(self, num_samples: int, sample_definition: Tuple[DummySampleConfig] self.num_samples = num_samples self.sample_definition = sample_definition + self.VISION = 1 + def __len__(self) -> int: return self.num_samples @@ -77,6 +80,8 @@ def _create_random_sample(self): data = np.random.randn(*s.sample_shape) elif s.sample_type == DummySampleDataType.INT: data = np.random.randint(low=0, high=512, size=s.sample_shape) + elif s.sample_type == DummySampleDataType.CONSTANT: + data = self.VISION else: raise NotImplementedError(f"DummyDataset does not support type { s.sample_type}") sample[s.sample_key] = data From 3f10bfa1fd76a65a08121872841cfdef9925ce8e Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 17 Jun 2024 11:56:25 +0200 Subject: [PATCH 072/161] fix: webdataset with multiple dataset builders --- src/modalities/dataloader/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index b6513fa33..576506d20 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -560,7 +560,7 @@ def __init__( for b in self.builders: datasets.append(b.web_dataset) dataset = wds.RandomMix(datasets, self.mixing_ratios) # Apply mixing at sample level - self.pipeline.extend(dataset.pipeline) + self.pipeline.append(dataset) else: self.pipeline.extend(self.builders[0].web_dataset.pipeline) From 599ed66010a54d0b9bf7d51ae5f4606a02539b27 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Mon, 17 Jun 2024 11:56:51 +0200 Subject: [PATCH 073/161] fix: vision transformer config --- .../models/vision_transformer/vision_transformer_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/modalities/models/vision_transformer/vision_transformer_model.py b/src/modalities/models/vision_transformer/vision_transformer_model.py index de58a83ef..7cc12c5e7 100644 --- a/src/modalities/models/vision_transformer/vision_transformer_model.py +++ b/src/modalities/models/vision_transformer/vision_transformer_model.py @@ -19,6 +19,7 @@ class VisionTransformerConfig(BaseModel): attention_config: AttentionConfig = None n_head: Annotated[int, Field(ge=1)] = 8 n_embd: Annotated[int, Field(ge=1)] = 768 + ffn_hidden: Annotated[int, Field(ge=1)] = 3072 dropout: Annotated[float, Field(ge=0.0)] = 0.0 patch_size: Annotated[int, Field(ge=1)] = 16 patch_stride: Annotated[int, Field(ge=1)] = 16 From 2d3c3fa8ea62d66930b27b937d4bd7d7784508ec Mon Sep 17 00:00:00 2001 From: Sogol Haghighat Date: Mon, 8 Jul 2024 17:03:11 +0200 Subject: [PATCH 074/161] refactor: add pyav dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 6094d5256..d4d509bd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "einops>=0.7.0", "webdataset>=0.2.86", "timm>=0.9.16", + "pyav", "flash-attn", # install this directly via `pip install flash-attn --no-build-isolation` ] From 88ae41fa077efaa02e1d05445c53c9e27f422978 Mon Sep 17 00:00:00 2001 From: Sogol Haghighat Date: Mon, 8 Jul 2024 17:04:50 +0200 Subject: [PATCH 075/161] fix: block_size for video coca --- .../models/vision_transformer/vision_transformer_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/modalities/models/vision_transformer/vision_transformer_model.py b/src/modalities/models/vision_transformer/vision_transformer_model.py index 7cc12c5e7..a696e7c10 100644 --- a/src/modalities/models/vision_transformer/vision_transformer_model.py +++ b/src/modalities/models/vision_transformer/vision_transformer_model.py @@ -209,12 +209,15 @@ def __init__( self.time_embd = nn.Parameter(torch.randn(num_video_frames, 1, n_embd)) # [T,1,d] if add_cls_token: n_latents += 1 # to count for a video level cls token + self.block_size -= 1 self.latents = nn.Parameter(torch.randn(n_latents, n_embd)) # [R,d] self.rearrange = Rearrange("b T S D -> b (T S) D") else: self.embedding_fn = ImagePatchEmbedding(n_img_channels, n_embd, patch_size, patch_stride, add_cls_token) - self.positional_embedding_fn = nn.Embedding(num_embeddings=self.block_size, embedding_dim=n_embd) # [S D] + self.positional_embedding_fn = nn.Embedding( + num_embeddings=self.block_size, embedding_dim=n_embd + ) # [S D] #TODO: this needs to be adjusted for video with cls_token block_classes = {"Video": PerceiverTransformerBlock, "Image": VisionTransformerBlock} self.blocks = nn.ModuleList( From fe23a69461a5d77410681a131f6324872153979b Mon Sep 17 00:00:00 2001 From: Sogol Haghighat Date: Mon, 8 Jul 2024 17:07:11 +0200 Subject: [PATCH 076/161] test: add more test cases and merge video and image tests for vision transformer --- .../test_vision_transformer.py | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/tests/models/vision_transformer/test_vision_transformer.py b/tests/models/vision_transformer/test_vision_transformer.py index bb68a7f03..41a1afb2a 100644 --- a/tests/models/vision_transformer/test_vision_transformer.py +++ b/tests/models/vision_transformer/test_vision_transformer.py @@ -8,16 +8,33 @@ from tests.conftest import _ROOT_DIR -def test_vision_transformer(): +@pytest.mark.parametrize( + "input,sample_key,n_classes,num_video_frames,add_cls_token,out_put", + [ + (torch.randn(1, 3, 224, 224), "images", 1000, 1, True, (1, 1000)), + (torch.randn(1, 3, 224, 224), "images", None, 1, True, (1, 197, 768)), + (torch.randn(1, 3, 224, 224), "images", None, 1, False, (1, 196, 768)), + (torch.randn(1, 3, 224, 224), "images", 1000, 1, False, (1, 1000)), + (torch.randn(1, 16, 3, 224, 224), "videos", 1000, 16, True, (1, 1000)), + (torch.randn(1, 16, 3, 224, 224), "videos", None, 16, True, (1, 65, 768)), + (torch.randn(1, 16, 3, 224, 224), "videos", None, 16, False, (1, 64, 768)), + (torch.randn(1, 16, 3, 224, 224), "videos", 1000, 16, False, (1, 1000)), + ], +) +def test_vision_transformer(input, sample_key, n_classes, num_video_frames, add_cls_token, out_put): # Create model config_file_path = _ROOT_DIR / Path("tests/models/vision_transformer/vision_transformer_config.yaml") config_dict = load_app_config_dict(config_file_path=config_file_path) config = VisionTransformerConfig.model_validate(config_dict) + config.sample_key = sample_key + config.n_classes = n_classes + config.num_video_frames = num_video_frames + config.add_cls_token = add_cls_token + model = VisionTransformer(**dict(config)) # Create dummy inputs - dummy_input_image = torch.randn(1, 3, 224, 224) - dummy_input = dict(images=dummy_input_image) + dummy_input = {sample_key: input} # Create optimizer optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) @@ -31,32 +48,7 @@ def test_vision_transformer(): # Test outputs assert "logits" in out - assert out["logits"].shape == (1, 1000) - - # Test for video input - # Create model - config_file_path2 = _ROOT_DIR / Path("tests/models/vision_transformer/vision_transformer_config2.yaml") - config_dict2 = load_app_config_dict(config_file_path=config_file_path2) - config2 = VisionTransformerConfig.model_validate(config_dict2) - model2 = VisionTransformer(**dict(config2)) - - # Create dummy inputs - dummy_input_video = torch.randn(1, 3, 16, 224, 224) # [b c T h w] - dummy_input2 = dict(videos=dummy_input_video) - - # Create optimizer - optimizer2 = torch.optim.SGD(model2.parameters(), lr=0.001, momentum=0.9) - - # Run one training step - optimizer2.zero_grad() - out2 = model2(dummy_input2) - loss2 = out2["logits"].sum() - loss2.backward() - optimizer2.step() - - # Test outputs - assert "logits" in out2 - assert out2["logits"].shape == (1, 1000) + assert out["logits"].shape == out_put @pytest.mark.parametrize( From 256f6b219fc6f899c4d813556babfd6cd774d63f Mon Sep 17 00:00:00 2001 From: Sogol Haghighat Date: Fri, 12 Jul 2024 10:30:37 +0200 Subject: [PATCH 077/161] fix: hard coded num_frames in video transform --- src/modalities/dataloader/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 576506d20..a65b0ca1b 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -341,7 +341,7 @@ def __init__( transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) - self.temporal_transform = RandomTemporalCrop(num_frames=16) + self.temporal_transform = RandomTemporalCrop(num_frames=num_frames) def __call__(self, video): video = video[0] From 28a3b82b693e1ced10fa01cfd64d84d4688f89d3 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Mon, 5 Aug 2024 11:33:01 +0200 Subject: [PATCH 078/161] refactor: use decord for loading videos --- pyproject.toml | 1 + src/modalities/dataloader/dataset.py | 34 ++++++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d4d509bd5..cd8837b9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "pyyaml", "transformers", "datasets", + "decord", "protobuf", "SentencePiece", "accelerate", diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index a65b0ca1b..b9ed0a8a6 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -1,10 +1,13 @@ from __future__ import annotations +import io import random +import re from enum import Enum from pathlib import Path from typing import Annotated, Dict, List, Optional, Tuple, Union +import decord import jq import numpy as np import torch @@ -13,7 +16,7 @@ from timm.data import create_transform from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torch.utils.data.dataset import Dataset as TorchdataSet -from torchvision import transforms +from torchvision.transforms import v2 as transforms from tqdm import tqdm from transformers import BatchEncoding @@ -25,6 +28,8 @@ from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper from modalities.util import flatten_dict +decord.bridge.set_bridge("torch") + class Dataset(TorchdataSet): def __init__(self, raw_data_path: Path, block_size: int, sample_key: str): @@ -356,6 +361,31 @@ class MultimodalWebDatasetBuilderConfig(BaseModel): num_samples: Annotated[int, Field(ge=1)] +def decord_video(key, data): + """Based on the torch_video decoder in webdataset + https://github.com/webdataset/webdataset/blob/5b12e0ba78bfb64741add2533c5d1e4cf088ffff/webdataset/autodecode.py#L394 + """ + extension = re.sub(r".*[.]", "", key) + if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split(): + return None + + file_obj = io.BytesIO(data) + + # we could replace this with torchaudio.load(data) + ar = decord.AudioReader(file_obj, mono=False) + audio = ar[:] + + # reset to start of file + file_obj.seek(0) + vr = decord.VideoReader(file_obj) + clip_num_frames = 64 + # sample clip_num_frames uniformly from the full video + frame_ids = torch.linspace(0, len(vr) - 1, clip_num_frames, dtype=torch.int64) + frames = vr.get_batch(frame_ids.tolist()) # T x H x W x C + + return (frames, audio) + + # @register_component("dataset", "web_dataset_builder", MultimodalWebDatasetBuilderConfig) class MultimodalWebDatasetBuilder: def __init__( @@ -386,7 +416,7 @@ def __init__( self.modality_to_decode_fn = { ModalityEnum.TEXT: None, ModalityEnum.IMAGE: "pil", - ModalityEnum.VIDEO: wds.torch_video, + ModalityEnum.VIDEO: decord_video, ModalityEnum.AUDIO: wds.torch_audio, } From c0e8d1f40682c4521abc0f6c39eaf573e49b8c58 Mon Sep 17 00:00:00 2001 From: mmaurya Date: Tue, 14 May 2024 08:17:00 +0000 Subject: [PATCH 079/161] fix: incorrect variable usage and audio input shape --- .../models/audio_transformer/audio_transformer_model.py | 5 +++-- src/modalities/models/coca/coca_model.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/modalities/models/audio_transformer/audio_transformer_model.py b/src/modalities/models/audio_transformer/audio_transformer_model.py index f9142c356..a265813f6 100644 --- a/src/modalities/models/audio_transformer/audio_transformer_model.py +++ b/src/modalities/models/audio_transformer/audio_transformer_model.py @@ -48,7 +48,7 @@ def forward( self, x: torch.Tensor, ) -> torch.Tensor: - # x.shape: batch_size, n_input_dims, n_input_frames + x = x.transpose(1, 2) # x.shape: batch_size, n_input_dims, n_input_frames x = self.subsampler(x) # x.shape: batch_size, n_input_dims, ceil(n_input_frames / 4) x = x.transpose(1, 2) @@ -100,7 +100,8 @@ def forward( self, inputs: Dict[str, tuple[torch.Tensor, torch.Tensor]], ) -> Dict[str, tuple[torch.Tensor, torch.Tensor]]: - x, x_length = inputs[self.sample_key] # x.shape: batch_size, n_input_dims, n_input_frames + x = inputs[self.sample_key] # x.shape: batch_size, n_input_dims, n_input_frames + x_length = inputs["feats_len"] x = self.pre_conformer(x) # x.shape: batch_size, ceil(n_input_frames / 4), n_input_dims x, x_length = self.conformer(x, x_length) # x.shape: batch_size, ceil(n_input_frames / 4), n_input_dims x = self.post_conformer(x) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index 09df2bd50..d7f7ccab0 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -234,10 +234,10 @@ def _forward_encode_text(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.T text_embd, text_cls_token = text_embd[:, :-1, :], F.normalize(text_embd[:, -1, :], dim=-1) return text_embd, text_cls_token - def _forward_decode(self, text_embd: torch.Tensor, vision_embd: torch.Tensor) -> torch.Tensor: + def _forward_decode(self, text_embd: torch.Tensor, modality_embd: torch.Tensor) -> torch.Tensor: decoder_inputs = { self.text_embd_prediction_key: text_embd, - "context": vision_embd, + "context": modality_embd, } decoder_outputs = self.multimodal_decoder(decoder_inputs) logits = decoder_outputs[self.multimodal_decoder.prediction_key] From 83accc1337fb347a8eb95e19463425ddc6b3db42 Mon Sep 17 00:00:00 2001 From: mmaurya Date: Tue, 14 May 2024 08:27:31 +0000 Subject: [PATCH 080/161] fix: to avoid torch.tensor(tensor) --- src/modalities/models/coca/collator.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/modalities/models/coca/collator.py b/src/modalities/models/coca/collator.py index e220aa677..c4570c2e5 100644 --- a/src/modalities/models/coca/collator.py +++ b/src/modalities/models/coca/collator.py @@ -32,14 +32,18 @@ def __init__(self, sample_keys: List[str], target_keys: List[str], text_sample_k def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: samples = { - sample_key: torch.stack([self._prepare_sample(d[sample_key]) for d in batch]) + sample_key: torch.stack( + [torch.tensor(d[sample_key]) if type(d[sample_key]) != torch.Tensor else d[sample_key] for d in batch] + ) for sample_key in self.sample_keys } if "attention_mask" in batch[0]: samples["attention_mask"] = torch.stack([self._prepare_sample(d["attention_mask"]) for d in batch]) targets = { - target_key: torch.stack([self._prepare_sample(d[target_key]) for d in batch]) + target_key: torch.stack( + [torch.tensor(d[target_key]) if type(d[target_key]) != torch.Tensor else d[target_key] for d in batch] + ) for target_key in self.target_keys } @@ -52,9 +56,3 @@ def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: samples["attention_mask"] = samples["attention_mask"][:, :-1] return DatasetBatch(targets=targets, samples=samples) - - @staticmethod - def _prepare_sample(x): - if isinstance(x, torch.Tensor): - return x - return torch.tensor(x) From 54906d62e0e5ef4b5617f34a9758a621f041af0b Mon Sep 17 00:00:00 2001 From: mmaurya Date: Tue, 14 May 2024 08:51:46 +0000 Subject: [PATCH 081/161] fix: add argument to ignore padding indices --- src/modalities/loss_functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index 8c1ec96ba..957fa4021 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -40,7 +40,10 @@ def __init__(self, target_key: str, prediction_key: str, weight: float, tag: str self.target_key = target_key self.prediction_key = prediction_key # Mean over the tokens in the local-batch (batch per rank) - self.loss_fun = TorchCrossEntropyLoss(reduction="mean") + self.loss_fun = TorchCrossEntropyLoss( + reduction="mean", + ignore_index=0, + ) def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: labels = forward_batch.get_targets(self.target_key) From 67d3778ddfb9356682a277447e2b5196a9182d34 Mon Sep 17 00:00:00 2001 From: mmaurya Date: Tue, 14 May 2024 09:16:07 +0000 Subject: [PATCH 082/161] test: uptate tests to comply with changes --- config_files/training/config_example_coca.yaml | 4 ++-- .../audio_transformer/test_audio_transformer_model.py | 4 ++-- tests/models/coca/test_coca.py | 9 +++++---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/config_files/training/config_example_coca.yaml b/config_files/training/config_example_coca.yaml index 8045410c8..8004e190b 100644 --- a/config_files/training/config_example_coca.yaml +++ b/config_files/training/config_example_coca.yaml @@ -54,7 +54,7 @@ train_dataset: sample_shape: [256] sample_type: int - sample_key: modality - sample_shape: [0] + sample_shape: [1] sample_type: const val_dataset: @@ -70,7 +70,7 @@ val_dataset: sample_shape: [256] sample_type: int - sample_key: modality - sample_shape: [0] + sample_shape: [1] sample_type: const train_dataloader: diff --git a/tests/models/audio_transformer/test_audio_transformer_model.py b/tests/models/audio_transformer/test_audio_transformer_model.py index 173e8239a..7131c8de6 100644 --- a/tests/models/audio_transformer/test_audio_transformer_model.py +++ b/tests/models/audio_transformer/test_audio_transformer_model.py @@ -45,12 +45,12 @@ def audio_transformer( @pytest.fixture def dummy_input_div4(): - return {"audio_feats": (torch.randn(4, 80, 1000), torch.Tensor([1000 / 4] * 4))} + return {"audio_feats": torch.randn(4, 1000, 80), "feats_len": torch.Tensor([1000 / 4] * 4)} @pytest.fixture def dummy_input_notdiv4(): - return {"audio_feats": (torch.randn(4, 80, 750), torch.Tensor([750 // 4] * 4))} + return {"audio_feats": torch.randn(4, 750, 80), "feats_len": torch.Tensor([750 / 4] * 4)} def test_audio_transformer_output_shape_div4( diff --git a/tests/models/coca/test_coca.py b/tests/models/coca/test_coca.py index b7facd0cf..3f2bd7ba6 100644 --- a/tests/models/coca/test_coca.py +++ b/tests/models/coca/test_coca.py @@ -25,14 +25,15 @@ def dummy_image_sample(): def dummy_audio_sample(): - audio_features = torch.randn(1, 128, 1000) + audio_features = torch.randn(1, 1000, 128) audio_len = torch.Tensor([1000 / 4]) text_decoder_vocab_size = 50304 text_decoder_block_size = 1024 input_text = torch.randint(0, text_decoder_vocab_size, (1, text_decoder_block_size)) AUDIO = torch.tensor([0]) return dict( - audio=(audio_features, audio_len), + audio=audio_features, + feats_len=audio_len, input_ids=input_text, modality=AUDIO, ) @@ -67,8 +68,8 @@ def test_coca(yaml, dummy_sample): assert "modality_cls" in out assert "text_cls" in out assert out["logits"].shape == (1, 1024, 50304) - assert out["modality_cls"].shape == (1, 1, 768) - assert out["text_cls"].shape == (1, 1, 768) + assert out["modality_cls"].shape == (1, 768) + assert out["text_cls"].shape == (1, 768) def test_coca_audio_vision_together(): From 0b1698d575f2e93ef36e1cba93c474bb09c9771d Mon Sep 17 00:00:00 2001 From: mmaurya Date: Tue, 14 May 2024 09:31:51 +0000 Subject: [PATCH 083/161] chore: add configs These can help run audio-only, vision-only or audio-vision experiments! --- config_files/config_example_coca_audio.yaml | 291 +++++++++++++++ .../config_example_coca_audio_vision.yaml | 339 ++++++++++++++++++ config_files/config_example_coca_vision.yaml | 328 +++++++++++++++++ 3 files changed, 958 insertions(+) create mode 100644 config_files/config_example_coca_audio.yaml create mode 100644 config_files/config_example_coca_audio_vision.yaml create mode 100644 config_files/config_example_coca_vision.yaml diff --git a/config_files/config_example_coca_audio.yaml b/config_files/config_example_coca_audio.yaml new file mode 100644 index 000000000..4e7af6437 --- /dev/null +++ b/config_files/config_example_coca_audio.yaml @@ -0,0 +1,291 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + training: + callback_interval_in_samples: 10000 + global_num_training_samples: 1925398 + global_num_seen_samples: 0 + do_apply_activation_checkpointing: true + gradient_acc_steps: 1 + local_train_micro_batch_size: 124 + sequence_length: 512 + gradient_clipping: + mode: p2_norm + threshold: 1.0 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints + +tokenizer: + component_key: tokenizer + variant_key: gpt2_tokenizer_fast + config: + tokenizer_file: data/tokenizer/tokenizer_gpt2.json + +collate_fn: + component_key: collate_fn + variant_key: coca_collator + config: + sample_keys: + - feats + - feats_len + - ${settings.referencing_keys.sample_key} + - modality + target_keys: [] + text_sample_key: ${settings.referencing_keys.sample_key} + text_target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: simple_dataset + config: + type_: train + audio_dataset_arrows: gertv-arrow-remove-na/train + vision_dataset_arrows: coco_captions/val + bpe_to_ind: + bpecodes: bpecodesgertv500 + num_feats: 80 + freq_domain_mask_length: 30 + time_domain_mask_length: 100 + + +val_dataset: + component_key: dataset + variant_key: simple_dataset + config: + type_: val + audio_dataset_arrows: gertv-arrow-remove-na/test + vision_dataset_arrows: coco_captions/val + bpe_to_ind: + bpecodes: bpecodesgertv500 + num_feats: 80 + freq_domain_mask_length: 30 + time_domain_mask_length: 100 + + +train_dataloader: + component_key: data_loader + variant_key: repeating_data_loader + config: + dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: true + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + + num_epochs: 1 + + +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "val" + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + +checkpointing: + component_key: checkpointing + variant_key: default + config: + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: FP_16 + sharding_strategy: FULL_SHARD + block_names: [TransformerBlock, AudioTransformer] + + +captioning_loss: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: logits + +contrastive_loss: + component_key: loss + variant_key: nce_loss + config: + prediction_key1: ${model.config.modality_cls_prediction_key} + prediction_key2: ${model.config.text_cls_prediction_key} + tag: contrastive_loss + +loss_fn: + - instance_key: captioning_loss + pass_type: BY_REFERENCE + - instance_key: contrastive_loss + pass_type: BY_REFERENCE + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: FP_16 + sharding_strategy: FULL_SHARD + block_names: [TransformerBlock, AudioTransformer] + +model: + component_key: model + variant_key: coca + config: + prediction_key: logits + modality_key: modality + modality_embd_prediction_key: modality_embeddings + text_embd_prediction_key: text_embeddings + modality_cls_prediction_key: modality_cls + text_cls_prediction_key: text_cls + modality_encoder_config: + sample_key: feats + prediction_key: modality_embeddings + input_dims: 80 + pre_conformer_dropout: 0.1 + conformer_dropout: 0.1 + n_heads: 4 + n_embd: 512 + n_layers: 3 + depthwise_conv_kernel_size: 31 + text_decoder_config: + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${model.config.prediction_key} + block_size: 512 + vocab_size: 634 + n_layer_text: 1 + n_layer_multimodal_text: 2 + attention_config: + attention_engine_type: default_attention + n_head: 4 + ffn_hidden: 1024 + n_embd: 512 + dropout: 0.0 + bias: true + activation: fused_swiglu + epsilon: 1e-5 + n_pool_head: 8 + n_vision_queries: Null + n_audio_queries: 256 + bias_attn_pool: False + epsilon_attn_pool: 1e-5 + weight_init: + mean: 0.0 + std: 0.02 + +scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: 962699 + pct_start: 0.01 + anneal_strategy: cos + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.00001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_samples: ${settings.training.global_num_seen_samples} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: "." diff --git a/config_files/config_example_coca_audio_vision.yaml b/config_files/config_example_coca_audio_vision.yaml new file mode 100644 index 000000000..0716a2311 --- /dev/null +++ b/config_files/config_example_coca_audio_vision.yaml @@ -0,0 +1,339 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + training: + callback_interval_in_samples: 10000 + global_num_training_samples: 1925398 + global_num_seen_samples: 0 + do_apply_activation_checkpointing: true + gradient_acc_steps: 1 + local_train_micro_batch_size: 124 + sequence_length: 512 + gradient_clipping: + mode: p2_norm + threshold: 1.0 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints + +tokenizer: + component_key: tokenizer + variant_key: gpt2_tokenizer_fast + config: + tokenizer_file: data/tokenizer/tokenizer_gpt2.json + +collate_fn: + component_key: collate_fn + variant_key: coca_collator + config: + sample_keys: + - feats + - feats_len + - ${settings.referencing_keys.sample_key} + - modality + target_keys: [] + text_sample_key: ${settings.referencing_keys.sample_key} + text_target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: simple_dataset + config: + type_: train + audio_dataset_arrows: gertv-arrow-remove-na/train + vision_dataset_arrows: coco_captions_arrow/train + bpe_to_ind: bpe_to_ind_test.pkl + bpecodes: bpecodes_test + num_feats: 80 + freq_domain_mask_length: 30 + time_domain_mask_length: 100 + + +val_dataset: + component_key: dataset + variant_key: simple_dataset + config: + type_: val + audio_dataset_arrows: gertv-arrow-remove-na/test + vision_dataset_arrows: coco_captions_arrow/val + bpe_to_ind: bpe_to_ind_test.pkl + bpecodes: bpecodes_test + num_feats: 80 + freq_domain_mask_length: 30 + time_domain_mask_length: 100 + + +# train_dataloader: +# component_key: data_loader +# variant_key: repeating_data_loader +# config: +# dataloader: +# component_key: data_loader +# variant_key: default +# config: +# num_workers: 2 +# pin_memory: true +# shuffle: false +# dataloader_tag: "train" +# dataset: +# instance_key: train_dataset +# pass_type: BY_REFERENCE +# batch_sampler: +# component_key: batch_sampler +# variant_key: default +# config: +# batch_size: ${settings.training.local_train_micro_batch_size} +# drop_last: false +# sampler: +# component_key: sampler +# variant_key: distributed_sampler +# config: +# rank: ${settings.cuda_env.global_rank} +# num_replicas: ${settings.cuda_env.world_size} +# shuffle: true +# dataset: +# instance_key: train_dataset +# pass_type: BY_REFERENCE +# collate_fn: +# instance_key: collate_fn +# pass_type: BY_REFERENCE + +# num_epochs: 1 + + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "val" + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + +checkpointing: + component_key: checkpointing + variant_key: default + config: + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: FP_16 + sharding_strategy: FULL_SHARD + block_names: [TransformerBlock, AudioTransformer, VisionTransformerBlock] + + +captioning_loss: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: logits + +contrastive_loss: + component_key: loss + variant_key: nce_loss + config: + prediction_key1: ${model.config.modality_cls_prediction_key} + prediction_key2: ${model.config.text_cls_prediction_key} + tag: contrastive_loss + +loss_fn: + - instance_key: captioning_loss + pass_type: BY_REFERENCE + - instance_key: contrastive_loss + pass_type: BY_REFERENCE + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: FP_16 + sharding_strategy: FULL_SHARD + block_names: [TransformerBlock, AudioTransformer, VisionTransformerBlock] + +model: + component_key: model + variant_key: coca + config: + prediction_key: logits + modality_key: modality + modality_embd_prediction_key: modality_embeddings + text_embd_prediction_key: text_embeddings + modality_cls_prediction_key: modality_cls + text_cls_prediction_key: text_cls + modality_encoder_config: + vision_transformer_config: + sample_key: feats # need to fix these + prediction_key: modality_embeddings + img_size: 224 + n_classes: Null # Disable vision transformer head + n_layer: 3 + attention_config: + attention_engine_type: default_attention + n_head: 4 + n_embd: 512 + dropout: 0.0 + patch_size: 16 + patch_stride: 16 + n_img_channels: 3 + add_cls_token: False + bias: True + audio_transformer_config: + sample_key: feats + prediction_key: modality_embeddings + input_dims: 80 + pre_conformer_dropout: 0.1 + conformer_dropout: 0.1 + n_heads: 4 + n_embd: 512 + n_layers: 3 + depthwise_conv_kernel_size: 31 + text_decoder_config: + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${model.config.prediction_key} + block_size: 512 + vocab_size: 634 + n_layer_text: 1 + n_layer_multimodal_text: 2 + attention_config: + attention_engine_type: default_attention + n_head: 4 + ffn_hidden: 1024 + n_embd: 512 + dropout: 0.0 + bias: true + activation: fused_swiglu + epsilon: 1e-5 + n_pool_head: 8 + n_vision_queries: 256 + n_audio_queries: 256 + bias_attn_pool: False + epsilon_attn_pool: 1e-5 + weight_init: + mean: 0.0 + std: 0.02 + +scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: 962699 + pct_start: 0.01 + anneal_strategy: cos + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.00001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_samples: ${settings.training.global_num_seen_samples} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: "." diff --git a/config_files/config_example_coca_vision.yaml b/config_files/config_example_coca_vision.yaml new file mode 100644 index 000000000..b93a3da0d --- /dev/null +++ b/config_files/config_example_coca_vision.yaml @@ -0,0 +1,328 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + training: + callback_interval_in_samples: 10000 + global_num_training_samples: 1925398 + global_num_seen_samples: 0 + do_apply_activation_checkpointing: true + gradient_acc_steps: 1 + local_train_micro_batch_size: 512 + sequence_length: 512 + gradient_clipping: + mode: p2_norm + threshold: 1.0 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints + +tokenizer: + component_key: tokenizer + variant_key: gpt2_tokenizer_fast + config: + tokenizer_file: data/tokenizer/tokenizer_gpt2.json + +collate_fn: + component_key: collate_fn + variant_key: coca_collator + config: + sample_keys: + - feats + - feats_len + - ${settings.referencing_keys.sample_key} + - modality + target_keys: [] + text_sample_key: ${settings.referencing_keys.sample_key} + text_target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: simple_dataset + config: + type_: train + audio_dataset_arrows: gertv-arrow-remove-na/train + vision_dataset_arrows: coco_captions_arrow/train + bpe_to_ind: + bpecodes: bpecodescococaptions500 + num_feats: 80 + freq_domain_mask_length: 30 + time_domain_mask_length: 100 + + +val_dataset: + component_key: dataset + variant_key: simple_dataset + config: + type_: val + audio_dataset_arrows: gertv-arrow-remove-na/test + vision_dataset_arrows: coco_captions_arrow/val + bpe_to_ind: + bpecodes: bpecodescococaptions500 + num_feats: 80 + freq_domain_mask_length: 30 + time_domain_mask_length: 100 + + +# train_dataloader: +# component_key: data_loader +# variant_key: repeating_data_loader +# config: +# dataloader: +# component_key: data_loader +# variant_key: default +# config: +# num_workers: 2 +# pin_memory: true +# shuffle: false +# dataloader_tag: "train" +# dataset: +# instance_key: train_dataset +# pass_type: BY_REFERENCE +# batch_sampler: +# component_key: batch_sampler +# variant_key: default +# config: +# batch_size: ${settings.training.local_train_micro_batch_size} +# drop_last: false +# sampler: +# component_key: sampler +# variant_key: distributed_sampler +# config: +# rank: ${settings.cuda_env.global_rank} +# num_replicas: ${settings.cuda_env.world_size} +# shuffle: true +# dataset: +# instance_key: train_dataset +# pass_type: BY_REFERENCE +# collate_fn: +# instance_key: collate_fn +# pass_type: BY_REFERENCE + +# num_epochs: 1 + + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "val" + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + +checkpointing: + component_key: checkpointing + variant_key: default + config: + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: FP_16 + sharding_strategy: FULL_SHARD + block_names: [TransformerBlock, VisionTransformerBlock] + + +captioning_loss: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: logits + +contrastive_loss: + component_key: loss + variant_key: nce_loss + config: + prediction_key1: ${model.config.modality_cls_prediction_key} + prediction_key2: ${model.config.text_cls_prediction_key} + tag: contrastive_loss + +loss_fn: + - instance_key: captioning_loss + pass_type: BY_REFERENCE + - instance_key: contrastive_loss + pass_type: BY_REFERENCE + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: FP_16 + sharding_strategy: FULL_SHARD + block_names: [TransformerBlock, VisionTransformerBlock] + +model: + component_key: model + variant_key: coca + config: + prediction_key: logits + modality_key: modality + modality_embd_prediction_key: modality_embeddings + text_embd_prediction_key: text_embeddings + modality_cls_prediction_key: modality_cls + text_cls_prediction_key: text_cls + modality_encoder_config: + sample_key: feats + prediction_key: modality_embeddings + img_size: 224 # 288 in the original coca + n_classes: Null # Disable vision transformer head + n_layer: 3 + attention_config: + attention_engine_type: default_attention + n_head: 12 + n_embd: 768 + dropout: 0.0 + patch_size: 16 # 18 in the original coca + patch_stride: 16 # 18 in the original coca + n_img_channels: 3 + add_cls_token: False + bias: True + text_decoder_config: + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${model.config.prediction_key} + block_size: 512 + vocab_size: 610 + n_layer_text: 3 + n_layer_multimodal_text: 1 + attention_config: + attention_engine_type: default_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: fused_swiglu + epsilon: 1e-5 + n_pool_head: 8 + n_vision_queries: 256 + n_audio_queries: Null + bias_attn_pool: False + epsilon_attn_pool: 1e-5 + weight_init: + mean: 0.0 + std: 0.02 + +scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: 962699 + pct_start: 0.01 + anneal_strategy: cos + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.00001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_samples: ${settings.training.global_num_seen_samples} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: "." From b89bff853e0110571fef96624d2a0a772f3b947a Mon Sep 17 00:00:00 2001 From: mmaurya Date: Tue, 28 May 2024 12:18:31 +0000 Subject: [PATCH 084/161] feat: allow masking of "pad" keys --- src/modalities/nn/attention.py | 39 +++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/src/modalities/nn/attention.py b/src/modalities/nn/attention.py index dd8b5db57..ceecff542 100644 --- a/src/modalities/nn/attention.py +++ b/src/modalities/nn/attention.py @@ -59,25 +59,42 @@ def __init__( ) self.resid_dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() - def forward(self, x: Tensor, context: Optional[Tensor] = None) -> Tensor: + def forward(self, x: Tensor, context: Optional[Tensor] = None, mask: Tensor = None) -> Tensor: context = context if self.use_cross_attention else x B, T, C = x.shape # batch size, sequence length, embedding dimensionality (n_embd) q, k, v = self._forward_input_projection(x, context=context) if self.use_flash: - y = F.scaled_dot_product_attention( - query=q, - key=k, - value=v, - attn_mask=None, - dropout_p=self.dropout if self.training else 0, - is_causal=self.is_causal, + y = ( + self._flash_with_mask(query=q, key=k, value=v, mask=mask) + if mask is not None + else self._flash_without_mask(query=q, key=k, value=v) ) else: - y = self._forward_attention(query=q, key=k, value=v) + y = self._forward_attention(query=q, key=k, value=v, mask=mask) y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.resid_dropout(self.c_proj(y)) return y + def _flash_with_mask(self, query: Tensor, key: Tensor, value: Tensor, mask: Tensor) -> Tensor: + return F.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=(mask == 0).logical_not(), + dropout_p=self.dropout if self.training else 0, + is_causal=self.is_causal, + ) + + def _flash_without_mask(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: + return F.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=None, + dropout_p=self.dropout if self.training else 0, + is_causal=self.is_causal, + ) + def _forward_input_projection(self, x: Tensor, context: Tensor) -> Tuple[Tensor, Tensor, Tensor]: B, T, C = x.shape # batch size, sequence length, embedding dimensionality (n_embd) _, Tc, Cc = context.shape # batch size, context length, context embedding dimensionality @@ -88,11 +105,13 @@ def _forward_input_projection(self, x: Tensor, context: Tensor) -> Tuple[Tensor, v = self.wv(context).view(B, Tc, self.n_head, Cc // self.n_head).transpose(1, 2) return q, k, v - def _forward_attention(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: + def _forward_attention(self, query: Tensor, key: Tensor, value: Tensor, mask: Tensor) -> Tensor: att = (query @ key.transpose(-2, -1)) * (1.0 / math.sqrt(key.size(-1))) if self.is_causal: T = query.size(2) att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + if mask is not None: + att = att.masked_fill(mask == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) return att @ value From d31d0a70482dd08817bd50b8407d630c0e544909 Mon Sep 17 00:00:00 2001 From: mmaurya Date: Tue, 28 May 2024 12:19:00 +0000 Subject: [PATCH 085/161] feat: implement Conformer from scratch --- .../audio_transformer_model.py | 234 +++++++++++++----- src/modalities/models/coca/coca_model.py | 2 +- 2 files changed, 176 insertions(+), 60 deletions(-) diff --git a/src/modalities/models/audio_transformer/audio_transformer_model.py b/src/modalities/models/audio_transformer/audio_transformer_model.py index a265813f6..eec4dbc6e 100644 --- a/src/modalities/models/audio_transformer/audio_transformer_model.py +++ b/src/modalities/models/audio_transformer/audio_transformer_model.py @@ -3,58 +3,133 @@ import torch from pydantic import BaseModel, Field from torch import nn -from torchaudio.models import Conformer + +from modalities.nn.attention import AttentionConfig, AttentionType, MultiHeadAttention +from modalities.nn.mlp import MLP class AudioTransformerConfig(BaseModel): sample_key: str prediction_key: str - input_dims: Annotated[int, Field(ge=1)] - pre_conformer_dropout: Annotated[float, Field(lt=1.0)] - conformer_dropout: Annotated[float, Field(lt=1.0)] - n_heads: Annotated[int, Field(ge=1)] + block_size: Annotated[int, Field(ge=1)] + n_mels: Annotated[int, Field(ge=1)] n_embd: Annotated[int, Field(ge=1)] - n_layers: Annotated[int, Field(ge=1)] + n_heads: Annotated[int, Field(ge=1)] + n_conformer_blocks: Annotated[int, Field(ge=1)] + attention_config: AttentionConfig + pointwise_conv_kernel_size: Annotated[int, Field(ge=1)] depthwise_conv_kernel_size: Annotated[int, Field(ge=1)] + ffmodule_dropout: Annotated[float, Field(lt=1.0)] = 0.1 + attn_dropout: Annotated[float, Field(lt=1.0)] = 0.1 + convmodule_dropout: Annotated[float, Field(lt=1.0)] = 0.1 -class PreConformer(nn.Module): +class ConvolutionModule(nn.Module): def __init__( self, - *, - n_input_dims: int, - dropout: float, + n_embd: int, + pointwise_conv_kernel_size: int, + depthwise_conv_kernel_size: int, + dropout: int, ): super().__init__() - self.subsampler = nn.Sequential( - nn.Conv1d( - in_channels=n_input_dims, - out_channels=n_input_dims, - kernel_size=2, - stride=2, - ), - nn.Conv1d( - in_channels=n_input_dims, - out_channels=n_input_dims, - kernel_size=2, - stride=2, - ), + self.ln = nn.LayerNorm(n_embd) + self.pointwise_conv_1 = nn.Conv1d( + n_embd, + 2 * n_embd, + pointwise_conv_kernel_size, + padding="same", + ) + self.glu = nn.GLU(dim=1) + self.depthwise_conv = nn.Conv1d( + n_embd, + n_embd, + kernel_size=depthwise_conv_kernel_size, + groups=n_embd, + padding="same", + ) + self.bn = nn.BatchNorm1d( + n_embd, + ) + self.swish = nn.SiLU() + self.pointwise_conv_2 = nn.Conv1d( + n_embd, + n_embd, + pointwise_conv_kernel_size, + padding="same", ) - self.linear = nn.Linear(n_input_dims, n_input_dims) - self.dropout = nn.Dropout(dropout) def forward( self, x: torch.Tensor, ) -> torch.Tensor: - x = x.transpose(1, 2) # x.shape: batch_size, n_input_dims, n_input_frames - - x = self.subsampler(x) # x.shape: batch_size, n_input_dims, ceil(n_input_frames / 4) + x = self.ln(x) x = x.transpose(1, 2) - x = self.linear(x) # x.shape: batch_size, ceil(n_input_frames / 4), n_input_dims - x = self.dropout(x) - return x + x = self.glu(self.pointwise_conv_1(x)) + x = self.swish(self.bn(self.depthwise_conv(x))) + x = self.pointwise_conv_2(x) + return self.dropout(x.transpose(1, 2)) # shape: B, T, D + + +class ConformerBlock(nn.Module): + def __init__( + self, + n_embd: int, + n_heads: int, + attention_config, + pointwise_conv_kernel_size: int, + depthwise_conv_kernel_size: int, + ffmodule_dropout: float, + attn_dropout: float, + convmodule_dropout: float, + ) -> None: + super().__init__() + + self.ln1 = nn.LayerNorm(n_embd) + self.entry_ffmodule = MLP( + in_features=n_embd, + act_fn=nn.SiLU, + dropout=ffmodule_dropout, + ) + self.mhsa_ln = nn.LayerNorm(n_embd) + self.mhsa = MultiHeadAttention( + attention_config=attention_config, + attention_type=AttentionType.NON_CAUSAL_SELF_ATTENTION, + n_embd=n_embd, + n_head=n_heads, + dropout=attn_dropout, + ) + self.convmodule = ConvolutionModule( + n_embd, + pointwise_conv_kernel_size, + depthwise_conv_kernel_size, + convmodule_dropout, + ) + self.ln2 = nn.LayerNorm( + n_embd, + ) + self.exit_ffmodule = MLP( + in_features=n_embd, + act_fn=nn.SiLU, + dropout=ffmodule_dropout, + ) + self.exit_ln = nn.LayerNorm( + n_embd, + ) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + ) -> torch.Tensor: + x = self.ln1(x) # x.shape: B, T, D + x = x + 0.5 * self.entry_ffmodule(x) + x = x + self.mhsa(self.mhsa_ln(x), mask=mask) + x = x + self.convmodule(x) + x = self.ln2(x) + x = x + 0.5 * self.exit_ffmodule(x) + return self.exit_ln(x) class AudioTransformer(nn.Module): @@ -63,46 +138,87 @@ def __init__( *, sample_key: str, prediction_key: str, - input_dims: int, - n_heads: int, + block_size: int, + n_mels: int, n_embd: int, - n_layers: int, + n_heads: int, + n_conformer_blocks: int, + attention_config: AttentionConfig, + pointwise_conv_kernel_size: int, depthwise_conv_kernel_size: int, - pre_conformer_dropout: float, - conformer_dropout: float, + ffmodule_dropout: float = 0.1, + attn_dropout: float = 0.1, + convmodule_dropout: float = 0.1, ): super().__init__() self.sample_key = sample_key self.prediction_key = prediction_key - self.pre_conformer = PreConformer( - n_input_dims=input_dims, - dropout=pre_conformer_dropout, - ) + self.block_size = block_size - self.conformer = Conformer( - input_dim=input_dims, - num_heads=n_heads, - ffn_dim=n_embd, - num_layers=n_layers, - depthwise_conv_kernel_size=depthwise_conv_kernel_size, - dropout=conformer_dropout, + self.project = nn.Conv1d(in_channels=n_mels, out_channels=n_embd, kernel_size=3, padding="same") + self.subsampler = nn.Sequential( + nn.Conv1d( + in_channels=n_embd, + out_channels=n_embd, + kernel_size=2, + stride=2, + ), + nn.Conv1d( + in_channels=n_embd, + out_channels=n_embd, + kernel_size=2, + stride=2, + ), + ) + self.post_subsampler_linear = nn.Sequential( + nn.Linear(n_embd, n_embd), + nn.Dropout(0.1), ) - self.post_conformer = nn.Sequential( - nn.Linear( - input_dims, - n_embd, - ), - nn.LayerNorm(n_embd), + self.positional_embeddings = nn.Embedding(self.block_size, n_embd) + self.conformer_blocks = nn.ModuleList( + [ + ConformerBlock( + n_embd, + n_heads, + attention_config, + pointwise_conv_kernel_size, + depthwise_conv_kernel_size, + ffmodule_dropout, + attn_dropout, + convmodule_dropout, + ) + for _ in range(n_conformer_blocks) + ] ) def forward( self, inputs: Dict[str, tuple[torch.Tensor, torch.Tensor]], ) -> Dict[str, tuple[torch.Tensor, torch.Tensor]]: - x = inputs[self.sample_key] # x.shape: batch_size, n_input_dims, n_input_frames - x_length = inputs["feats_len"] - x = self.pre_conformer(x) # x.shape: batch_size, ceil(n_input_frames / 4), n_input_dims - x, x_length = self.conformer(x, x_length) # x.shape: batch_size, ceil(n_input_frames / 4), n_input_dims - x = self.post_conformer(x) - return {self.prediction_key: (x, x_length)} + x = inputs[self.sample_key] # x.shape: B, T, D + attn_key_mask = self._get_attn_key_mask(inputs["feats_len"]) + # x.shape: B, T, D + x = self.project(x.transpose(1, 2)) # x.shape: B, D, T + x = self.subsampler(x) # x.shape: B, D, T/4 + x = x.transpose(1, 2) + x = self.post_subsampler_linear(x) + x = x + self.positional_embeddings.weight + + for block in self.conformer_blocks: + x = block(x, attn_key_mask) + return {self.prediction_key: x} + + def _get_attn_key_mask( + self, + lengths: torch.Tensor, + ): + return ( + torch.nn.utils.rnn.pad_sequence( + [torch.ones(length, self.block_size) for length in lengths] + + [torch.ones(self.block_size, self.block_size)], + batch_first=True, + ) + .transpose(1, 2)[:-1] + .unsqueeze_(1) + ).to(lengths.device) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index d7f7ccab0..369205dee 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -223,7 +223,7 @@ def _forward_encode_vision(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch return vision_embd, vision_cls_token def _forward_encode_audio(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - audio_embd, _ = self.audio_encoder(inputs)[self.modality_embd_prediction_key] + audio_embd = self.audio_encoder(inputs)[self.modality_embd_prediction_key] queries = repeat(self.audio_queries, "n d -> b n d", b=audio_embd.shape[0]) audio_embd = self.audio_attn_pool(queries, context=audio_embd) audio_embd, audio_cls_token = audio_embd[:, :-1, :], F.normalize(audio_embd[:, -1:, :], dim=-1) From 4c42b27f1bc0ff479f9bbade2d8682f7ec675266 Mon Sep 17 00:00:00 2001 From: mmaurya Date: Tue, 28 May 2024 12:21:14 +0000 Subject: [PATCH 086/161] test: fix to comply to changes --- tests/models/coca/coca_config_audio.yaml | 12 +++++++----- tests/models/coca/coca_config_av.yaml | 12 +++++++----- tests/models/coca/test_coca.py | 4 ++-- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/models/coca/coca_config_audio.yaml b/tests/models/coca/coca_config_audio.yaml index 79c08df8b..3ba12f113 100644 --- a/tests/models/coca/coca_config_audio.yaml +++ b/tests/models/coca/coca_config_audio.yaml @@ -7,12 +7,14 @@ text_cls_prediction_key: text_cls modality_encoder_config: sample_key: audio prediction_key: modality_embeddings - input_dims: 128 - pre_conformer_dropout: 0.1 - conformer_dropout: 0.1 - n_heads: 8 + block_size: 500 + n_mels: 128 n_embd: 768 - n_layers: 17 + n_heads: 4 + n_conformer_blocks: 3 + attention_config: + attention_engine_type: default_attention + pointwise_conv_kernel_size: 1 depthwise_conv_kernel_size: 31 text_decoder_config: sample_key: input_ids diff --git a/tests/models/coca/coca_config_av.yaml b/tests/models/coca/coca_config_av.yaml index 30025d45f..eb88dcdd1 100644 --- a/tests/models/coca/coca_config_av.yaml +++ b/tests/models/coca/coca_config_av.yaml @@ -24,12 +24,14 @@ modality_encoder_config: audio_transformer_config: sample_key: audio prediction_key: modality_embeddings - input_dims: 128 - pre_conformer_dropout: 0.1 - conformer_dropout: 0.1 - n_heads: 8 + block_size: 500 + n_mels: 128 n_embd: 768 - n_layers: 17 + n_heads: 4 + n_conformer_blocks: 3 + attention_config: + attention_engine_type: default_attention + pointwise_conv_kernel_size: 1 depthwise_conv_kernel_size: 31 text_decoder_config: sample_key: input_ids diff --git a/tests/models/coca/test_coca.py b/tests/models/coca/test_coca.py index 3f2bd7ba6..e45f15172 100644 --- a/tests/models/coca/test_coca.py +++ b/tests/models/coca/test_coca.py @@ -25,8 +25,8 @@ def dummy_image_sample(): def dummy_audio_sample(): - audio_features = torch.randn(1, 1000, 128) - audio_len = torch.Tensor([1000 / 4]) + audio_features = torch.randn(1, 500 * 4, 128) + audio_len = torch.tensor([1000 / 4]).type(torch.int16) text_decoder_vocab_size = 50304 text_decoder_block_size = 1024 input_text = torch.randint(0, text_decoder_vocab_size, (1, text_decoder_block_size)) From 1e09561a5ddfbb5d612c2cf14b7e1355db6bc9bb Mon Sep 17 00:00:00 2001 From: mmaurya Date: Tue, 28 May 2024 12:22:35 +0000 Subject: [PATCH 087/161] test: remove deprecated test --- tests/models/audio_transformer/__init__.py | 0 .../test_audio_transformer_model.py | 77 ------------------- 2 files changed, 77 deletions(-) delete mode 100644 tests/models/audio_transformer/__init__.py delete mode 100644 tests/models/audio_transformer/test_audio_transformer_model.py diff --git a/tests/models/audio_transformer/__init__.py b/tests/models/audio_transformer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/models/audio_transformer/test_audio_transformer_model.py b/tests/models/audio_transformer/test_audio_transformer_model.py deleted file mode 100644 index 7131c8de6..000000000 --- a/tests/models/audio_transformer/test_audio_transformer_model.py +++ /dev/null @@ -1,77 +0,0 @@ -import pytest -import torch - -from modalities.models.audio_transformer.audio_transformer_model import AudioTransformer - - -@pytest.fixture -def pre_conformer_config(): - return { - "input_dims": 80, - "dropout": 0.1, - } - - -@pytest.fixture -def audio_transformer_config(): - return { - "sample_key": "audio_feats", - "prediction_key": "audio_embeddings", - "n_heads": 4, - "n_embd": 512, - "n_layers": 2, - "depthwise_conv_kernel_size": 3, - "dropout": 0.1, - } - - -@pytest.fixture -def audio_transformer( - pre_conformer_config, - audio_transformer_config, -): - return AudioTransformer( - sample_key=audio_transformer_config["sample_key"], - prediction_key=audio_transformer_config["prediction_key"], - input_dims=pre_conformer_config["input_dims"], - n_heads=audio_transformer_config["n_heads"], - n_embd=audio_transformer_config["n_embd"], - n_layers=audio_transformer_config["n_layers"], - depthwise_conv_kernel_size=audio_transformer_config["depthwise_conv_kernel_size"], - pre_conformer_dropout=pre_conformer_config["dropout"], - conformer_dropout=audio_transformer_config["dropout"], - ) - - -@pytest.fixture -def dummy_input_div4(): - return {"audio_feats": torch.randn(4, 1000, 80), "feats_len": torch.Tensor([1000 / 4] * 4)} - - -@pytest.fixture -def dummy_input_notdiv4(): - return {"audio_feats": torch.randn(4, 750, 80), "feats_len": torch.Tensor([750 / 4] * 4)} - - -def test_audio_transformer_output_shape_div4( - dummy_input_div4, - audio_transformer, - audio_transformer_config, -): - output = audio_transformer(dummy_input_div4) - audio_embeddings, audio_lengths = output[audio_transformer_config["prediction_key"]] - assert audio_embeddings.shape[0] == 4 - assert audio_embeddings.shape[1] == 1000 / 4 - assert audio_embeddings.shape[2] == 512 - - -def test_audio_transformer_output_shape_notdiv4( - dummy_input_notdiv4, - audio_transformer, - audio_transformer_config, -): - output = audio_transformer(dummy_input_notdiv4) - audio_embeddings, audio_lengths = output[audio_transformer_config["prediction_key"]] - assert audio_embeddings.shape[0] == 4 - assert audio_embeddings.shape[1] == 750 // 4 - assert audio_embeddings.shape[2] == 512 From 68f6e5ea90bdee7ae251ce447c032050185a3870 Mon Sep 17 00:00:00 2001 From: mmaurya Date: Tue, 28 May 2024 12:41:44 +0000 Subject: [PATCH 088/161] chore: fix configs to comply to changes --- config_files/config_example_coca_audio.yaml | 97 +++++++++---------- .../config_example_coca_audio_vision.yaml | 29 ++++-- config_files/config_example_coca_vision.yaml | 27 +++--- 3 files changed, 77 insertions(+), 76 deletions(-) diff --git a/config_files/config_example_coca_audio.yaml b/config_files/config_example_coca_audio.yaml index 4e7af6437..7356bc2b8 100644 --- a/config_files/config_example_coca_audio.yaml +++ b/config_files/config_example_coca_audio.yaml @@ -9,7 +9,7 @@ settings: global_num_seen_samples: 0 do_apply_activation_checkpointing: true gradient_acc_steps: 1 - local_train_micro_batch_size: 124 + local_train_micro_batch_size: 256 sequence_length: 512 gradient_clipping: mode: p2_norm @@ -42,69 +42,64 @@ collate_fn: train_dataset: component_key: dataset - variant_key: simple_dataset + variant_key: arrow_dataset_audio config: type_: train audio_dataset_arrows: gertv-arrow-remove-na/train - vision_dataset_arrows: coco_captions/val - bpe_to_ind: - bpecodes: bpecodesgertv500 - num_feats: 80 + bpe_to_ind: bpe_to_ind_audio.pkl + bpecodes: bpecodes_audio + n_mels: 128 + block_size_audio_encoder: 500 + block_size_text_decoder: 512 freq_domain_mask_length: 30 time_domain_mask_length: 100 val_dataset: component_key: dataset - variant_key: simple_dataset + variant_key: arrow_dataset_audio config: type_: val audio_dataset_arrows: gertv-arrow-remove-na/test - vision_dataset_arrows: coco_captions/val - bpe_to_ind: - bpecodes: bpecodesgertv500 - num_feats: 80 + bpe_to_ind: bpe_to_ind_audio.pkl + bpecodes: bpecodes_audio + n_mels: 128 + block_size_audio_encoder: 500 + block_size_text_decoder: 512 freq_domain_mask_length: 30 time_domain_mask_length: 100 train_dataloader: component_key: data_loader - variant_key: repeating_data_loader + variant_key: default config: - dataloader: - component_key: data_loader + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler variant_key: default config: - num_workers: 2 - pin_memory: true - shuffle: true - dataloader_tag: "train" - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler config: - batch_size: ${settings.training.local_train_micro_batch_size} - drop_last: false - sampler: - component_key: sampler - variant_key: distributed_sampler - config: - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: true - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - - num_epochs: 1 - + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE val_dataloader: component_key: data_loader @@ -208,20 +203,22 @@ model: modality_encoder_config: sample_key: feats prediction_key: modality_embeddings - input_dims: 80 - pre_conformer_dropout: 0.1 - conformer_dropout: 0.1 - n_heads: 4 + block_size: 500 + n_mels: 128 n_embd: 512 - n_layers: 3 + n_heads: 4 + n_conformer_blocks: 3 + attention_config: + attention_engine_type: default_attention + pointwise_conv_kernel_size: 1 depthwise_conv_kernel_size: 31 text_decoder_config: sample_key: ${settings.referencing_keys.sample_key} prediction_key: ${model.config.prediction_key} block_size: 512 - vocab_size: 634 - n_layer_text: 1 - n_layer_multimodal_text: 2 + vocab_size: 141 + n_layer_text: 3 + n_layer_multimodal_text: 3 attention_config: attention_engine_type: default_attention n_head: 4 diff --git a/config_files/config_example_coca_audio_vision.yaml b/config_files/config_example_coca_audio_vision.yaml index 0716a2311..79207cee7 100644 --- a/config_files/config_example_coca_audio_vision.yaml +++ b/config_files/config_example_coca_audio_vision.yaml @@ -42,28 +42,35 @@ collate_fn: train_dataset: component_key: dataset - variant_key: simple_dataset + variant_key: arrow_dataset_av config: type_: train + batch_size: 124 audio_dataset_arrows: gertv-arrow-remove-na/train vision_dataset_arrows: coco_captions_arrow/train bpe_to_ind: bpe_to_ind_test.pkl bpecodes: bpecodes_test - num_feats: 80 + n_mels: 128 + img_size: 224 + block_size_audio_encoder: 2000 + block_size_text_decoder: 512 freq_domain_mask_length: 30 time_domain_mask_length: 100 - val_dataset: component_key: dataset - variant_key: simple_dataset + variant_key: arrow_dataset_av config: type_: val + batch_size: 124 audio_dataset_arrows: gertv-arrow-remove-na/test vision_dataset_arrows: coco_captions_arrow/val bpe_to_ind: bpe_to_ind_test.pkl bpecodes: bpecodes_test - num_feats: 80 + n_mels: 128 + img_size: 224 + block_size_audio_encoder: 2000 + block_size_text_decoder: 512 freq_domain_mask_length: 30 time_domain_mask_length: 100 @@ -256,12 +263,14 @@ model: audio_transformer_config: sample_key: feats prediction_key: modality_embeddings - input_dims: 80 - pre_conformer_dropout: 0.1 - conformer_dropout: 0.1 - n_heads: 4 + block_size: 2000 + n_mels: 128 n_embd: 512 - n_layers: 3 + n_heads: 4 + n_conformer_blocks: 3 + attention_config: + attention_engine_type: default_attention + pointwise_conv_kernel_size: 1 depthwise_conv_kernel_size: 31 text_decoder_config: sample_key: ${settings.referencing_keys.sample_key} diff --git a/config_files/config_example_coca_vision.yaml b/config_files/config_example_coca_vision.yaml index b93a3da0d..bcc00dc13 100644 --- a/config_files/config_example_coca_vision.yaml +++ b/config_files/config_example_coca_vision.yaml @@ -42,30 +42,25 @@ collate_fn: train_dataset: component_key: dataset - variant_key: simple_dataset + variant_key: arrow_dataset_vision config: - type_: train - audio_dataset_arrows: gertv-arrow-remove-na/train vision_dataset_arrows: coco_captions_arrow/train - bpe_to_ind: - bpecodes: bpecodescococaptions500 - num_feats: 80 - freq_domain_mask_length: 30 - time_domain_mask_length: 100 + bpe_to_ind: bpe_to_ind_test.pkl + bpecodes: bpecodes_test + img_size: 224 + block_size_text_decoder: 512 + val_dataset: component_key: dataset - variant_key: simple_dataset + variant_key: arrow_dataset_vision config: - type_: val - audio_dataset_arrows: gertv-arrow-remove-na/test vision_dataset_arrows: coco_captions_arrow/val - bpe_to_ind: - bpecodes: bpecodescococaptions500 - num_feats: 80 - freq_domain_mask_length: 30 - time_domain_mask_length: 100 + bpe_to_ind: bpe_to_ind_test.pkl + bpecodes: bpecodes_test + img_size: 224 + block_size_text_decoder: 512 # train_dataloader: From 94efdb811e15ff717d854674f9bedac78d6740f7 Mon Sep 17 00:00:00 2001 From: Thomas Holz Date: Mon, 3 Jun 2024 11:54:53 +0000 Subject: [PATCH 089/161] fix: accelerate import --- src/modalities/running_env/fsdp/fsdp_auto_wrapper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py index 4019b4169..38c1c1056 100644 --- a/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py +++ b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py @@ -32,6 +32,7 @@ def _get_fsdp_blocks_from_block_names(model: nn.Module, block_names: List[str]) block_type = FullyShardedDataParallelPlugin.get_module_class_from_name(model, cls_block_name) except AttributeError: from accelerate.utils.dataclasses import get_module_class_from_name + block_type = get_module_class_from_name(model, cls_block_name) if block_type is None: raise ValueError(f"Could not find block with name {cls_block_name} in model") From b13c8cb68f48f53bf95e1f8924e0afe1b600eff0 Mon Sep 17 00:00:00 2001 From: Thomas Holz Date: Mon, 3 Jun 2024 13:55:24 +0000 Subject: [PATCH 090/161] refactor: introduce global constants --- tests/models/coca/coca_config_av.yaml | 6 +-- tests/models/coca/test_coca.py | 63 ++++++++++++++++----------- 2 files changed, 41 insertions(+), 28 deletions(-) diff --git a/tests/models/coca/coca_config_av.yaml b/tests/models/coca/coca_config_av.yaml index eb88dcdd1..466277a9e 100644 --- a/tests/models/coca/coca_config_av.yaml +++ b/tests/models/coca/coca_config_av.yaml @@ -36,14 +36,14 @@ modality_encoder_config: text_decoder_config: sample_key: input_ids prediction_key: text_embeddings - block_size: 1024 - vocab_size: 50304 + block_size: 1_024 + vocab_size: 50_304 n_layer_text: 6 n_layer_multimodal_text: 6 attention_config: attention_engine_type: pytorch_flash_attention n_head: 12 - ffn_hidden: 2048 + ffn_hidden: 2_048 n_embd: 768 dropout: 0.0 bias: true diff --git a/tests/models/coca/test_coca.py b/tests/models/coca/test_coca.py index e45f15172..2a137672f 100644 --- a/tests/models/coca/test_coca.py +++ b/tests/models/coca/test_coca.py @@ -10,12 +10,27 @@ from modalities.running_env.cuda_env import CudaEnv from tests.conftest import _ROOT_DIR +# shared config +N_EMBD = 768 + +# text_decoder_config +TEXT_DECODER_VOCAB_SIZE = 50_304 +TEXT_DECODER_BLOCK_SIZE = 1_024 + +# vision_transformer_config +N_IMAGE_CLASSES = 1_000 +IMG_SIZE = 224 +N_IMG_CHANNELS = 3 + +# audio_transformer_config +AUDIO_BLOCK_SIZE = 500 +N_MELS = 128 +N_HEADS = 4 + def dummy_image_sample(): - input_image = torch.randn(1, 3, 224, 224) - text_decoder_vocab_size = 50304 - text_decoder_block_size = 1024 - input_text = torch.randint(0, text_decoder_vocab_size, (1, text_decoder_block_size)) + input_image = torch.randn(1, N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (1, TEXT_DECODER_BLOCK_SIZE)) VISION = torch.tensor([1]) return dict( images=input_image, @@ -25,11 +40,9 @@ def dummy_image_sample(): def dummy_audio_sample(): - audio_features = torch.randn(1, 500 * 4, 128) - audio_len = torch.tensor([1000 / 4]).type(torch.int16) - text_decoder_vocab_size = 50304 - text_decoder_block_size = 1024 - input_text = torch.randint(0, text_decoder_vocab_size, (1, text_decoder_block_size)) + audio_features = torch.randn(1, AUDIO_BLOCK_SIZE * N_HEADS, N_MELS) + audio_len = torch.tensor([N_IMAGE_CLASSES / N_HEADS]).type(torch.int16) + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (1, TEXT_DECODER_BLOCK_SIZE)) AUDIO = torch.tensor([0]) return dict( audio=audio_features, @@ -67,9 +80,9 @@ def test_coca(yaml, dummy_sample): assert "logits" in out assert "modality_cls" in out assert "text_cls" in out - assert out["logits"].shape == (1, 1024, 50304) - assert out["modality_cls"].shape == (1, 768) - assert out["text_cls"].shape == (1, 768) + assert out["logits"].shape == (1, TEXT_DECODER_BLOCK_SIZE, TEXT_DECODER_VOCAB_SIZE) + assert out["modality_cls"].shape == (1, N_EMBD) + assert out["text_cls"].shape == (1, N_EMBD) def test_coca_audio_vision_together(): @@ -85,19 +98,19 @@ def test_coca_audio_vision_together(): audio_sample = dummy_audio_sample() image_sample = dummy_image_sample() - # Run for image - optimizer.zero_grad() - out = model(image_sample) - loss = out["logits"].sum() - loss.backward() - optimizer.step() - - # Run for audio - optimizer.zero_grad() - out = model(audio_sample) - loss = out["logits"].sum() - loss.backward() - optimizer.step() + for dummy_samples in [audio_sample, image_sample]: + optimizer.zero_grad() + out = model(dummy_samples) + loss = out["logits"].sum() + loss.backward() + optimizer.step() + + assert "logits" in out + assert "modality_cls" in out + assert "text_cls" in out + assert out["logits"].shape == (1, TEXT_DECODER_BLOCK_SIZE, TEXT_DECODER_VOCAB_SIZE) + assert out["modality_cls"].shape == (1, N_EMBD) + assert out["text_cls"].shape == (1, N_EMBD) @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This e2e test requires 1 GPU.") From 70652f136bed0351aefc94db4f3db2f83b11cf95 Mon Sep 17 00:00:00 2001 From: Thomas Holz Date: Mon, 10 Jun 2024 12:52:15 +0000 Subject: [PATCH 091/161] fix: constant renaming --- tests/models/coca/test_coca.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/coca/test_coca.py b/tests/models/coca/test_coca.py index 2a137672f..071b7e365 100644 --- a/tests/models/coca/test_coca.py +++ b/tests/models/coca/test_coca.py @@ -25,7 +25,7 @@ # audio_transformer_config AUDIO_BLOCK_SIZE = 500 N_MELS = 128 -N_HEADS = 4 +SUB_SAMPLING_FACTOR = 4 def dummy_image_sample(): @@ -40,8 +40,8 @@ def dummy_image_sample(): def dummy_audio_sample(): - audio_features = torch.randn(1, AUDIO_BLOCK_SIZE * N_HEADS, N_MELS) - audio_len = torch.tensor([N_IMAGE_CLASSES / N_HEADS]).type(torch.int16) + audio_features = torch.randn(1, AUDIO_BLOCK_SIZE * SUB_SAMPLING_FACTOR, N_MELS) + audio_len = torch.tensor([N_IMAGE_CLASSES / SUB_SAMPLING_FACTOR]).type(torch.int16) input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (1, TEXT_DECODER_BLOCK_SIZE)) AUDIO = torch.tensor([0]) return dict( @@ -87,7 +87,7 @@ def test_coca(yaml, dummy_sample): def test_coca_audio_vision_together(): # Create model - config_file_path = _ROOT_DIR / Path("tests/models/coca/coca_config_av.yaml") + config_file_path = _ROOT_DIR / Path("coca/coca_config_av.yaml") config_dict = load_app_config_dict(config_file_path=config_file_path) coca_config = CoCaConfig.model_validate(config_dict) model = CoCa(**dict(coca_config)) From f07cfdaea45e0b826a89d4f2b922b335c8a85931 Mon Sep 17 00:00:00 2001 From: mmaurya Date: Wed, 12 Jun 2024 14:04:16 +0000 Subject: [PATCH 092/161] fix: disable mamba imports --- src/modalities/registry/components.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 6acfcceea..298b4d348 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -88,8 +88,9 @@ HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig, ) -from modalities.models.mamba.mamba_config import MambaLLMConfig -from modalities.models.mamba.mamba_model import MambaLLM + +# from modalities.models.mamba.mamba_config import MambaLLMConfig +# from modalities.models.mamba.mamba_model import MambaLLM from modalities.models.model_factory import ModelFactory from modalities.optimizers.lr_schedulers import DummyLRScheduler from modalities.optimizers.optimizer_factory import OptimizerFactory @@ -117,7 +118,7 @@ class ComponentEntity: COMPONENTS = [ # models ComponentEntity("model", "gpt2", GPT2LLM, GPT2LLMConfig), - ComponentEntity("model", "mamba", MambaLLM, MambaLLMConfig), + # ComponentEntity("model", "mamba", MambaLLM, MambaLLMConfig), ComponentEntity( "model", "huggingface_pretrained_model", HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig ), From 11a28e299a3bb6f797e504ac78a88497223357e7 Mon Sep 17 00:00:00 2001 From: mmaurya Date: Wed, 12 Jun 2024 14:04:57 +0000 Subject: [PATCH 093/161] chore: update audio coca arrow dataset config --- config_files/config_example_coca_audio.yaml | 73 ++++++++++++--------- 1 file changed, 41 insertions(+), 32 deletions(-) diff --git a/config_files/config_example_coca_audio.yaml b/config_files/config_example_coca_audio.yaml index 7356bc2b8..334f7928f 100644 --- a/config_files/config_example_coca_audio.yaml +++ b/config_files/config_example_coca_audio.yaml @@ -1,19 +1,19 @@ settings: experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} referencing_keys: sample_key: input_ids target_key: target_ids training: - callback_interval_in_samples: 10000 + global_training_log_interval_in_steps: 10000 # Needs to be a multiple of gradient_acc_steps + global_checkpointing_interval_in_steps: 10000 + global_evaluation_interval_in_steps: 10000 global_num_training_samples: 1925398 - global_num_seen_samples: 0 + global_num_seen_steps: 0 do_apply_activation_checkpointing: true gradient_acc_steps: 1 local_train_micro_batch_size: 256 sequence_length: 512 - gradient_clipping: - mode: p2_norm - threshold: 1.0 cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK} @@ -21,11 +21,6 @@ settings: paths: checkpointing_path: data/checkpoints -tokenizer: - component_key: tokenizer - variant_key: gpt2_tokenizer_fast - config: - tokenizer_file: data/tokenizer/tokenizer_gpt2.json collate_fn: component_key: collate_fn @@ -46,8 +41,8 @@ train_dataset: config: type_: train audio_dataset_arrows: gertv-arrow-remove-na/train - bpe_to_ind: bpe_to_ind_audio.pkl - bpecodes: bpecodes_audio + bpe_to_ind: bpe_to_ind_gertv_2000.pkl + bpecodes: bpecodes_gertv_2000 n_mels: 128 block_size_audio_encoder: 500 block_size_text_decoder: 512 @@ -61,8 +56,8 @@ val_dataset: config: type_: val audio_dataset_arrows: gertv-arrow-remove-na/test - bpe_to_ind: bpe_to_ind_audio.pkl - bpecodes: bpecodes_audio + bpe_to_ind: bpe_to_ind_gertv_2000.pkl + bpecodes: bpecodes_gertv_2000 n_mels: 128 block_size_audio_encoder: 500 block_size_text_decoder: 512 @@ -86,7 +81,7 @@ train_dataloader: variant_key: default config: batch_size: ${settings.training.local_train_micro_batch_size} - drop_last: false + drop_last: true sampler: component_key: sampler variant_key: distributed_sampler @@ -117,7 +112,7 @@ val_dataloader: variant_key: default config: batch_size: ${settings.training.local_train_micro_batch_size} - drop_last: false + drop_last: true sampler: component_key: sampler variant_key: distributed_sampler @@ -136,40 +131,42 @@ eval_dataloaders: - instance_key: val_dataloader pass_type: BY_REFERENCE -checkpointing: - component_key: checkpointing +checkpoint_saving: + component_key: checkpoint_saving variant_key: default config: - checkpointing_strategy: - component_key: checkpointing_strategy + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy variant_key: save_k_most_recent_checkpoints_strategy config: k: -1 # -1 to save all checkpoints - checkpointing_execution: - component_key: checkpointing_execution - variant_key: fsdp_to_disc_checkpointing + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: fsdp config: checkpoint_path: ${settings.paths.checkpointing_path} global_rank: ${settings.cuda_env.global_rank} experiment_id: ${settings.experiment_id} - mixed_precision_settings: FP_16 - sharding_strategy: FULL_SHARD - block_names: [TransformerBlock, AudioTransformer] + # mixed_precision_settings: FP_16 + # sharding_strategy: FULL_SHARD + # block_names: [TransformerBlock, AudioTransformer] captioning_loss: component_key: loss - variant_key: clm_cross_entropy_loss + variant_key: cross_entropy_loss config: target_key: ${settings.referencing_keys.target_key} - prediction_key: logits + prediction_key: ${model.config.prediction_key} + tag: captioning_loss contrastive_loss: component_key: loss - variant_key: nce_loss + variant_key: clip_loss config: prediction_key1: ${model.config.modality_cls_prediction_key} prediction_key2: ${model.config.text_cls_prediction_key} + logit_scale_key: ${model.config.logit_scale_prediction_key} tag: contrastive_loss loss_fn: @@ -200,6 +197,7 @@ model: text_embd_prediction_key: text_embeddings modality_cls_prediction_key: modality_cls text_cls_prediction_key: text_cls + logit_scale_prediction_key: logit_scale modality_encoder_config: sample_key: feats prediction_key: modality_embeddings @@ -216,7 +214,7 @@ model: sample_key: ${settings.referencing_keys.sample_key} prediction_key: ${model.config.prediction_key} block_size: 512 - vocab_size: 141 + vocab_size: 2134 n_layer_text: 3 n_layer_multimodal_text: 3 attention_config: @@ -263,13 +261,23 @@ optimizer: instance_key: wrapped_model pass_type: BY_REFERENCE +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp + config: + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + batch_progress_subscriber: component_key: progress_subscriber variant_key: rich config: local_rank: ${settings.cuda_env.local_rank} - world_size: ${settings.cuda_env.world_size} - global_num_seen_samples: ${settings.training.global_num_seen_samples} + # world_size: ${settings.cuda_env.world_size} + global_num_seen_steps: ${settings.training.global_num_seen_steps} train_dataloader: instance_key: train_dataloader pass_type: BY_REFERENCE @@ -286,3 +294,4 @@ evaluation_subscriber: mode: OFFLINE experiment_id: ${settings.experiment_id} directory: "." + config_file_path: ${settings.config_file_path} From ced54c86c787398c762b4179cdddafe8ef6422db Mon Sep 17 00:00:00 2001 From: Thomas Holz Date: Thu, 13 Jun 2024 13:34:06 +0000 Subject: [PATCH 094/161] feat: add audio transform --- src/modalities/dataloader/dataset.py | 55 +++++++++++++++++++++++++-- src/modalities/registry/components.py | 3 ++ 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index ce7180d5f..9303fb1f0 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -314,6 +314,51 @@ def __call__(self, text): return batch_encoding +class AudioTransformConfig(TransformConfig): + is_training: bool = False + n_mels: int = 128 + freq_domain_mask_length: int = 30 + time_domain_mask_length: int = 100 + block_size_audio_encoder: int + + +class AudioTransform(Transform): + def __init__( + self, + block_size_audio_encoder: int, + is_training: bool = False, + n_mels: int = 128, + freq_domain_mask_length: int = 30, + time_domain_mask_length: int = 100, + ): + self.block_size_audio_encoder = block_size_audio_encoder + self.is_training = is_training + self.n_mels = n_mels + self.freq_domain_mask_length = freq_domain_mask_length + self.time_domain_mask_length = time_domain_mask_length + + def __call__(self, raw_audio: tuple[torch.Tensor, int]) -> tuple[torch.Tensor, int]: + SUB_SAMPLING_FACTOR = 4 + + self.extract_features = torchaudio.transforms.MelSpectrogram(n_mels=self.n_mels) + + if self.is_training: + self.masking = torch.nn.Sequential( + torchaudio.transforms.FrequencyMasking(freq_mask_param=self.freq_domain_mask_length), + torchaudio.transforms.TimeMasking(time_mask_param=self.time_domain_mask_length), + ) + + log_mel_spec = torch.clamp(self.extract_features(raw_audio[0]), 1e-10).log10().squeeze(0) + log_mel_spec = self.masking(log_mel_spec) if self.is_training else log_mel_spec + feats_len = log_mel_spec.shape[-1] // SUB_SAMPLING_FACTOR + + assert feats_len * SUB_SAMPLING_FACTOR <= SUB_SAMPLING_FACTOR * self.block_size_audio_encoder + log_mel_spec = torch.nn.functional.pad( + log_mel_spec, (0, SUB_SAMPLING_FACTOR * self.block_size_audio_encoder - log_mel_spec.shape[-1]) + ).transpose(0, 1) + return log_mel_spec, feats_len + + class RandomTemporalCrop: def __init__(self, num_frames): self.num_frames = num_frames @@ -396,9 +441,13 @@ def __init__( } self.additional_extreacted_keys = [] + self.additional_extreacted_keys.append("modality") if ModalityEnum.TEXT in self.modality_transforms: self.additional_extreacted_keys.append("attention_mask") + if ModalityEnum.AUDIO in self.modality_transforms: + self.additional_extreacted_keys.append("feats_len") + # Mapping between modality and transform self.modality_to_transform_fn = { ModalityEnum.TEXT: self._transform_text, @@ -466,10 +515,10 @@ def _transform_video(self, sample): def _transform_audio(self, sample): source_key, target_key = self.modality_key_mapping[ModalityEnum.AUDIO] - # config: AudioTransformConfig = self.modality_transforms_configs[ModalityEnum.AUDIO] - # TODO add audio transform - sample[target_key] = sample[source_key] + transform: AudioTransform = self.modality_transforms[ModalityEnum.AUDIO] + sample[target_key], sample["feats_len"] = transform(sample[source_key]) del sample[source_key] + sample["modality"] = [0] return sample def _flatten_sample(self, sample): diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 298b4d348..43d104405 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -54,6 +54,8 @@ ) from modalities.dataloader.dataloader_factory import DataloaderFactory from modalities.dataloader.dataset import ( + AudioTransform, + AudioTransformConfig, DummyDatasetConfig, ImageTransform, ImageTransformConfig, @@ -176,6 +178,7 @@ class ComponentEntity: # Data transforms & augmentations ComponentEntity("transform", "text_transform", TextTransform, TextTransformConfig), ComponentEntity("transform", "image_transform", ImageTransform, ImageTransformConfig), + ComponentEntity("transform", "audio_transform", AudioTransform, AudioTransformConfig), ComponentEntity("transform", "video_transform", VideoTransform, VideoTransformConfig), # samplers ComponentEntity("sampler", "distributed_sampler", DistributedSampler, DistributedSamplerConfig), From 8ad8e865e19522ded38c337246c06f439f679cdd Mon Sep 17 00:00:00 2001 From: Thomas Holz Date: Thu, 13 Jun 2024 13:34:52 +0000 Subject: [PATCH 095/161] fix: cross entropy loss ignore index --- src/modalities/loss_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index 957fa4021..a5e6a8c0b 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -42,7 +42,6 @@ def __init__(self, target_key: str, prediction_key: str, weight: float, tag: str # Mean over the tokens in the local-batch (batch per rank) self.loss_fun = TorchCrossEntropyLoss( reduction="mean", - ignore_index=0, ) def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: From 31d72d0754ea91ec71d1039935d06c3bdc2b0645 Mon Sep 17 00:00:00 2001 From: Thomas Holz Date: Thu, 13 Jun 2024 13:35:16 +0000 Subject: [PATCH 096/161] fix: prepare_sample --- src/modalities/models/coca/collator.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/modalities/models/coca/collator.py b/src/modalities/models/coca/collator.py index c4570c2e5..e220aa677 100644 --- a/src/modalities/models/coca/collator.py +++ b/src/modalities/models/coca/collator.py @@ -32,18 +32,14 @@ def __init__(self, sample_keys: List[str], target_keys: List[str], text_sample_k def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: samples = { - sample_key: torch.stack( - [torch.tensor(d[sample_key]) if type(d[sample_key]) != torch.Tensor else d[sample_key] for d in batch] - ) + sample_key: torch.stack([self._prepare_sample(d[sample_key]) for d in batch]) for sample_key in self.sample_keys } if "attention_mask" in batch[0]: samples["attention_mask"] = torch.stack([self._prepare_sample(d["attention_mask"]) for d in batch]) targets = { - target_key: torch.stack( - [torch.tensor(d[target_key]) if type(d[target_key]) != torch.Tensor else d[target_key] for d in batch] - ) + target_key: torch.stack([self._prepare_sample(d[target_key]) for d in batch]) for target_key in self.target_keys } @@ -56,3 +52,9 @@ def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: samples["attention_mask"] = samples["attention_mask"][:, :-1] return DatasetBatch(targets=targets, samples=samples) + + @staticmethod + def _prepare_sample(x): + if isinstance(x, torch.Tensor): + return x + return torch.tensor(x) From 3a14b9526e41d51c852aae321b8eb3682174d398 Mon Sep 17 00:00:00 2001 From: Thomas Holz Date: Thu, 13 Jun 2024 13:37:20 +0000 Subject: [PATCH 097/161] chore: apply changes from origin/feat/audio_coca --- src/modalities/models/coca/coca_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index 369205dee..86ee5c4f1 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -226,7 +226,7 @@ def _forward_encode_audio(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch. audio_embd = self.audio_encoder(inputs)[self.modality_embd_prediction_key] queries = repeat(self.audio_queries, "n d -> b n d", b=audio_embd.shape[0]) audio_embd = self.audio_attn_pool(queries, context=audio_embd) - audio_embd, audio_cls_token = audio_embd[:, :-1, :], F.normalize(audio_embd[:, -1:, :], dim=-1) + audio_embd, audio_cls_token = audio_embd[:, :-1, :], F.normalize(audio_embd[:, -1, :], dim=-1) return audio_embd, audio_cls_token def _forward_encode_text(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: From 2335dcddb0729dd5b04c3ebce2af49d6caac62a0 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 20 Jun 2024 14:33:13 +0200 Subject: [PATCH 098/161] chore: add separate vision and audio configs --- config_files/config_example_coca_audio.yaml | 11 +- .../config_example_coca_webdataset.yaml | 4 +- .../config_example_video_coca_webdataset.yaml | 6 +- src/modalities/models/coca/coca_model.py | 109 ++++++++++-------- 4 files changed, 67 insertions(+), 63 deletions(-) diff --git a/config_files/config_example_coca_audio.yaml b/config_files/config_example_coca_audio.yaml index 334f7928f..d61fcdc14 100644 --- a/config_files/config_example_coca_audio.yaml +++ b/config_files/config_example_coca_audio.yaml @@ -48,7 +48,6 @@ train_dataset: block_size_text_decoder: 512 freq_domain_mask_length: 30 time_domain_mask_length: 100 - val_dataset: component_key: dataset @@ -192,13 +191,12 @@ model: variant_key: coca config: prediction_key: logits - modality_key: modality - modality_embd_prediction_key: modality_embeddings + audio_embd_prediction_key: modality_embeddings text_embd_prediction_key: text_embeddings - modality_cls_prediction_key: modality_cls + audio_cls_prediction_key: modality_cls text_cls_prediction_key: text_cls logit_scale_prediction_key: logit_scale - modality_encoder_config: + audio_encoder_config: sample_key: feats prediction_key: modality_embeddings block_size: 500 @@ -227,8 +225,7 @@ model: activation: fused_swiglu epsilon: 1e-5 n_pool_head: 8 - n_vision_queries: Null - n_audio_queries: 256 + n_queries: 256 bias_attn_pool: False epsilon_attn_pool: 1e-5 weight_init: diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml index 4161d4fec..0817c160e 100644 --- a/config_files/training/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -229,7 +229,7 @@ model: activation: fused_swiglu epsilon: 1e-5 n_pool_head: 12 - n_vision_queries: 256 + n_queries: 256 bias_attn_pool: False epsilon_attn_pool: 1e-5 weight_init: @@ -287,7 +287,7 @@ evaluation_subscriber: config: local_rank: ${settings.cuda_env.global_rank} project: modalities - mode: ONLINE + mode: OFFLINE experiment_id: ${settings.experiment_id} directory: "." config_file_path: ${settings.config_file_path} diff --git a/config_files/training/config_example_video_coca_webdataset.yaml b/config_files/training/config_example_video_coca_webdataset.yaml index 58d2225d6..1466b9dc4 100644 --- a/config_files/training/config_example_video_coca_webdataset.yaml +++ b/config_files/training/config_example_video_coca_webdataset.yaml @@ -231,8 +231,8 @@ model: bias: true activation: fused_swiglu epsilon: 1e-5 - n_pool_head: 12 - n_vision_queries: 256 + n_pool_head: 16 + n_queries: 256 bias_attn_pool: False epsilon_attn_pool: 1e-5 weight_init: @@ -290,7 +290,7 @@ evaluation_subscriber: config: local_rank: ${settings.cuda_env.global_rank} project: modalities - mode: ONLINE + mode: OFFLINE experiment_id: ${settings.experiment_id} directory: "." config_file_path: ${settings.config_file_path} diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index b93647d11..bc066e131 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -1,6 +1,6 @@ import math from functools import partial -from typing import Annotated, Dict, Tuple +from typing import Annotated, Dict, Optional, Tuple import numpy as np import torch @@ -18,11 +18,6 @@ from modalities.nn.attention import AttentionConfig -class AVConfig(BaseModel): - audio_transformer_config: AudioTransformerConfig - vision_transformer_config: VisionTransformerConfig - - class TextDecoderConfig(BaseModel): """ Configuration class for the TextDecoder. @@ -81,17 +76,19 @@ class CoCaConfig(BaseModel): """ prediction_key: str = "logits" - modality_key: str = "modality" - modality_embd_prediction_key: str text_embd_prediction_key: str - modality_cls_prediction_key: str text_cls_prediction_key: str logit_scale_prediction_key: str - modality_encoder_config: AudioTransformerConfig | VisionTransformerConfig | AVConfig + audio_embd_prediction_key: Optional[str] = None + vision_embd_prediction_key: Optional[str] = None + audio_cls_prediction_key: Optional[str] = None + vision_cls_prediction_key: Optional[str] = None + audio_encoder_config: Optional[AudioTransformerConfig] = None + vision_encoder_config: Optional[VisionTransformerConfig] = None text_decoder_config: TextDecoderConfig n_pool_head: Annotated[int, Field(ge=1)] - n_vision_queries: Annotated[int, Field(ge=1)] | None - n_audio_queries: Annotated[int, Field(ge=1)] | None + n_vision_queries: Optional[Annotated[int, Field(ge=1)]] + n_audio_queries: Optional[Annotated[int, Field(ge=1)]] bias_attn_pool: bool epsilon_attn_pool: Annotated[float, Field(ge=0.0)] @@ -110,15 +107,19 @@ class CoCa(NNModel): def __init__( self, prediction_key: str, - modality_key: str, - modality_embd_prediction_key: str, text_embd_prediction_key: str, - logit_scale_prediction_key: str, - modality_cls_prediction_key: str, text_cls_prediction_key: str, - n_vision_queries: int, - n_audio_queries: int, + logit_scale_prediction_key: str, + audio_embd_prediction_key: Optional[str], + vision_embd_prediction_key: Optional[str], + audio_cls_prediction_key: Optional[str], + vision_cls_prediction_key: Optional[str], + audio_encoder_config: Optional[AudioTransformerConfig], + vision_encoder_config: Optional[VisionTransformerConfig], + text_decoder_config: TextDecoderConfig, n_pool_head: int, + n_vision_queries: Optional[int], + n_audio_queries: Optional[int], bias_attn_pool: bool, epsilon_attn_pool: float, modality_encoder_config: VisionTransformerConfig | AudioTransformerConfig | AVConfig, @@ -146,44 +147,36 @@ def __init__( """ super().__init__() - self.AUDIO = 0 - self.VISION = 1 - self.prediction_key = prediction_key - self.modality_key = modality_key - self.modality_embd_prediction_key = modality_embd_prediction_key self.text_embd_prediction_key = text_embd_prediction_key self.logit_scale_prediction_key = logit_scale_prediction_key - - self.modality_cls_prediction_key = modality_cls_prediction_key self.text_cls_prediction_key = text_cls_prediction_key + self.audio_embd_prediction_key = audio_embd_prediction_key + self.vision_embd_prediction_key = vision_embd_prediction_key + self.audio_cls_prediction_key = audio_cls_prediction_key + self.vision_cls_prediction_key = vision_cls_prediction_key + self.n_pool_head = n_pool_head self.bias_attn_pool = bias_attn_pool self.epsilon_attn_pool = epsilon_attn_pool self.text_decoder_config = text_decoder_config - if isinstance(modality_encoder_config, VisionTransformerConfig): + self.vision_sample_key = None + if vision_encoder_config is not None: + self.vision_sample_key = vision_encoder_config.sample_key self.vision_encoder, self.vision_queries, self.vision_attn_pool = self._init_modality( VisionTransformer, - modality_encoder_config, - n_vision_queries, - ) - elif isinstance(modality_encoder_config, AudioTransformerConfig): - self.audio_encoder, self.audio_queries, self.audio_attn_pool = self._init_modality( - AudioTransformer, - modality_encoder_config, - n_audio_queries, - ) - else: - self.vision_encoder, self.vision_queries, self.vision_attn_pool = self._init_modality( - VisionTransformer, - modality_encoder_config.vision_transformer_config, + vision_encoder_config, n_vision_queries, ) + + self.audio_sample_key = None + if audio_encoder_config is not None: + self.audio_sample_key = audio_encoder_config.sample_key self.audio_encoder, self.audio_queries, self.audio_attn_pool = self._init_modality( AudioTransformer, - modality_encoder_config.audio_transformer_config, + audio_encoder_config, n_audio_queries, ) @@ -257,18 +250,32 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: Returns: dict[str, torch.Tensor]: Output dictionary. """ - if inputs[self.modality_key][0] == self.AUDIO: - modality_embd, modality_cls_token = self._forward_encode_audio(inputs) - if inputs[self.modality_key][0] == self.VISION: - modality_embd, modality_cls_token = self._forward_encode_vision(inputs) + output = {} + # TODO stack features from different modalities (ensure correct alignment with the text features) + modality_embd = None + if self.audio_sample_key is not None and self.audio_sample_key in inputs: + audio_embd, audio_cls_token = self._forward_encode_audio(inputs) + output[self.audio_cls_prediction_key] = audio_cls_token + modality_embd = audio_embd + + elif self.vision_sample_key is not None and self.vision_sample_key in inputs: + vision_embd, vision_cls_token = self._forward_encode_vision(inputs) + output[self.vision_cls_prediction_key] = audio_cls_token + modality_embd = vision_embd + + else: + raise NotImplementedError("Parallel vision audio in the same batch is currently not supported!") + text_embd, text_cls_token = self._forward_encode_text(inputs) logits = self._forward_decode(text_embd, modality_embd) - return { - self.prediction_key: logits, - self.modality_cls_prediction_key: modality_cls_token, - self.text_cls_prediction_key: text_cls_token, - self.logit_scale_prediction_key: self.logit_scale.exp(), - } + output.update( + { + self.prediction_key: logits, + self.text_cls_prediction_key: text_cls_token, + self.logit_scale_prediction_key: self.logit_scale.exp(), + } + ) + return output def _forward_encode_vision(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -287,7 +294,7 @@ def _forward_encode_vision(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch return vision_embd, vision_cls_token def _forward_encode_audio(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - audio_embd = self.audio_encoder(inputs)[self.modality_embd_prediction_key] + audio_embd = self.audio_encoder(inputs)[self.audio_embd_prediction_key] queries = repeat(self.audio_queries, "n d -> b n d", b=audio_embd.shape[0]) audio_embd = self.audio_attn_pool(queries, context=audio_embd) audio_embd, audio_cls_token = audio_embd[:, :-1, :], F.normalize(audio_embd[:, -1, :], dim=-1) From 420041b9ea6e87e3c6a23400db4648f333174e3f Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Thu, 20 Jun 2024 15:58:23 +0200 Subject: [PATCH 099/161] fix: coca n_query parameter --- src/modalities/models/coca/coca_model.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index bc066e131..00badd6fc 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -87,8 +87,7 @@ class CoCaConfig(BaseModel): vision_encoder_config: Optional[VisionTransformerConfig] = None text_decoder_config: TextDecoderConfig n_pool_head: Annotated[int, Field(ge=1)] - n_vision_queries: Optional[Annotated[int, Field(ge=1)]] - n_audio_queries: Optional[Annotated[int, Field(ge=1)]] + n_queries: Optional[Annotated[int, Field(ge=1)]] bias_attn_pool: bool epsilon_attn_pool: Annotated[float, Field(ge=0.0)] @@ -118,8 +117,7 @@ def __init__( vision_encoder_config: Optional[VisionTransformerConfig], text_decoder_config: TextDecoderConfig, n_pool_head: int, - n_vision_queries: Optional[int], - n_audio_queries: Optional[int], + n_queries: Optional[int], bias_attn_pool: bool, epsilon_attn_pool: float, modality_encoder_config: VisionTransformerConfig | AudioTransformerConfig | AVConfig, @@ -168,7 +166,7 @@ def __init__( self.vision_encoder, self.vision_queries, self.vision_attn_pool = self._init_modality( VisionTransformer, vision_encoder_config, - n_vision_queries, + n_queries, ) self.audio_sample_key = None @@ -177,7 +175,7 @@ def __init__( self.audio_encoder, self.audio_queries, self.audio_attn_pool = self._init_modality( AudioTransformer, audio_encoder_config, - n_audio_queries, + n_queries, ) self.text_decoder = TextDecoder( From 9236a4664c49fa1f97967a8c3b69016989b1243e Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Tue, 2 Jul 2024 11:01:58 +0200 Subject: [PATCH 100/161] fix: copy paste error --- src/modalities/models/coca/coca_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index 00badd6fc..ee92e6d6e 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -258,7 +258,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: elif self.vision_sample_key is not None and self.vision_sample_key in inputs: vision_embd, vision_cls_token = self._forward_encode_vision(inputs) - output[self.vision_cls_prediction_key] = audio_cls_token + output[self.vision_cls_prediction_key] = vision_cls_token modality_embd = vision_embd else: From edbc1ac8c6c90d2dae46afd89283c498611e88db Mon Sep 17 00:00:00 2001 From: Thomas Holz Date: Thu, 1 Aug 2024 11:57:07 +0000 Subject: [PATCH 101/161] feat: allow for training all modalities --- .../config_example_coca_webdataset.yaml | 34 +++++++++- src/modalities/dataloader/dataset.py | 1 - src/modalities/loss_functions.py | 66 +++++++++++-------- src/modalities/models/coca/coca_model.py | 22 ++++++- .../models/coca/multi_modal_decoder.py | 20 +++++- 5 files changed, 106 insertions(+), 37 deletions(-) diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml index 0817c160e..5eb177e04 100644 --- a/config_files/training/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -47,6 +47,13 @@ train_image_transform: is_training: True input_size: 256 +train_image_transform: + component_key: transform + variant_key: image_transform + config: + is_training: True + input_size: 288 + text_transform: component_key: transform variant_key: text_transform @@ -70,13 +77,32 @@ train_coco_dataset_builder: TEXT: instance_key: text_transform pass_type: BY_REFERENCE - num_samples: 566_748 + num_samples: 30000 + +train_dataset_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: /p/scratch/jureap63/multimodal_data/audio/commonvoice_17_test_wav_000000.tar + modality_key_mapping: + TEXT: ["transcript.txt", "input_ids"] # source and target keys + AUDIO: ["wav", "feats"] + modality_transforms: + AUDIO: + instance_key: train_audio_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 10000 train_dataset: component_key: dataset variant_key: web_dataset config: builders: + - instance_key: train_dataset_builder + pass_type: BY_REFERENCE - instance_key: train_coco_dataset_builder pass_type: BY_REFERENCE shardshuffle: 100 @@ -162,8 +188,10 @@ contrastive_loss: component_key: loss variant_key: clip_loss config: - prediction_key1: ${model.config.vision_cls_prediction_key} - prediction_key2: ${model.config.text_cls_prediction_key} + prediction_keys: + - ${model.config.audio_cls_prediction_key} + - ${model.config.vision_cls_prediction_key} + - ${model.config.text_cls_prediction_key} logit_scale_key: ${model.config.logit_scale_prediction_key} tag: contrastive_loss weight: 1.0 diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index a01982edd..44dd3b724 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -775,7 +775,6 @@ def __init__( """ super().__init__() self.builders = builders - assert len(builders) == 1, "Multiple dataset builders are not supported yet" # TODO self.output_keys_by_modality = {} for b in builders: diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index a5e6a8c0b..c1478181e 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -154,8 +154,7 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: class ClipLossConfig(BaseModel): logit_scale_key: str - prediction_key1: str - prediction_key2: str + prediction_keys: list[str] weight: float = 1 local_loss: bool = True tag: str = "ClipLoss" @@ -165,8 +164,7 @@ class ClipLoss(Loss): def __init__( self, logit_scale_key: str, - prediction_key1: str, - prediction_key2: str, + prediction_keys: list[str], weight: float, local_loss: bool, tag: str = "ClipLoss", @@ -176,16 +174,17 @@ def __init__( Args: logit_scale_key (str): Value of a learnable logit scale parameter. - prediction_key1 (str): Key to access embedding 1. - prediction_key2 (str): Key to access embedding 2. + prediction_keys (list[str]): Keys to access embeddings. tag (str, optional): Defaults to "ClipLoss". """ super().__init__(tag, weight) self.logit_scale_key = logit_scale_key - self.prediction_key1 = prediction_key1 - self.prediction_key2 = prediction_key2 + self.prediction_keys = prediction_keys self.local_loss = local_loss + if not (2 <= len(prediction_keys) <= 3): + raise ValueError("ClipLoss requires either 2 or 3 prediction keys.") + def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: """ Args: @@ -195,44 +194,53 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: torch.Tensor: loss tensor. """ logit_scale = forward_batch.get_predictions(self.logit_scale_key) - embedding1 = forward_batch.get_predictions(self.prediction_key1).contiguous() - embedding2 = forward_batch.get_predictions(self.prediction_key2).contiguous() - device = embedding1.device + + embeddings = [forward_batch.get_predictions(key).contiguous() for key in self.prediction_keys] + device = embeddings[0].device # Gather all embeddings from each rank world_size = dist.get_world_size() rank = dist.get_rank() - gathered_embedding1 = [torch.zeros_like(embedding1) for _ in range(world_size)] - gathered_embedding2 = [torch.zeros_like(embedding2) for _ in range(world_size)] - dist.all_gather(gathered_embedding1, embedding1) - dist.all_gather(gathered_embedding2, embedding2) + + gathered_embeddings = [torch.zeros_like(embedding) for embedding in embeddings for _ in range(world_size)] + + for gathered_embedding, embedding in zip(gathered_embeddings, embeddings): + dist.all_gather(gathered_embedding, embedding) # Make sure we have gradients for the "local" embeddings if not self.local_loss: - gathered_embedding1[rank] = embedding1 - gathered_embedding2[rank] = embedding2 + for gathered_embedding, embedding in zip(gathered_embeddings, embeddings): + gathered_embedding[rank] = embedding # Combine embeddings - gathered_embedding1 = torch.cat(gathered_embedding1, dim=0) - gathered_embedding2 = torch.cat(gathered_embedding2, dim=0) + gathered_embeddings = [torch.cat(gathered_embedding, dim=0) for gathered_embedding in gathered_embeddings] # Calculate logits - if self.local_loss: - logits_per_embedding1 = logit_scale * embedding1 @ gathered_embedding2.T - logits_per_embedding2 = logit_scale * embedding2 @ gathered_embedding1.T - else: - logits_per_embedding1 = logit_scale * gathered_embedding1 @ gathered_embedding2.T - logits_per_embedding2 = logits_per_embedding1.T + logits_per_embeddings = [] + for i, embedding in enumerate(embeddings): + for j, gathered_embedding in enumerate(gathered_embeddings): + if i != j: + if self.local_loss: + logits = logit_scale * embedding @ gathered_embedding.T + else: + logits = logit_scale * gathered_embeddings[i] @ gathered_embeddings[j].T + logits_per_embeddings.append(logits) # Build gt labels for diagonal - num_logits = logits_per_embedding1.shape[0] + num_logits = logits_per_embeddings[0].shape[0] labels = torch.arange(num_logits, device=device, dtype=torch.long) if world_size > 1 and self.local_loss: labels = labels + num_logits * rank + ## MODIFIED # Calculate loss - clip_loss = ( - F.cross_entropy(logits_per_embedding1, labels) + F.cross_entropy(logits_per_embedding2, labels) - ) / 2 + losses = None + for logits in logits_per_embeddings: + if losses is None: + losses = F.cross_entropy(logits, labels) + else: + losses += F.cross_entropy(logits, labels) + + clip_loss = losses.mean() return clip_loss diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index ee92e6d6e..8e14e2434 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -261,6 +261,13 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: output[self.vision_cls_prediction_key] = vision_cls_token modality_embd = vision_embd + ## MODIFIED + elif self.audio_sample_key and self.vision_sample_key in inputs: + audio_embd, audio_cls_token, vision_embd, vision_cls_token = self._forward_encode_audio_vision(inputs) + output[self.audio_cls_prediction_key] = audio_cls_token + output[self.vision_cls_prediction_key] = vision_cls_token + modality_embd = [audio_embd, vision_embd] + else: raise NotImplementedError("Parallel vision audio in the same batch is currently not supported!") @@ -298,6 +305,16 @@ def _forward_encode_audio(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch. audio_embd, audio_cls_token = audio_embd[:, :-1, :], F.normalize(audio_embd[:, -1, :], dim=-1) return audio_embd, audio_cls_token + ## MODIFIED + def _forward_encode_audio_vision( + self, inputs: Dict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + audio_inputs, vision_inputs = inputs + audio_embd, audio_cls_token = self._forward_encode_audio(audio_inputs) + vision_embd, vision_cls_token = self._forward_encode_vision(vision_inputs) + + return audio_embd, audio_cls_token, vision_embd, vision_cls_token + def _forward_encode_text(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: """ Encodes the input text using the text decoder. @@ -313,7 +330,10 @@ def _forward_encode_text(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.T text_embd, text_cls_token = text_embd[:, :-1, :], F.normalize(text_embd[:, -1, :], dim=-1) return text_embd, text_cls_token - def _forward_decode(self, text_embd: torch.Tensor, modality_embd: torch.Tensor) -> torch.Tensor: + ## MODIFIED + def _forward_decode( + self, text_embd: torch.Tensor, modality_embd: list[torch.Tensor] | torch.Tensor + ) -> torch.Tensor: """ Perform forward decoding using the given text and vision embeddings. diff --git a/src/modalities/models/coca/multi_modal_decoder.py b/src/modalities/models/coca/multi_modal_decoder.py index 09997c2f3..e8042fc99 100644 --- a/src/modalities/models/coca/multi_modal_decoder.py +++ b/src/modalities/models/coca/multi_modal_decoder.py @@ -3,6 +3,7 @@ import torch from torch import nn +from transformers import PreTrainedTokenizer from modalities.models.gpt2.gpt2_model import ActivationType from modalities.models.model import NNModel, SwiGLU @@ -76,13 +77,22 @@ def __init__( self.ln_4 = nn.LayerNorm(normalized_shape=n_embd, bias=bias, eps=epsilon) self.mlp_2 = mlp() - def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + if isinstance(self.with_context, list): + self.cross_attn2 = MultiHeadAttention( + n_embd=n_embd, + n_head=n_head, + bias=bias, + attention_config=attention_config, + attention_type=AttentionType.CROSS_ATTENTION, + ) + + def forward(self, x: torch.Tensor, context: list[torch.Tensor] | torch.Tensor | None = None) -> torch.Tensor: """ Forward pass of the TransformerBlock module. Args: x (torch.Tensor): Input tensor. - context (torch.Tensor, optional): Context tensor. Defaults to None. + context (list[torch.Tensor] | torch.Tensor, optional): Context tensor. Defaults to None. Returns: torch.Tensor: Output tensor. @@ -90,7 +100,11 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor x = x + self.attn(self.ln_1(x)) if not self.with_context or self.add_extra_mlp: x = x + self.mlp(self.ln_2(x)) - if self.with_context: + if self.with_context and isinstance(self.with_context, List): + x = self.ln_3(x) + x = x + self.cross_attn(x, context=context[0]) + self.cross_attn2(x, context=context[1]) + x = x + self.mlp_2(self.ln_4(x)) + else: x = x + self.cross_attn(self.ln_3(x), context=context) x = x + self.mlp_2(self.ln_4(x)) return x From 845f961e1581a98f1772881e79ec64b74d6385f7 Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Tue, 6 Aug 2024 10:09:53 +0000 Subject: [PATCH 102/161] fix: revert back to multiple builders assertion --- src/modalities/dataloader/dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 44dd3b724..a01982edd 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -775,6 +775,7 @@ def __init__( """ super().__init__() self.builders = builders + assert len(builders) == 1, "Multiple dataset builders are not supported yet" # TODO self.output_keys_by_modality = {} for b in builders: From 35e68128c3459100608ee8b1d688f7764bccff9f Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Tue, 6 Aug 2024 10:17:57 +0000 Subject: [PATCH 103/161] refactor: copy decord from github exp/vision_languauge_coca branch --- src/modalities/dataloader/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index a01982edd..b6ee998af 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -583,7 +583,7 @@ def decord_video(key, data): file_obj = io.BytesIO(data) # we could replace this with torchaudio.load(data) - ar = decord.AudioReader(file_obj, mono=False) + ar = decord.AudioReader(file_obj, sample_rate=16000, mono=True) audio = ar[:] # reset to start of file From f6f7023a0d9757a249c2b1fd54fd86eb64995b75 Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Tue, 6 Aug 2024 10:45:43 +0000 Subject: [PATCH 104/161] feat: add video-audio-text sample generation --- src/modalities/dataloader/dataset.py | 12 +++++------- .../audio_transformer/audio_transformer_model.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index b6ee998af..00ba3f9da 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -514,7 +514,7 @@ def __call__(self, raw_audio: tuple[torch.Tensor, int]) -> tuple[torch.Tensor, i torchaudio.transforms.TimeMasking(time_mask_param=self.time_domain_mask_length), ) - log_mel_spec = torch.clamp(self.extract_features(raw_audio[0]), 1e-10).log10().squeeze(0) + log_mel_spec = torch.clamp(self.extract_features(raw_audio[1]), 1e-10).log10().squeeze(0) log_mel_spec = self.masking(log_mel_spec) if self.is_training else log_mel_spec feats_len = log_mel_spec.shape[-1] // SUB_SAMPLING_FACTOR @@ -628,16 +628,15 @@ def __init__( ModalityEnum.TEXT: None, ModalityEnum.IMAGE: "pil", ModalityEnum.VIDEO: decord_video, - ModalityEnum.AUDIO: wds.torch_audio, + ModalityEnum.AUDIO: decord_video, } self.additional_extreacted_keys = [] - self.additional_extreacted_keys.append("modality") if ModalityEnum.TEXT in self.modality_transforms: self.additional_extreacted_keys.append("attention_mask") if ModalityEnum.AUDIO in self.modality_transforms: - self.additional_extreacted_keys.append("feats_len") + self.additional_extreacted_keys.append("audio_len") # Mapping between modality and transform self.modality_to_transform_fn = { @@ -701,15 +700,14 @@ def _transform_video(self, sample): source_key, target_key = self.modality_key_mapping[ModalityEnum.VIDEO] transform: VideoTransform = self.modality_transforms[ModalityEnum.VIDEO] sample[target_key] = transform(sample[source_key]) - del sample[source_key] + # del sample[source_key] return sample def _transform_audio(self, sample): source_key, target_key = self.modality_key_mapping[ModalityEnum.AUDIO] transform: AudioTransform = self.modality_transforms[ModalityEnum.AUDIO] - sample[target_key], sample["feats_len"] = transform(sample[source_key]) + sample[target_key], sample["audio_len"] = transform(sample[source_key]) del sample[source_key] - sample["modality"] = [0] return sample def _flatten_sample(self, sample): diff --git a/src/modalities/models/audio_transformer/audio_transformer_model.py b/src/modalities/models/audio_transformer/audio_transformer_model.py index eec4dbc6e..2244e4f55 100644 --- a/src/modalities/models/audio_transformer/audio_transformer_model.py +++ b/src/modalities/models/audio_transformer/audio_transformer_model.py @@ -197,7 +197,7 @@ def forward( inputs: Dict[str, tuple[torch.Tensor, torch.Tensor]], ) -> Dict[str, tuple[torch.Tensor, torch.Tensor]]: x = inputs[self.sample_key] # x.shape: B, T, D - attn_key_mask = self._get_attn_key_mask(inputs["feats_len"]) + attn_key_mask = self._get_attn_key_mask(inputs["audio_len"]) # x.shape: B, T, D x = self.project(x.transpose(1, 2)) # x.shape: B, D, T x = self.subsampler(x) # x.shape: B, D, T/4 From 21c9ee478c822d7aaf9e40766a2ed3f92ca4d6bd Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Tue, 6 Aug 2024 10:46:11 +0000 Subject: [PATCH 105/161] fix: video-audio-text forward pass --- src/modalities/models/coca/coca_model.py | 18 +++++----- .../models/coca/multi_modal_decoder.py | 35 +++++++++++-------- src/modalities/models/coca/text_decoder.py | 2 ++ 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index 8e14e2434..5c3b004a7 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -160,6 +160,8 @@ def __init__( self.epsilon_attn_pool = epsilon_attn_pool self.text_decoder_config = text_decoder_config + num_input_modalities = 0 + self.vision_sample_key = None if vision_encoder_config is not None: self.vision_sample_key = vision_encoder_config.sample_key @@ -168,6 +170,7 @@ def __init__( vision_encoder_config, n_queries, ) + num_input_modalities += 1 self.audio_sample_key = None if audio_encoder_config is not None: @@ -177,6 +180,7 @@ def __init__( audio_encoder_config, n_queries, ) + num_input_modalities += 1 self.text_decoder = TextDecoder( sample_key=text_decoder_config.sample_key, @@ -202,6 +206,7 @@ def __init__( n_head=text_decoder_config.n_head, n_embd=text_decoder_config.n_embd, ffn_hidden=text_decoder_config.ffn_hidden, + is_two_input_modalities=num_input_modalities == 2, dropout=text_decoder_config.dropout, bias=text_decoder_config.bias, attention_config=text_decoder_config.attention_config, @@ -251,26 +256,22 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: output = {} # TODO stack features from different modalities (ensure correct alignment with the text features) modality_embd = None - if self.audio_sample_key is not None and self.audio_sample_key in inputs: + if self.audio_sample_key is not None and self.vision_sample_key is None: audio_embd, audio_cls_token = self._forward_encode_audio(inputs) output[self.audio_cls_prediction_key] = audio_cls_token modality_embd = audio_embd - elif self.vision_sample_key is not None and self.vision_sample_key in inputs: + elif self.vision_sample_key is not None and self.audio_sample_key is None: vision_embd, vision_cls_token = self._forward_encode_vision(inputs) output[self.vision_cls_prediction_key] = vision_cls_token modality_embd = vision_embd - ## MODIFIED - elif self.audio_sample_key and self.vision_sample_key in inputs: + else: audio_embd, audio_cls_token, vision_embd, vision_cls_token = self._forward_encode_audio_vision(inputs) output[self.audio_cls_prediction_key] = audio_cls_token output[self.vision_cls_prediction_key] = vision_cls_token modality_embd = [audio_embd, vision_embd] - else: - raise NotImplementedError("Parallel vision audio in the same batch is currently not supported!") - text_embd, text_cls_token = self._forward_encode_text(inputs) logits = self._forward_decode(text_embd, modality_embd) output.update( @@ -309,7 +310,8 @@ def _forward_encode_audio(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch. def _forward_encode_audio_vision( self, inputs: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - audio_inputs, vision_inputs = inputs + audio_inputs = {k: inputs[k] for k in inputs if k in ["audio", "audio_len"]} + vision_inputs = {k: inputs[k] for k in inputs if k in ["video"]} audio_embd, audio_cls_token = self._forward_encode_audio(audio_inputs) vision_embd, vision_cls_token = self._forward_encode_vision(vision_inputs) diff --git a/src/modalities/models/coca/multi_modal_decoder.py b/src/modalities/models/coca/multi_modal_decoder.py index e8042fc99..ed1c87c09 100644 --- a/src/modalities/models/coca/multi_modal_decoder.py +++ b/src/modalities/models/coca/multi_modal_decoder.py @@ -24,6 +24,7 @@ def __init__( dropout: float, ffn_hidden: int, with_context: bool, + is_two_input_modalities: bool, attention_type: AttentionType, attention_config: AttentionConfig = None, add_extra_mlp: bool = False, @@ -47,6 +48,7 @@ def __init__( """ super().__init__() self.with_context = with_context + self.is_two_input_modalities = is_two_input_modalities self.add_extra_mlp = add_extra_mlp if activation == ActivationType.GELU: @@ -77,14 +79,14 @@ def __init__( self.ln_4 = nn.LayerNorm(normalized_shape=n_embd, bias=bias, eps=epsilon) self.mlp_2 = mlp() - if isinstance(self.with_context, list): - self.cross_attn2 = MultiHeadAttention( - n_embd=n_embd, - n_head=n_head, - bias=bias, - attention_config=attention_config, - attention_type=AttentionType.CROSS_ATTENTION, - ) + if self.is_two_input_modalities: + self.cross_attn2 = MultiHeadAttention( + n_embd=n_embd, + n_head=n_head, + bias=bias, + attention_config=attention_config, + attention_type=AttentionType.CROSS_ATTENTION, + ) def forward(self, x: torch.Tensor, context: list[torch.Tensor] | torch.Tensor | None = None) -> torch.Tensor: """ @@ -100,13 +102,14 @@ def forward(self, x: torch.Tensor, context: list[torch.Tensor] | torch.Tensor | x = x + self.attn(self.ln_1(x)) if not self.with_context or self.add_extra_mlp: x = x + self.mlp(self.ln_2(x)) - if self.with_context and isinstance(self.with_context, List): - x = self.ln_3(x) - x = x + self.cross_attn(x, context=context[0]) + self.cross_attn2(x, context=context[1]) - x = x + self.mlp_2(self.ln_4(x)) - else: - x = x + self.cross_attn(self.ln_3(x), context=context) - x = x + self.mlp_2(self.ln_4(x)) + if self.with_context: + if isinstance(context, List): + x = self.ln_3(x) + x = x + self.cross_attn(x, context=context[0]) + self.cross_attn2(x, context=context[1]) + x = x + self.mlp_2(self.ln_4(x)) + else: + x = x + self.cross_attn(self.ln_3(x), context=context) + x = x + self.mlp_2(self.ln_4(x)) return x @@ -123,6 +126,7 @@ def __init__( n_head: int, n_embd: int, ffn_hidden: int, + is_two_input_modalities: bool, dropout: float, bias: bool, activation: ActivationType, @@ -168,6 +172,7 @@ def __init__( dropout=dropout, ffn_hidden=ffn_hidden, with_context=True, + is_two_input_modalities=is_two_input_modalities, attention_type=AttentionType.CAUSAL_SELF_ATTENTION, attention_config=attention_config, add_extra_mlp=False, diff --git a/src/modalities/models/coca/text_decoder.py b/src/modalities/models/coca/text_decoder.py index 4e69bed26..26e5a569f 100644 --- a/src/modalities/models/coca/text_decoder.py +++ b/src/modalities/models/coca/text_decoder.py @@ -2,6 +2,7 @@ import torch from torch import nn +from transformers import PreTrainedTokenizer from modalities.models.coca.multi_modal_decoder import TransformerBlock from modalities.models.gpt2.gpt2_model import ActivationType @@ -68,6 +69,7 @@ def __init__( dropout=dropout, ffn_hidden=ffn_hidden, with_context=False, + is_two_input_modalities=False, attention_type=AttentionType.CAUSAL_SELF_ATTENTION, attention_config=attention_config, ) From 9a7823165089c049bfcbf6d3bc585608d35ce560 Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Tue, 6 Aug 2024 10:46:39 +0000 Subject: [PATCH 106/161] fix: gathered embeddings --- src/modalities/loss_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index c1478181e..36fc09d94 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -202,7 +202,7 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: world_size = dist.get_world_size() rank = dist.get_rank() - gathered_embeddings = [torch.zeros_like(embedding) for embedding in embeddings for _ in range(world_size)] + gathered_embeddings = [[torch.zeros_like(embedding) for embedding in embeddings] for _ in range(world_size)] for gathered_embedding, embedding in zip(gathered_embeddings, embeddings): dist.all_gather(gathered_embedding, embedding) From 44fa30659b66f240448656b66e5c88fff42cd9b5 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Wed, 7 Aug 2024 17:09:53 +0200 Subject: [PATCH 107/161] fix: use torchaudio to load audio from videos too --- src/modalities/dataloader/dataset.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 00ba3f9da..142c65b51 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -574,7 +574,7 @@ class MultimodalWebDatasetBuilderConfig(BaseModel): def decord_video(key, data): """Based on the torch_video decoder in webdataset - https://github.com/webdataset/webdataset/blob/5b12e0ba78bfb64741add2533c5d1e4cf088ffff/webdataset/autodecode.py#L394 + https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py#L394 """ extension = re.sub(r".*[.]", "", key) if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split(): @@ -597,6 +597,18 @@ def decord_video(key, data): return (frames, audio) +def torch_audio(key, data): + """Based on the torch_audio decoder in webdataset + https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py#L418 + """ + extension = re.sub(r".*[.]", "", key) + valid_extensions = "mp4 ogv mjpeg avi mov h264 mpg webm wmv flac mp3 sox wav m4a ogg wma".split() + if extension not in valid_extensions: + return None + + return torchaudio.load(data) + + # @register_component("dataset", "web_dataset_builder", MultimodalWebDatasetBuilderConfig) class MultimodalWebDatasetBuilder: def __init__( @@ -628,7 +640,7 @@ def __init__( ModalityEnum.TEXT: None, ModalityEnum.IMAGE: "pil", ModalityEnum.VIDEO: decord_video, - ModalityEnum.AUDIO: decord_video, + ModalityEnum.AUDIO: torch_audio, } self.additional_extreacted_keys = [] From a2112dacbffb70af829058f9ca3623238b6dbc89 Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Fri, 9 Aug 2024 07:26:35 +0000 Subject: [PATCH 108/161] refactor: simplify if statement --- src/modalities/models/coca/coca_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index 5c3b004a7..5f7348a86 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -256,12 +256,12 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: output = {} # TODO stack features from different modalities (ensure correct alignment with the text features) modality_embd = None - if self.audio_sample_key is not None and self.vision_sample_key is None: + if self.audio_sample_key and self.vision_sample_key is None: audio_embd, audio_cls_token = self._forward_encode_audio(inputs) output[self.audio_cls_prediction_key] = audio_cls_token modality_embd = audio_embd - elif self.vision_sample_key is not None and self.audio_sample_key is None: + elif self.vision_sample_key and self.audio_sample_key is None: vision_embd, vision_cls_token = self._forward_encode_vision(inputs) output[self.vision_cls_prediction_key] = vision_cls_token modality_embd = vision_embd From 8c12b8a88e13d84aad5e87341bdc79781004e356 Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Fri, 9 Aug 2024 07:30:05 +0000 Subject: [PATCH 109/161] refactor: improve readability --- src/modalities/models/coca/coca_model.py | 2 +- src/modalities/models/coca/multi_modal_decoder.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index 5f7348a86..e3bc30118 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -270,7 +270,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: audio_embd, audio_cls_token, vision_embd, vision_cls_token = self._forward_encode_audio_vision(inputs) output[self.audio_cls_prediction_key] = audio_cls_token output[self.vision_cls_prediction_key] = vision_cls_token - modality_embd = [audio_embd, vision_embd] + modality_embd = {"audio": audio_embd, "video": vision_embd} text_embd, text_cls_token = self._forward_encode_text(inputs) logits = self._forward_decode(text_embd, modality_embd) diff --git a/src/modalities/models/coca/multi_modal_decoder.py b/src/modalities/models/coca/multi_modal_decoder.py index ed1c87c09..5c60595cb 100644 --- a/src/modalities/models/coca/multi_modal_decoder.py +++ b/src/modalities/models/coca/multi_modal_decoder.py @@ -103,9 +103,9 @@ def forward(self, x: torch.Tensor, context: list[torch.Tensor] | torch.Tensor | if not self.with_context or self.add_extra_mlp: x = x + self.mlp(self.ln_2(x)) if self.with_context: - if isinstance(context, List): + if isinstance(context, Dict): x = self.ln_3(x) - x = x + self.cross_attn(x, context=context[0]) + self.cross_attn2(x, context=context[1]) + x = x + self.cross_attn(x, context=context["audio"]) + self.cross_attn2(x, context=context["video"]) x = x + self.mlp_2(self.ln_4(x)) else: x = x + self.cross_attn(self.ln_3(x), context=context) From efae85577340751447c435c0243d4b688134d679 Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Fri, 9 Aug 2024 07:32:16 +0000 Subject: [PATCH 110/161] feat: add first draft of single batch mixed modality --- config_files/example_separate_datasets.yaml | 362 ++++++++++++++++++++ src/modalities/models/coca/coca_model.py | 27 +- 2 files changed, 383 insertions(+), 6 deletions(-) create mode 100644 config_files/example_separate_datasets.yaml diff --git a/config_files/example_separate_datasets.yaml b/config_files/example_separate_datasets.yaml new file mode 100644 index 000000000..d199c4b94 --- /dev/null +++ b/config_files/example_separate_datasets.yaml @@ -0,0 +1,362 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + training: + global_training_log_interval_in_steps: 30 # Needs to be a multiple of gradient_acc_steps + global_checkpointing_interval_in_steps: 9_990 + global_evaluation_interval_in_steps: 4_980 + global_num_training_samples: 100_000 # 491 steps with 8 gpus and global bs of 1152 + global_num_seen_steps: 0 + do_apply_activation_checkpointing: true + gradient_acc_steps: 30 + local_train_micro_batch_size: 1 + sequence_length: 64 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints + +tokenizer: + component_key: tokenizer + variant_key: pretrained_hf_tokenizer + config: + pretrained_model_name_or_path: tokenizer/ + padding: true + max_length: ${settings.training.sequence_length} + +collate_fn: + component_key: collate_fn + variant_key: coca_collator + config: + sample_keys: + - video + - audio + - audio_len + - ${settings.referencing_keys.sample_key} + target_keys: [] + text_sample_key: ${settings.referencing_keys.sample_key} + text_target_key: ${settings.referencing_keys.target_key} + +train_audio_transform: + component_key: transform + variant_key: audio_transform + config: + is_training: True + block_size_audio_encoder: 300 + freq_domain_mask_length: 30 + time_domain_mask_length: 100 + +val_audio_transform: + component_key: transform + variant_key: audio_transform + config: + is_training: false + block_size_audio_encoder: 300 + +train_video_transform: + component_key: transform + variant_key: video_transform + config: + is_training: True + input_size: 288 + num_frames: ${model.config.vision_encoder_config.num_video_frames} + +val_video_transform: + component_key: transform + variant_key: video_transform + config: + is_training: True + input_size: 288 + num_frames: ${model.config.vision_encoder_config.num_video_frames} + +text_transform: + component_key: transform + variant_key: text_transform + config: + tokenizer: + instance_key: tokenizer + pass_type: BY_REFERENCE + +train_video_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "videodata/validation/000000.tar" + modality_key_mapping: + TEXT: ["json", "input_ids"] + VIDEO: ["mp4", "video"] + AUDIO: ["mp4","audio"] + modality_transforms: + AUDIO: + instance_key: train_audio_transform + pass_type: BY_REFERENCE + VIDEO: + instance_key: train_video_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 100_000 + +val_video_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "videodata/validation/000000.tar" + modality_key_mapping: + TEXT: ["json", "input_ids"] + VIDEO: ["mp4", "video"] + AUDIO: ["mp4","audio"] + modality_transforms: + AUDIO: + instance_key: val_audio_transform + pass_type: BY_REFERENCE + VIDEO: + instance_key: val_video_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 10_000 + +train_dataset: + component_key: dataset + variant_key: web_dataset + config: + builders: + - instance_key: train_video_builder + pass_type: BY_REFERENCE + shardshuffle: 100 + repeat: true + resample: true + shuffle_buffer: 10_000 + +val_dataset: + component_key: dataset + variant_key: web_dataset + config: + builders: + - instance_key: val_video_builder + pass_type: BY_REFERENCE + shardshuffle: 1000 + repeat: true + resample: true + shuffle_buffer: 10_000 + +train_dataloader: + component_key: data_loader + variant_key: web_loader + config: + num_workers: 2 + pin_memory: true + drop_last: true + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_size: ${settings.training.local_train_micro_batch_size} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +val_dataloader: + component_key: data_loader + variant_key: web_loader + config: + num_workers: 8 + pin_memory: true + drop_last: false + dataloader_tag: "val" + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_size: ${settings.training.local_train_micro_batch_size} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: fsdp + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +captioning_loss: + component_key: loss + variant_key: cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${model.config.prediction_key} + tag: captioning_loss + weight: 2.0 + +contrastive_loss: + component_key: loss + variant_key: clip_loss + config: + prediction_keys: + - ${model.config.individual_datasets_cls_prediction_key} + - ${model.config.text_cls_prediction_key} + logit_scale_key: ${model.config.logit_scale_prediction_key} + tag: contrastive_loss + +loss_fn: + - instance_key: captioning_loss + pass_type: BY_REFERENCE + - instance_key: contrastive_loss + pass_type: BY_REFERENCE + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: FP_16 + sharding_strategy: HYBRID_SHARD + block_names: [TransformerBlock, PerceiverTransformerBlock] + +model: + component_key: model + variant_key: coca + config: + prediction_key: logits + audio_embd_prediction_key: audio_embeddings + vision_embd_prediction_key: vision_embeddings + text_embd_prediction_key: text_embeddings + vision_cls_prediction_key: vision_cls + individual_datasets_cls_prediction_key: modalities_cls + audio_cls_prediction_key: audio_cls + text_cls_prediction_key: text_cls + logit_scale_prediction_key: logit_scale + audio_encoder_config: + sample_key: audio + prediction_key: audio_embeddings + block_size: 300 + n_mels: 128 + n_embd: 768 + n_heads: 12 + n_conformer_blocks: 3 + attention_config: + attention_engine_type: default_attention + pointwise_conv_kernel_size: 1 + depthwise_conv_kernel_size: 31 + vision_encoder_config: + sample_key: video + prediction_key: vision_embeddings + img_size: 288 # 288 in the original coca + n_classes: Null # Disable vision transformer head + n_layer: 3 + attention_config: + attention_engine_type: default_attention + n_head: 12 + n_embd: 768 + dropout: 0.0 + patch_size: 18 # 18 in the original coca + patch_stride: 18 # 18 in the original coca + n_img_channels: 3 + add_cls_token: False + bias: True + num_video_frames: 16 + n_latents: 64 + text_decoder_config: + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${model.config.prediction_key} + block_size: 77 + vocab_size: 49_408 # 64k in the original coca + n_layer_text: 2 + n_layer_multimodal_text: 1 + attention_config: + attention_engine_type: default_attention + n_head: 12 + ffn_hidden: 3072 + n_embd: 768 + dropout: 0.0 + bias: true + activation: fused_swiglu + epsilon: 1e-5 + n_pool_head: 12 + n_queries: 256 + bias_attn_pool: False + epsilon_attn_pool: 1e-5 + weight_init: + mean: 0.0 + std: 0.02 + +scheduler: + component_key: scheduler + variant_key: cosine_annealing_with_warmup_lr # COCA uses linear decay + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + num_warmup_steps: 2_000 + num_training_steps: 500_000 + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 8e-4 + betas: [0.9, 0.999] + eps: 1e-8 + weight_decay: 0.01 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp + config: + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.global_rank} + global_num_seen_steps: ${settings.training.global_num_seen_steps} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.global_rank} + project: modalities + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: "." + config_file_path: ${settings.config_file_path} diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index e3bc30118..f355859c7 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -83,6 +83,7 @@ class CoCaConfig(BaseModel): vision_embd_prediction_key: Optional[str] = None audio_cls_prediction_key: Optional[str] = None vision_cls_prediction_key: Optional[str] = None + individual_datasets_cls_prediction_key: Optional[str] = None audio_encoder_config: Optional[AudioTransformerConfig] = None vision_encoder_config: Optional[VisionTransformerConfig] = None text_decoder_config: TextDecoderConfig @@ -113,6 +114,7 @@ def __init__( vision_embd_prediction_key: Optional[str], audio_cls_prediction_key: Optional[str], vision_cls_prediction_key: Optional[str], + individual_datasets_cls_prediction_key: Optional[str], audio_encoder_config: Optional[AudioTransformerConfig], vision_encoder_config: Optional[VisionTransformerConfig], text_decoder_config: TextDecoderConfig, @@ -154,6 +156,7 @@ def __init__( self.vision_embd_prediction_key = vision_embd_prediction_key self.audio_cls_prediction_key = audio_cls_prediction_key self.vision_cls_prediction_key = vision_cls_prediction_key + self.individual_datasets_cls_prediction_key = individual_datasets_cls_prediction_key self.n_pool_head = n_pool_head self.bias_attn_pool = bias_attn_pool @@ -254,7 +257,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: dict[str, torch.Tensor]: Output dictionary. """ output = {} - # TODO stack features from different modalities (ensure correct alignment with the text features) + modality_embd = None if self.audio_sample_key and self.vision_sample_key is None: audio_embd, audio_cls_token = self._forward_encode_audio(inputs) @@ -267,13 +270,25 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: modality_embd = vision_embd else: - audio_embd, audio_cls_token, vision_embd, vision_cls_token = self._forward_encode_audio_vision(inputs) - output[self.audio_cls_prediction_key] = audio_cls_token - output[self.vision_cls_prediction_key] = vision_cls_token - modality_embd = {"audio": audio_embd, "video": vision_embd} + if self.individual_datasets_cls_prediction_key: # audio / vision / text BUT separate datasets + audio_embd, audio_cls_token, vision_embd, vision_cls_token = self._forward_encode_audio_vision(inputs) + output[self.individual_datasets_cls_prediction_key] = torch.cat([vision_cls_token, audio_cls_token]) + modality_embd = {"audio": audio_embd, "image": vision_embd} + else: # audio + vision from one single dataset + audio_embd, audio_cls_token, vision_embd, vision_cls_token = self._forward_encode_audio_vision(inputs) + output[self.audio_cls_prediction_key] = audio_cls_token + output[self.vision_cls_prediction_key] = vision_cls_token + modality_embd = {"audio": audio_embd, "video": vision_embd} text_embd, text_cls_token = self._forward_encode_text(inputs) - logits = self._forward_decode(text_embd, modality_embd) + if self.vision_sample_key and self.audio_sample_key and self.individual_datasets_cls_prediction_key: + image_text_embd, audio_text_embd = text_embd[: len(vision_embd), :], text_embd[len(vision_embd) :, :] + image_logits = self._forward_decode(image_text_embd, modality_embd["image"]) + audio_logits = self._forward_decode(audio_text_embd, modality_embd["audio"]) + logits = torch.cat([image_logits, audio_logits]) + else: + logits = self._forward_decode(text_embd, modality_embd) + output.update( { self.prediction_key: logits, From c1145062e1b14802f5bfa1dd9314b5bd9d1b0b5a Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Fri, 9 Aug 2024 11:46:01 +0200 Subject: [PATCH 111/161] fix: load audio from video file only if it exists, and behave the same way as audio-only datasets --- config_files/example_separate_datasets.yaml | 2 - src/modalities/dataloader/dataset.py | 43 ++++++++++++++------- src/modalities/models/coca/collator.py | 2 +- 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/config_files/example_separate_datasets.yaml b/config_files/example_separate_datasets.yaml index d199c4b94..8601e0290 100644 --- a/config_files/example_separate_datasets.yaml +++ b/config_files/example_separate_datasets.yaml @@ -90,7 +90,6 @@ train_video_builder: modality_key_mapping: TEXT: ["json", "input_ids"] VIDEO: ["mp4", "video"] - AUDIO: ["mp4","audio"] modality_transforms: AUDIO: instance_key: train_audio_transform @@ -111,7 +110,6 @@ val_video_builder: modality_key_mapping: TEXT: ["json", "input_ids"] VIDEO: ["mp4", "video"] - AUDIO: ["mp4","audio"] modality_transforms: AUDIO: instance_key: val_audio_transform diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 142c65b51..edb79c0ad 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -514,7 +514,7 @@ def __call__(self, raw_audio: tuple[torch.Tensor, int]) -> tuple[torch.Tensor, i torchaudio.transforms.TimeMasking(time_mask_param=self.time_domain_mask_length), ) - log_mel_spec = torch.clamp(self.extract_features(raw_audio[1]), 1e-10).log10().squeeze(0) + log_mel_spec = torch.clamp(self.extract_features(raw_audio[0]), 1e-10).log10().squeeze(0) log_mel_spec = self.masking(log_mel_spec) if self.is_training else log_mel_spec feats_len = log_mel_spec.shape[-1] // SUB_SAMPLING_FACTOR @@ -580,21 +580,24 @@ def decord_video(key, data): if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split(): return None - file_obj = io.BytesIO(data) - - # we could replace this with torchaudio.load(data) - ar = decord.AudioReader(file_obj, sample_rate=16000, mono=True) - audio = ar[:] + audio = None + audio_sample_rate = -1 + stream = torchaudio.io.StreamReader(data) + for idx in range(stream.num_src_streams): + if stream.get_src_stream_info(idx).media_type == "audio": + audio, audio_sample_rate = torchaudio.load(data) + if audio.shape[0] > 1: # more than one audio channel + audio = torch.mean(audio, dim=0) + break - # reset to start of file - file_obj.seek(0) + file_obj = io.BytesIO(data) vr = decord.VideoReader(file_obj) clip_num_frames = 64 # sample clip_num_frames uniformly from the full video frame_ids = torch.linspace(0, len(vr) - 1, clip_num_frames, dtype=torch.int64) frames = vr.get_batch(frame_ids.tolist()) # T x H x W x C - return (frames, audio) + return (frames, audio, audio_sample_rate) # audio can be None if no audio stream exists def torch_audio(key, data): @@ -606,7 +609,11 @@ def torch_audio(key, data): if extension not in valid_extensions: return None - return torchaudio.load(data) + # torchaudio.load returns (torch.Tensor, int) + audio, sample_rate = torchaudio.load(data) + if audio.shape[0] > 1: # more than one channel + audio = torch.mean(audio, dim=0) + return (audio, sample_rate) # @register_component("dataset", "web_dataset_builder", MultimodalWebDatasetBuilderConfig) @@ -630,7 +637,10 @@ def __init__( self.urls = urls self.modality_key_mapping = modality_key_mapping self.modality_transforms = modality_transforms - assert self.modality_key_mapping.keys() == self.modality_transforms.keys() + # transforms should be specified for all modality_key mappings, + # but we can also specify more transforms than necessary + # so modality_key_mappings should be a subset of modality_transforms + assert set(self.modality_key_mapping.keys()).issubset(self.modality_transforms.keys()) self.modalities = list(self.modality_key_mapping.keys()) self.num_samples = num_samples self.web_dataset = None @@ -647,7 +657,7 @@ def __init__( if ModalityEnum.TEXT in self.modality_transforms: self.additional_extreacted_keys.append("attention_mask") - if ModalityEnum.AUDIO in self.modality_transforms: + if ModalityEnum.AUDIO in self.modality_transforms or ModalityEnum.VIDEO in self.modality_transforms: self.additional_extreacted_keys.append("audio_len") # Mapping between modality and transform @@ -712,7 +722,13 @@ def _transform_video(self, sample): source_key, target_key = self.modality_key_mapping[ModalityEnum.VIDEO] transform: VideoTransform = self.modality_transforms[ModalityEnum.VIDEO] sample[target_key] = transform(sample[source_key]) - # del sample[source_key] + # if the video contains audio + if sample[source_key][1] is not None and ModalityEnum.AUDIO in self.modality_transforms: + transform: AudioTransform = self.modality_transforms[ModalityEnum.AUDIO] + sample["audio"], sample["audio_len"] = transform((sample[source_key][1], sample[source_key][2])) + if "audio" not in self.additional_extreacted_keys: + self.additional_extreacted_keys.append("audio") + del sample[source_key] return sample def _transform_audio(self, sample): @@ -785,7 +801,6 @@ def __init__( """ super().__init__() self.builders = builders - assert len(builders) == 1, "Multiple dataset builders are not supported yet" # TODO self.output_keys_by_modality = {} for b in builders: diff --git a/src/modalities/models/coca/collator.py b/src/modalities/models/coca/collator.py index 4dd11e10c..9104df605 100644 --- a/src/modalities/models/coca/collator.py +++ b/src/modalities/models/coca/collator.py @@ -73,7 +73,7 @@ def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: None. """ samples = { - sample_key: torch.stack([self._prepare_sample(d[sample_key]) for d in batch]) + sample_key: torch.stack([self._prepare_sample(d[sample_key]) for d in batch if sample_key in d]) for sample_key in self.sample_keys } if "attention_mask" in batch[0]: From 414031a11db093c3e81ea2f04f0b5ee4e58da2a6 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Mon, 12 Aug 2024 14:20:29 +0200 Subject: [PATCH 112/161] fix: keep original dimension for audio when averaging channels --- src/modalities/dataloader/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index edb79c0ad..b836291bf 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -587,7 +587,7 @@ def decord_video(key, data): if stream.get_src_stream_info(idx).media_type == "audio": audio, audio_sample_rate = torchaudio.load(data) if audio.shape[0] > 1: # more than one audio channel - audio = torch.mean(audio, dim=0) + audio = torch.mean(audio, dim=0, keepdim=True) break file_obj = io.BytesIO(data) @@ -612,7 +612,7 @@ def torch_audio(key, data): # torchaudio.load returns (torch.Tensor, int) audio, sample_rate = torchaudio.load(data) if audio.shape[0] > 1: # more than one channel - audio = torch.mean(audio, dim=0) + audio = torch.mean(audio, dim=0, keepdim=True) return (audio, sample_rate) From 7770b1c678b8a75db609cf53dca023104f9cb863 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Thu, 15 Aug 2024 13:29:42 +0200 Subject: [PATCH 113/161] fix: group input_ids based on modality; the order is determined by the order of modalities in the config file. e.g.: samples['input_ids'].shape = (10, max_length) images = samples['input_ids'][:4], audio = samples['input_ids'][4:] samples['images'].shape = (4, C, W, H), samples['audio'].shape = (6, L, N) samples['attention_mask'] = (10, max_length) --- src/modalities/models/coca/collator.py | 36 +++++++++++++++++++++----- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/src/modalities/models/coca/collator.py b/src/modalities/models/coca/collator.py index 9104df605..4ab88a3eb 100644 --- a/src/modalities/models/coca/collator.py +++ b/src/modalities/models/coca/collator.py @@ -72,13 +72,35 @@ def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: Raises: None. """ - samples = { - sample_key: torch.stack([self._prepare_sample(d[sample_key]) for d in batch if sample_key in d]) - for sample_key in self.sample_keys - } - if "attention_mask" in batch[0]: - samples["attention_mask"] = torch.stack([self._prepare_sample(d["attention_mask"]) for d in batch]) - + # only keys related to the other modalities (e.g. images, audio, video) + modality_keys = [key for key in self.sample_keys if key not in ["audio_len", self.text_sample_key]] + + samples = {sample_key: [] for sample_key in self.sample_keys if sample_key != self.text_sample_key} + text_samples = {sample_key: [] for sample_key in modality_keys} + attention_masks = {sample_key: [] for sample_key in modality_keys} + # gather samples by modality + for sample in batch: + for sample_key in self.sample_keys: + if sample_key in sample: + if sample_key in samples: + samples[sample_key].append(self._prepare_sample(sample[sample_key])) + if sample_key in text_samples: + text_samples[sample_key].append(self._prepare_sample(sample[self.text_sample_key])) + if "attention_mask" in sample and sample_key in attention_masks: + attention_masks[sample_key].append(self._prepare_sample(sample["attention_mask"])) + # stack samples by modality + for sample_key in self.sample_keys: + if sample_key in samples: + samples[sample_key] = torch.stack(samples[sample_key]) + if sample_key in text_samples: + text_samples[sample_key] = torch.stack(text_samples[sample_key]) + if sample_key in attention_masks: + attention_masks[sample_key] = torch.stack(attention_masks[sample_key]) + # stack input_ids and attention masks for all modalities + samples[self.text_sample_key] = torch.cat([text_samples[sample_key] for sample_key in modality_keys]) + samples["attention_mask"] = torch.cat([attention_masks[sample_key] for sample_key in modality_keys]) + + ## TODO: this will not work when there is data from multiple datasets per batch targets = { target_key: torch.stack([self._prepare_sample(d[target_key]) for d in batch]) for target_key in self.target_keys From d025dd447f6530632bba1b5a9a75e65988a27344 Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Fri, 16 Aug 2024 10:23:39 +0000 Subject: [PATCH 114/161] refactor: separate forward pass for audio-image and video-audio --- src/modalities/models/coca/coca_model.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index f355859c7..6fddd31fd 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -270,8 +270,8 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: modality_embd = vision_embd else: - if self.individual_datasets_cls_prediction_key: # audio / vision / text BUT separate datasets - audio_embd, audio_cls_token, vision_embd, vision_cls_token = self._forward_encode_audio_vision(inputs) + if self.individual_datasets: # audio / vision / text BUT separate datasets + audio_embd, audio_cls_token, vision_embd, vision_cls_token = self._forward_encode_audio_image(inputs) output[self.individual_datasets_cls_prediction_key] = torch.cat([vision_cls_token, audio_cls_token]) modality_embd = {"audio": audio_embd, "image": vision_embd} else: # audio + vision from one single dataset @@ -322,6 +322,16 @@ def _forward_encode_audio(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch. return audio_embd, audio_cls_token ## MODIFIED + def _forward_encode_audio_image( + self, inputs: Dict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + audio_inputs = {k: inputs[k] for k in inputs if k in ["audio", "audio_len"]} + vision_inputs = {k: inputs[k] for k in inputs if k in ["images"]} + audio_embd, audio_cls_token = self._forward_encode_audio(audio_inputs) + vision_embd, vision_cls_token = self._forward_encode_vision(vision_inputs) + + return audio_embd, audio_cls_token, vision_embd, vision_cls_token + def _forward_encode_audio_vision( self, inputs: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: From 8826758866ef1c2ede3b5ec2164fc61d5d62c7ca Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Fri, 16 Aug 2024 10:25:37 +0000 Subject: [PATCH 115/161] fix: audio-image forward o/p for contrastive loss --- src/modalities/models/coca/coca_model.py | 30 +++++++++++++++++++----- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index 6fddd31fd..cbcb78ed5 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -82,8 +82,10 @@ class CoCaConfig(BaseModel): audio_embd_prediction_key: Optional[str] = None vision_embd_prediction_key: Optional[str] = None audio_cls_prediction_key: Optional[str] = None + audio_text_cls_prediction_key: Optional[str] = None vision_cls_prediction_key: Optional[str] = None - individual_datasets_cls_prediction_key: Optional[str] = None + image_text_cls_prediction_key: Optional[str] = None + individual_datasets: Optional[bool] = None audio_encoder_config: Optional[AudioTransformerConfig] = None vision_encoder_config: Optional[VisionTransformerConfig] = None text_decoder_config: TextDecoderConfig @@ -113,8 +115,10 @@ def __init__( audio_embd_prediction_key: Optional[str], vision_embd_prediction_key: Optional[str], audio_cls_prediction_key: Optional[str], + audio_text_cls_prediction_key: Optional[str], vision_cls_prediction_key: Optional[str], - individual_datasets_cls_prediction_key: Optional[str], + image_text_cls_prediction_key: Optional[str], + individual_datasets: Optional[bool], audio_encoder_config: Optional[AudioTransformerConfig], vision_encoder_config: Optional[VisionTransformerConfig], text_decoder_config: TextDecoderConfig, @@ -155,8 +159,10 @@ def __init__( self.audio_embd_prediction_key = audio_embd_prediction_key self.vision_embd_prediction_key = vision_embd_prediction_key self.audio_cls_prediction_key = audio_cls_prediction_key + self.audio_text_cls_prediction_key = audio_text_cls_prediction_key self.vision_cls_prediction_key = vision_cls_prediction_key - self.individual_datasets_cls_prediction_key = individual_datasets_cls_prediction_key + self.image_text_cls_prediction_key = image_text_cls_prediction_key + self.individual_datasets = individual_datasets self.n_pool_head = n_pool_head self.bias_attn_pool = bias_attn_pool @@ -272,7 +278,8 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: else: if self.individual_datasets: # audio / vision / text BUT separate datasets audio_embd, audio_cls_token, vision_embd, vision_cls_token = self._forward_encode_audio_image(inputs) - output[self.individual_datasets_cls_prediction_key] = torch.cat([vision_cls_token, audio_cls_token]) + output[self.audio_cls_prediction_key] = audio_cls_token + output[self.vision_cls_prediction_key] = vision_cls_token modality_embd = {"audio": audio_embd, "image": vision_embd} else: # audio + vision from one single dataset audio_embd, audio_cls_token, vision_embd, vision_cls_token = self._forward_encode_audio_vision(inputs) @@ -281,8 +288,19 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: modality_embd = {"audio": audio_embd, "video": vision_embd} text_embd, text_cls_token = self._forward_encode_text(inputs) - if self.vision_sample_key and self.audio_sample_key and self.individual_datasets_cls_prediction_key: - image_text_embd, audio_text_embd = text_embd[: len(vision_embd), :], text_embd[len(vision_embd) :, :] + if self.vision_sample_key and self.audio_sample_key and self.individual_datasets: + image_text_cls_token, audio_text_cls_token = ( + text_cls_token[: len(vision_embd)], + text_cls_token[len(vision_embd) :], + ) + output.update( + { + self.image_text_cls_prediction_key: image_text_cls_token, + self.audio_text_cls_prediction_key: audio_text_cls_token, + } + ) + + image_text_embd, audio_text_embd = text_embd[: len(vision_embd)], text_embd[len(vision_embd) :] image_logits = self._forward_decode(image_text_embd, modality_embd["image"]) audio_logits = self._forward_decode(audio_text_embd, modality_embd["audio"]) logits = torch.cat([image_logits, audio_logits]) From 5435d4b31baa96d33a89f22abed8fd64b0734e03 Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Fri, 16 Aug 2024 10:32:24 +0000 Subject: [PATCH 116/161] refactor: revert back to wds.torch_audio --- src/modalities/dataloader/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index b836291bf..94fd480cf 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -650,7 +650,7 @@ def __init__( ModalityEnum.TEXT: None, ModalityEnum.IMAGE: "pil", ModalityEnum.VIDEO: decord_video, - ModalityEnum.AUDIO: torch_audio, + ModalityEnum.AUDIO: wds.torch_audio, } self.additional_extreacted_keys = [] From 7c3c4bad85646f0a8a9285b3cfdd47a54dade823 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Mon, 19 Aug 2024 13:03:47 +0200 Subject: [PATCH 117/161] fix: only collect text samples once for case where a single dataset has video and audio in the coca collator --- src/modalities/models/coca/collator.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/modalities/models/coca/collator.py b/src/modalities/models/coca/collator.py index 4ab88a3eb..42b4fcdde 100644 --- a/src/modalities/models/coca/collator.py +++ b/src/modalities/models/coca/collator.py @@ -80,14 +80,22 @@ def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: attention_masks = {sample_key: [] for sample_key in modality_keys} # gather samples by modality for sample in batch: + text_sample_added = False # make sure text is only added once per sample for sample_key in self.sample_keys: if sample_key in sample: if sample_key in samples: samples[sample_key].append(self._prepare_sample(sample[sample_key])) - if sample_key in text_samples: - text_samples[sample_key].append(self._prepare_sample(sample[self.text_sample_key])) - if "attention_mask" in sample and sample_key in attention_masks: + if "attention_mask" in sample and sample_key in attention_masks and not text_sample_added: attention_masks[sample_key].append(self._prepare_sample(sample["attention_mask"])) + if sample_key in text_samples and not text_sample_added: + text_samples[sample_key].append(self._prepare_sample(sample[self.text_sample_key])) + text_sample_added = True + # remove keys with no samples + for sample_key in modality_keys: + if len(text_samples[sample_key]) == 0: + del text_samples[sample_key] + if len(attention_masks[sample_key]) == 0: + del attention_masks[sample_key] # stack samples by modality for sample_key in self.sample_keys: if sample_key in samples: @@ -97,8 +105,8 @@ def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: if sample_key in attention_masks: attention_masks[sample_key] = torch.stack(attention_masks[sample_key]) # stack input_ids and attention masks for all modalities - samples[self.text_sample_key] = torch.cat([text_samples[sample_key] for sample_key in modality_keys]) - samples["attention_mask"] = torch.cat([attention_masks[sample_key] for sample_key in modality_keys]) + samples[self.text_sample_key] = torch.cat([text_samples[sample_key] for sample_key in text_samples]) + samples["attention_mask"] = torch.cat([attention_masks[sample_key] for sample_key in attention_masks]) ## TODO: this will not work when there is data from multiple datasets per batch targets = { From bc28749e813a8614d177597c699f38742ca8a8fb Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Mon, 19 Aug 2024 13:05:42 +0200 Subject: [PATCH 118/161] refactor: coca: separate vision into image and video, and refactor forward function to be simpler --- src/modalities/models/coca/coca_model.py | 173 +++++++++++++---------- 1 file changed, 95 insertions(+), 78 deletions(-) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index cbcb78ed5..5bce3a057 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -80,14 +80,18 @@ class CoCaConfig(BaseModel): text_cls_prediction_key: str logit_scale_prediction_key: str audio_embd_prediction_key: Optional[str] = None - vision_embd_prediction_key: Optional[str] = None + image_embd_prediction_key: Optional[str] = None + video_embd_prediction_key: Optional[str] = None audio_cls_prediction_key: Optional[str] = None audio_text_cls_prediction_key: Optional[str] = None - vision_cls_prediction_key: Optional[str] = None + image_cls_prediction_key: Optional[str] = None image_text_cls_prediction_key: Optional[str] = None + video_cls_prediction_key: Optional[str] = None + video_text_cls_prediction_key: Optional[str] = None individual_datasets: Optional[bool] = None audio_encoder_config: Optional[AudioTransformerConfig] = None - vision_encoder_config: Optional[VisionTransformerConfig] = None + image_encoder_config: Optional[VisionTransformerConfig] = None + video_encoder_config: Optional[VisionTransformerConfig] = None text_decoder_config: TextDecoderConfig n_pool_head: Annotated[int, Field(ge=1)] n_queries: Optional[Annotated[int, Field(ge=1)]] @@ -113,14 +117,18 @@ def __init__( text_cls_prediction_key: str, logit_scale_prediction_key: str, audio_embd_prediction_key: Optional[str], - vision_embd_prediction_key: Optional[str], + image_embd_prediction_key: Optional[str], + video_embd_prediction_key: Optional[str], audio_cls_prediction_key: Optional[str], audio_text_cls_prediction_key: Optional[str], - vision_cls_prediction_key: Optional[str], + image_cls_prediction_key: Optional[str], image_text_cls_prediction_key: Optional[str], + video_cls_prediction_key: Optional[str], + video_text_cls_prediction_key: Optional[str], individual_datasets: Optional[bool], audio_encoder_config: Optional[AudioTransformerConfig], - vision_encoder_config: Optional[VisionTransformerConfig], + image_encoder_config: Optional[VisionTransformerConfig], + video_encoder_config: Optional[VisionTransformerConfig], text_decoder_config: TextDecoderConfig, n_pool_head: int, n_queries: Optional[int], @@ -157,11 +165,14 @@ def __init__( self.text_cls_prediction_key = text_cls_prediction_key self.audio_embd_prediction_key = audio_embd_prediction_key - self.vision_embd_prediction_key = vision_embd_prediction_key + self.image_embd_prediction_key = image_embd_prediction_key + self.video_embd_prediction_key = video_embd_prediction_key self.audio_cls_prediction_key = audio_cls_prediction_key self.audio_text_cls_prediction_key = audio_text_cls_prediction_key - self.vision_cls_prediction_key = vision_cls_prediction_key + self.image_cls_prediction_key = image_cls_prediction_key self.image_text_cls_prediction_key = image_text_cls_prediction_key + self.video_cls_prediction_key = video_cls_prediction_key + self.video_text_cls_prediction_key = video_text_cls_prediction_key self.individual_datasets = individual_datasets self.n_pool_head = n_pool_head @@ -171,12 +182,22 @@ def __init__( num_input_modalities = 0 - self.vision_sample_key = None - if vision_encoder_config is not None: - self.vision_sample_key = vision_encoder_config.sample_key - self.vision_encoder, self.vision_queries, self.vision_attn_pool = self._init_modality( + self.image_sample_key = None + if image_encoder_config is not None: + self.image_sample_key = image_encoder_config.sample_key + self.image_encoder, self.image_queries, self.image_attn_pool = self._init_modality( VisionTransformer, - vision_encoder_config, + image_encoder_config, + n_queries, + ) + num_input_modalities += 1 + + self.video_sample_key = None + if video_encoder_config is not None: + self.video_sample_key = video_encoder_config.sample_key + self.video_encoder, self.video_queries, self.video_attn_pool = self._init_modality( + VisionTransformer, + video_encoder_config, n_queries, ) num_input_modalities += 1 @@ -264,59 +285,70 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ output = {} - modality_embd = None - if self.audio_sample_key and self.vision_sample_key is None: + # encode modalities + image_embd = audio_embd = video_embd = None + if self.image_sample_key: + image_embd, image_cls_token = self._forward_encode_image(inputs) + output[self.image_cls_prediction_key] = image_cls_token + + if self.audio_sample_key: audio_embd, audio_cls_token = self._forward_encode_audio(inputs) output[self.audio_cls_prediction_key] = audio_cls_token - modality_embd = audio_embd - - elif self.vision_sample_key and self.audio_sample_key is None: - vision_embd, vision_cls_token = self._forward_encode_vision(inputs) - output[self.vision_cls_prediction_key] = vision_cls_token - modality_embd = vision_embd - - else: - if self.individual_datasets: # audio / vision / text BUT separate datasets - audio_embd, audio_cls_token, vision_embd, vision_cls_token = self._forward_encode_audio_image(inputs) - output[self.audio_cls_prediction_key] = audio_cls_token - output[self.vision_cls_prediction_key] = vision_cls_token - modality_embd = {"audio": audio_embd, "image": vision_embd} - else: # audio + vision from one single dataset - audio_embd, audio_cls_token, vision_embd, vision_cls_token = self._forward_encode_audio_vision(inputs) - output[self.audio_cls_prediction_key] = audio_cls_token - output[self.vision_cls_prediction_key] = vision_cls_token - modality_embd = {"audio": audio_embd, "video": vision_embd} + if self.video_sample_key: + video_embd, video_cls_token = self._forward_encode_video(inputs) + output[self.video_cls_prediction_key] = video_cls_token + + # encode text text_embd, text_cls_token = self._forward_encode_text(inputs) - if self.vision_sample_key and self.audio_sample_key and self.individual_datasets: - image_text_cls_token, audio_text_cls_token = ( - text_cls_token[: len(vision_embd)], - text_cls_token[len(vision_embd) :], - ) - output.update( - { - self.image_text_cls_prediction_key: image_text_cls_token, - self.audio_text_cls_prediction_key: audio_text_cls_token, - } - ) - image_text_embd, audio_text_embd = text_embd[: len(vision_embd)], text_embd[len(vision_embd) :] - image_logits = self._forward_decode(image_text_embd, modality_embd["image"]) - audio_logits = self._forward_decode(audio_text_embd, modality_embd["audio"]) - logits = torch.cat([image_logits, audio_logits]) - else: + # decode modality + text + if self.individual_datasets: # multiple modalities (from different datasets) + start = 0 + modality_logits = [] + if image_embd is not None: + image_text_cls_token = text_cls_token[: len(image_embd)] + output.update({self.image_text_cls_prediction_key: image_text_cls_token}) + image_text_embd = text_embd[: len(image_embd)] + image_logits = self._forward_decode(image_text_embd, image_embd) + modality_logits.append(image_logits) + start = start + len(image_embd) + if audio_embd is not None: + audio_text_cls_token = text_cls_token[start : start + len(audio_embd)] + output.update({self.audio_text_cls_prediction_key: audio_text_cls_token}) + audio_text_embd = text_embd[start : start + len(audio_embd)] + audio_logits = self._forward_decode(audio_text_embd, audio_embd) + modality_logits.append(audio_logits) + start = start + len(audio_embd) + if video_embd is not None: + video_text_cls_token = text_cls_token[start:] + output.update({self.video_text_cls_prediction_key: video_text_cls_token}) + video_text_embd = text_embd[start:] + video_logits = self._forward_decode(video_text_embd, video_embd) + modality_logits.append(video_logits) + logits = torch.cat(modality_logits) + elif audio_embd is not None and video_embd is not None: # video dataset that contains audio + modality_embd = {"audio": audio_embd, "video": video_embd} logits = self._forward_decode(text_embd, modality_embd) + output.update({self.text_cls_prediction_key: text_cls_token}) + else: # single modality + output.update({self.text_cls_prediction_key: text_cls_token}) + if image_embd is not None: + logits = self._forward_decode(text_embd, image_embd) + elif audio_embd is not None: + logits = self._forward_decode(text_embd, audio_embd) + elif video_embd is not None: + logits = self._forward_decode(text_embd, video_embd) output.update( { self.prediction_key: logits, - self.text_cls_prediction_key: text_cls_token, self.logit_scale_prediction_key: self.logit_scale.exp(), } ) return output - def _forward_encode_vision(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + def _forward_encode_image(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: """ Encodes the input image using the vision encoder. @@ -326,11 +358,18 @@ def _forward_encode_vision(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple containing encoded vision embeddings and classification token. """ - vision_embd = self.vision_encoder(inputs)[self.modality_embd_prediction_key] - queries = repeat(self.vision_queries, "n d -> b n d", b=vision_embd.shape[0]) - vision_embd = self.vision_attn_pool(queries, context=vision_embd) - vision_embd, vision_cls_token = vision_embd[:, :-1, :], F.normalize(vision_embd[:, -1, :], dim=-1) - return vision_embd, vision_cls_token + image_embd = self.image_encoder(inputs)[self.image_embd_prediction_key] + queries = repeat(self.image_queries, "n d -> b n d", b=image_embd.shape[0]) + image_embd = self.image_attn_pool(queries, context=image_embd) + image_embd, image_cls_token = image_embd[:, :-1, :], F.normalize(image_embd[:, -1, :], dim=-1) + return image_embd, image_cls_token + + def _forward_encode_video(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + video_embd = self.video_encoder(inputs)[self.video_embd_prediction_key] + queries = repeat(self.video_queries, "n d -> b n d", b=video_embd.shape[0]) + video_embd = self.video_attn_pool(queries, context=video_embd) + video_embd, video_cls_token = video_embd[:, :-1, :], F.normalize(video_embd[:, -1, :], dim=-1) + return video_embd, video_cls_token def _forward_encode_audio(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: audio_embd = self.audio_encoder(inputs)[self.audio_embd_prediction_key] @@ -339,27 +378,6 @@ def _forward_encode_audio(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch. audio_embd, audio_cls_token = audio_embd[:, :-1, :], F.normalize(audio_embd[:, -1, :], dim=-1) return audio_embd, audio_cls_token - ## MODIFIED - def _forward_encode_audio_image( - self, inputs: Dict[str, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - audio_inputs = {k: inputs[k] for k in inputs if k in ["audio", "audio_len"]} - vision_inputs = {k: inputs[k] for k in inputs if k in ["images"]} - audio_embd, audio_cls_token = self._forward_encode_audio(audio_inputs) - vision_embd, vision_cls_token = self._forward_encode_vision(vision_inputs) - - return audio_embd, audio_cls_token, vision_embd, vision_cls_token - - def _forward_encode_audio_vision( - self, inputs: Dict[str, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - audio_inputs = {k: inputs[k] for k in inputs if k in ["audio", "audio_len"]} - vision_inputs = {k: inputs[k] for k in inputs if k in ["video"]} - audio_embd, audio_cls_token = self._forward_encode_audio(audio_inputs) - vision_embd, vision_cls_token = self._forward_encode_vision(vision_inputs) - - return audio_embd, audio_cls_token, vision_embd, vision_cls_token - def _forward_encode_text(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: """ Encodes the input text using the text decoder. @@ -375,7 +393,6 @@ def _forward_encode_text(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.T text_embd, text_cls_token = text_embd[:, :-1, :], F.normalize(text_embd[:, -1, :], dim=-1) return text_embd, text_cls_token - ## MODIFIED def _forward_decode( self, text_embd: torch.Tensor, modality_embd: list[torch.Tensor] | torch.Tensor ) -> torch.Tensor: From c8b9f65dca5633e9b71d590054af312496a6ed5f Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Mon, 19 Aug 2024 14:08:42 +0200 Subject: [PATCH 119/161] fix: webdataset builder: remove constraint that all dataset builders have the same type (required if we want to train with several datasets of different modalities) --- src/modalities/dataloader/dataset.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 94fd480cf..55bd831f9 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -802,16 +802,6 @@ def __init__( super().__init__() self.builders = builders - self.output_keys_by_modality = {} - for b in builders: - for k, v in b.modality_key_mapping.items(): - if k not in self.output_keys_by_modality: - self.output_keys_by_modality[k] = v[1] - else: - assert ( - self.output_keys_by_modality[k] == v[1] - ), "Output keys for the same modality of all builders should be the same." - # Build datasets [ b.prepare(shardshuffle=shardshuffle, resample=resample, repeat=repeat, shuffle_buffer=shuffle_buffer) From ea8910c4b810f5e8c5eae70bae8d07372c124559 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Mon, 19 Aug 2024 16:57:21 +0200 Subject: [PATCH 120/161] fix: add config parameter for video-audio-text dataset as a special case --- .../config_example_coca_webdataset.yaml | 2 + .../config_example_video_coca_webdataset.yaml | 2 + src/modalities/dataloader/dataset.py | 24 ++++--- src/modalities/models/coca/coca_model.py | 62 ++++++++++--------- .../models/coca/multi_modal_decoder.py | 11 ++-- src/modalities/models/coca/text_decoder.py | 3 +- 6 files changed, 61 insertions(+), 43 deletions(-) diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml index 5eb177e04..e76858cc7 100644 --- a/config_files/training/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -223,6 +223,8 @@ model: text_embd_prediction_key: text_embeddings vision_cls_prediction_key: vision_cls text_cls_prediction_key: text_cls + individual_datasets: false + modality_keys: ${collate_fn.config.sample_keys} logit_scale_prediction_key: logit_scale vision_encoder_config: sample_key: images diff --git a/config_files/training/config_example_video_coca_webdataset.yaml b/config_files/training/config_example_video_coca_webdataset.yaml index 1466b9dc4..240f3c053 100644 --- a/config_files/training/config_example_video_coca_webdataset.yaml +++ b/config_files/training/config_example_video_coca_webdataset.yaml @@ -196,6 +196,8 @@ model: text_embd_prediction_key: text_embeddings vision_cls_prediction_key: vision_cls text_cls_prediction_key: text_cls + individual_datasets: false + modality_keys: ${collate_fn.config.sample_keys} logit_scale_prediction_key: logit_scale vision_encoder_config: sample_key: videos diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 55bd831f9..23cdf4c43 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -565,13 +565,6 @@ def __call__(self, video): return self.spatial_transform(video) -class MultimodalWebDatasetBuilderConfig(BaseModel): - urls: Union[List[str], str] - modality_key_mapping: Dict[ModalityEnum, Tuple[str, str]] - modality_transforms: Dict[ModalityEnum, PydanticTransformIFType] - num_samples: Annotated[int, Field(ge=1)] - - def decord_video(key, data): """Based on the torch_video decoder in webdataset https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py#L394 @@ -616,6 +609,14 @@ def torch_audio(key, data): return (audio, sample_rate) +class MultimodalWebDatasetBuilderConfig(BaseModel): + urls: Union[List[str], str] + modality_key_mapping: Dict[ModalityEnum, Tuple[str, str]] + modality_transforms: Dict[ModalityEnum, PydanticTransformIFType] + is_audio_video: Optional[bool] = False + num_samples: Annotated[int, Field(ge=1)] + + # @register_component("dataset", "web_dataset_builder", MultimodalWebDatasetBuilderConfig) class MultimodalWebDatasetBuilder: def __init__( @@ -623,6 +624,7 @@ def __init__( urls: Union[List[str], str], modality_key_mapping: Dict[str, Tuple[str, str]], modality_transforms: Dict[str, Transform], + is_audio_video: bool, num_samples: int, ): """A multimodal dataset instance for the WebDataset. @@ -633,8 +635,10 @@ def __init__( For example: {ModalityEnum.IMAGE: ("jpg", "image"), ModalityEnum.TEXT: ("text", "caption")}} modality_transforms: The transforms for each modality. num_samples: The number of samples for each modality combination. + is_audio_video: Whether the dataset is a video dataset which contains audio """ self.urls = urls + self.is_audio_video = is_audio_video self.modality_key_mapping = modality_key_mapping self.modality_transforms = modality_transforms # transforms should be specified for all modality_key mappings, @@ -802,6 +806,12 @@ def __init__( super().__init__() self.builders = builders + for builder in self.builders: + if builder.is_audio_video and len(self.builders) > 1: + raise NotImplementedError( + "It is not yet possible to include a video-audio dataset with other types of modalities" + ) + # Build datasets [ b.prepare(shardshuffle=shardshuffle, resample=resample, repeat=repeat, shuffle_buffer=shuffle_buffer) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index 5bce3a057..dc7200dca 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -1,6 +1,6 @@ import math from functools import partial -from typing import Annotated, Dict, Optional, Tuple +from typing import Annotated, Dict, List, Optional, Tuple import numpy as np import torch @@ -88,7 +88,9 @@ class CoCaConfig(BaseModel): image_text_cls_prediction_key: Optional[str] = None video_cls_prediction_key: Optional[str] = None video_text_cls_prediction_key: Optional[str] = None - individual_datasets: Optional[bool] = None + modality_keys: List[str] + individual_datasets: Optional[bool] = False + is_audio_video: Optional[bool] = False audio_encoder_config: Optional[AudioTransformerConfig] = None image_encoder_config: Optional[VisionTransformerConfig] = None video_encoder_config: Optional[VisionTransformerConfig] = None @@ -125,7 +127,9 @@ def __init__( image_text_cls_prediction_key: Optional[str], video_cls_prediction_key: Optional[str], video_text_cls_prediction_key: Optional[str], + modality_keys: List[str], individual_datasets: Optional[bool], + is_audio_video: Optional[bool], audio_encoder_config: Optional[AudioTransformerConfig], image_encoder_config: Optional[VisionTransformerConfig], video_encoder_config: Optional[VisionTransformerConfig], @@ -173,15 +177,16 @@ def __init__( self.image_text_cls_prediction_key = image_text_cls_prediction_key self.video_cls_prediction_key = video_cls_prediction_key self.video_text_cls_prediction_key = video_text_cls_prediction_key + + self.modality_keys = modality_keys self.individual_datasets = individual_datasets + self.is_audio_video = is_audio_video self.n_pool_head = n_pool_head self.bias_attn_pool = bias_attn_pool self.epsilon_attn_pool = epsilon_attn_pool self.text_decoder_config = text_decoder_config - num_input_modalities = 0 - self.image_sample_key = None if image_encoder_config is not None: self.image_sample_key = image_encoder_config.sample_key @@ -190,7 +195,6 @@ def __init__( image_encoder_config, n_queries, ) - num_input_modalities += 1 self.video_sample_key = None if video_encoder_config is not None: @@ -200,7 +204,6 @@ def __init__( video_encoder_config, n_queries, ) - num_input_modalities += 1 self.audio_sample_key = None if audio_encoder_config is not None: @@ -210,7 +213,6 @@ def __init__( audio_encoder_config, n_queries, ) - num_input_modalities += 1 self.text_decoder = TextDecoder( sample_key=text_decoder_config.sample_key, @@ -236,7 +238,7 @@ def __init__( n_head=text_decoder_config.n_head, n_embd=text_decoder_config.n_embd, ffn_hidden=text_decoder_config.ffn_hidden, - is_two_input_modalities=num_input_modalities == 2, + is_audio_video=self.is_audio_video, dropout=text_decoder_config.dropout, bias=text_decoder_config.bias, attention_config=text_decoder_config.attention_config, @@ -306,26 +308,30 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if self.individual_datasets: # multiple modalities (from different datasets) start = 0 modality_logits = [] - if image_embd is not None: - image_text_cls_token = text_cls_token[: len(image_embd)] - output.update({self.image_text_cls_prediction_key: image_text_cls_token}) - image_text_embd = text_embd[: len(image_embd)] - image_logits = self._forward_decode(image_text_embd, image_embd) - modality_logits.append(image_logits) - start = start + len(image_embd) - if audio_embd is not None: - audio_text_cls_token = text_cls_token[start : start + len(audio_embd)] - output.update({self.audio_text_cls_prediction_key: audio_text_cls_token}) - audio_text_embd = text_embd[start : start + len(audio_embd)] - audio_logits = self._forward_decode(audio_text_embd, audio_embd) - modality_logits.append(audio_logits) - start = start + len(audio_embd) - if video_embd is not None: - video_text_cls_token = text_cls_token[start:] - output.update({self.video_text_cls_prediction_key: video_text_cls_token}) - video_text_embd = text_embd[start:] - video_logits = self._forward_decode(video_text_embd, video_embd) - modality_logits.append(video_logits) + # this ensures that we select the text input_ids corresponding to each modality_key in the order + # they are stacked by the collator + for modality_key in self.modality_keys: + if modality_key == "images" and image_embd is not None: + image_text_cls_token = text_cls_token[start : start + len(image_embd)] + image_text_embd = text_embd[start : start + len(image_embd)] + image_logits = self._forward_decode(image_text_embd, image_embd) + output.update({self.image_text_cls_prediction_key: image_text_cls_token}) + modality_logits.append(image_logits) + start = start + len(image_embd) + if modality_key == "audio" and audio_embd is not None: + audio_text_cls_token = text_cls_token[start : start + len(audio_embd)] + audio_text_embd = text_embd[start : start + len(audio_embd)] + audio_logits = self._forward_decode(audio_text_embd, audio_embd) + output.update({self.audio_text_cls_prediction_key: audio_text_cls_token}) + modality_logits.append(audio_logits) + start = start + len(audio_embd) + if modality_key == "video" and video_embd is not None: + video_text_cls_token = text_cls_token[start : start + len(video_embd)] + video_text_embd = text_embd[start : start + len(video_embd)] + video_logits = self._forward_decode(video_text_embd, video_embd) + output.update({self.video_text_cls_prediction_key: video_text_cls_token}) + modality_logits.append(video_logits) + start = start + len(video_embd) logits = torch.cat(modality_logits) elif audio_embd is not None and video_embd is not None: # video dataset that contains audio modality_embd = {"audio": audio_embd, "video": video_embd} diff --git a/src/modalities/models/coca/multi_modal_decoder.py b/src/modalities/models/coca/multi_modal_decoder.py index 5c60595cb..dbf5409b9 100644 --- a/src/modalities/models/coca/multi_modal_decoder.py +++ b/src/modalities/models/coca/multi_modal_decoder.py @@ -3,7 +3,6 @@ import torch from torch import nn -from transformers import PreTrainedTokenizer from modalities.models.gpt2.gpt2_model import ActivationType from modalities.models.model import NNModel, SwiGLU @@ -24,7 +23,7 @@ def __init__( dropout: float, ffn_hidden: int, with_context: bool, - is_two_input_modalities: bool, + is_audio_video: bool, attention_type: AttentionType, attention_config: AttentionConfig = None, add_extra_mlp: bool = False, @@ -48,7 +47,7 @@ def __init__( """ super().__init__() self.with_context = with_context - self.is_two_input_modalities = is_two_input_modalities + self.is_audio_video = is_audio_video self.add_extra_mlp = add_extra_mlp if activation == ActivationType.GELU: @@ -79,7 +78,7 @@ def __init__( self.ln_4 = nn.LayerNorm(normalized_shape=n_embd, bias=bias, eps=epsilon) self.mlp_2 = mlp() - if self.is_two_input_modalities: + if self.is_audio_video: self.cross_attn2 = MultiHeadAttention( n_embd=n_embd, n_head=n_head, @@ -126,7 +125,7 @@ def __init__( n_head: int, n_embd: int, ffn_hidden: int, - is_two_input_modalities: bool, + is_audio_video: bool, dropout: float, bias: bool, activation: ActivationType, @@ -172,7 +171,7 @@ def __init__( dropout=dropout, ffn_hidden=ffn_hidden, with_context=True, - is_two_input_modalities=is_two_input_modalities, + is_audio_video=is_audio_video, attention_type=AttentionType.CAUSAL_SELF_ATTENTION, attention_config=attention_config, add_extra_mlp=False, diff --git a/src/modalities/models/coca/text_decoder.py b/src/modalities/models/coca/text_decoder.py index 26e5a569f..c21ef6871 100644 --- a/src/modalities/models/coca/text_decoder.py +++ b/src/modalities/models/coca/text_decoder.py @@ -2,7 +2,6 @@ import torch from torch import nn -from transformers import PreTrainedTokenizer from modalities.models.coca.multi_modal_decoder import TransformerBlock from modalities.models.gpt2.gpt2_model import ActivationType @@ -69,7 +68,7 @@ def __init__( dropout=dropout, ffn_hidden=ffn_hidden, with_context=False, - is_two_input_modalities=False, + is_audio_video=False, attention_type=AttentionType.CAUSAL_SELF_ATTENTION, attention_config=attention_config, ) From aa66bc6a6ac1ff164285ad08a8ade552d1ef18f9 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Mon, 19 Aug 2024 17:07:30 +0200 Subject: [PATCH 121/161] fix: only use audio from video-audio sample if specified --- src/modalities/dataloader/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 23cdf4c43..8c17ce1df 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -727,7 +727,7 @@ def _transform_video(self, sample): transform: VideoTransform = self.modality_transforms[ModalityEnum.VIDEO] sample[target_key] = transform(sample[source_key]) # if the video contains audio - if sample[source_key][1] is not None and ModalityEnum.AUDIO in self.modality_transforms: + if sample[source_key][1] is not None and ModalityEnum.AUDIO in self.modality_transforms and self.is_audio_video: transform: AudioTransform = self.modality_transforms[ModalityEnum.AUDIO] sample["audio"], sample["audio_len"] = transform((sample[source_key][1], sample[source_key][2])) if "audio" not in self.additional_extreacted_keys: From 9efecdb192b35bce52352716f54592fdad324895 Mon Sep 17 00:00:00 2001 From: Thomas Holz Date: Tue, 20 Aug 2024 05:53:34 +0000 Subject: [PATCH 122/161] chore: remove comment --- src/modalities/loss_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index 36fc09d94..a59e8503a 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -232,7 +232,6 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: if world_size > 1 and self.local_loss: labels = labels + num_logits * rank - ## MODIFIED # Calculate loss losses = None for logits in logits_per_embeddings: From 6d53b3941bae98274f77ecbcd6b95a346eae8955 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Mon, 26 Aug 2024 14:42:01 +0200 Subject: [PATCH 123/161] fix: set correct type for mixing ratios (float) --- src/modalities/dataloader/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 8c17ce1df..19e4018cd 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -784,7 +784,7 @@ class MultimodalWebDataset(wds.DataPipeline, wds.compat.FluidInterface): def __init__( self, builders: List[MultimodalWebDatasetBuilder], - mixing_ratios: Optional[List[int]] = None, + mixing_ratios: Optional[List[float]] = None, shardshuffle: int = 100, repeat: bool = False, resample: bool = True, From 21d175117e10b335cc68b2d9484d09f323e433d9 Mon Sep 17 00:00:00 2001 From: Julian Spravil Date: Tue, 27 Aug 2024 16:48:47 +0200 Subject: [PATCH 124/161] fix: gather all in clip loss --- src/modalities/loss_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index a59e8503a..ebc976117 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -202,7 +202,7 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: world_size = dist.get_world_size() rank = dist.get_rank() - gathered_embeddings = [[torch.zeros_like(embedding) for embedding in embeddings] for _ in range(world_size)] + gathered_embeddings = [[torch.zeros_like(embedding) for _ in range(world_size)] for embedding in embeddings] for gathered_embedding, embedding in zip(gathered_embeddings, embeddings): dist.all_gather(gathered_embedding, embedding) From a014b25f5d8cd3f6922af1538c653b241a51e891 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Thu, 29 Aug 2024 13:40:01 +0200 Subject: [PATCH 125/161] fix: use maximum batch size of samples as batch length --- src/modalities/batch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/modalities/batch.py b/src/modalities/batch.py index 6a8b12e61..746cbdc2b 100644 --- a/src/modalities/batch.py +++ b/src/modalities/batch.py @@ -50,8 +50,8 @@ def device(self) -> torch.device: return self.samples[key].device def __len__(self) -> int: - key = list(self.samples.keys())[0] - return self.samples[key].shape[self.batch_dim] + lengths = [self.samples[key].shape[self.batch_dim] for key in self.samples.keys()] + return max(lengths) @dataclass @@ -89,8 +89,8 @@ def get_targets(self, key: str) -> torch.Tensor: return self.targets[key] def __len__(self) -> int: - key = list(self.predictions.keys())[0] - return self.predictions[key].shape[self.batch_dim] + lengths = [self.predictions[key].shape[self.batch_dim] for key in self.predictions.keys()] + return max(lengths) @dataclass From 1788b15bcc9057b6ddee4bca3a8e7f40fc81d262 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Mon, 2 Sep 2024 10:56:22 +0200 Subject: [PATCH 126/161] fix: webdataset: use a fixed round robin sampling strategy to get a fixed number of samples per modality in every batch --- .../config_example_coca_webdataset.yaml | 4 ++ src/modalities/dataloader/dataset.py | 47 ++++++++++++++++++- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml index e76858cc7..52506c8c5 100644 --- a/config_files/training/config_example_coca_webdataset.yaml +++ b/config_files/training/config_example_coca_webdataset.yaml @@ -109,6 +109,8 @@ train_dataset: repeat: true resample: true shuffle_buffer: 10_000 + mixing_ratios: [0.35, 0.65] + batch_size: ${settings.training.local_train_micro_batch_size} val_dataset: component_key: dataset @@ -121,6 +123,8 @@ val_dataset: repeat: true resample: true shuffle_buffer: 10_000 + mixing_ratios: [0.35, 0.65] + batch_size: ${settings.training.local_train_micro_batch_size} train_dataloader: component_key: data_loader diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 19e4018cd..2b91d5a29 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -15,6 +15,7 @@ from pydantic import BaseModel, Field from timm.data import create_transform from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torch.utils.data import IterableDataset from torch.utils.data.dataset import Dataset as TorchdataSet from torchvision.transforms import v2 as transforms from tqdm import tqdm @@ -609,6 +610,46 @@ def torch_audio(key, data): return (audio, sample_rate) +def fixed_ratio_round_robin(*sources, samples_per_batch): + sources = list(sources) + remaining_samples_in_batch = samples_per_batch.copy() + i = 0 + while len(sources) > 0: + try: + sample = next(sources[i]) + remaining_samples_in_batch[i] -= 1 + + # reset + if sum(remaining_samples_in_batch) == 0: + remaining_samples_in_batch = samples_per_batch.copy() + + # go to next source which has some remaining samples + i = (i + 1) % len(sources) + while remaining_samples_in_batch[i] == 0: + i = (i + 1) % len(sources) + yield sample + except StopIteration: + del sources[i] + + +class FixedRatioRoundRobinMix(IterableDataset): + """ + returns an iterator for a list of datasets; samples are yielded in a round robin manner + with a fixed ratio of samples per dataset. There is no random sampling, so the number of + samples per modality is guaranteed to be fixed per batch. + """ + + def __init__(self, datasets, mixing_ratios, batch_size): + self.datasets = datasets + self.samples_per_batch = [int(batch_size * ratio) for ratio in mixing_ratios] + self.samples_per_batch[0] += batch_size - sum(self.samples_per_batch) + + def __iter__(self): + """Return an iterator over the sources.""" + sources = [iter(d) for d in self.datasets] + return fixed_ratio_round_robin(*sources, samples_per_batch=self.samples_per_batch) + + class MultimodalWebDatasetBuilderConfig(BaseModel): urls: Union[List[str], str] modality_key_mapping: Dict[ModalityEnum, Tuple[str, str]] @@ -772,7 +813,8 @@ def dummy_nodesplitter(src, group=None): class MultimodalWebDatasetConfig(BaseModel): builders: List[PydanticMultimodalWebDatasetBuilderIFType] - mixing_ratios: Optional[List[int]] = None + batch_size: int + mixing_ratios: Optional[List[float]] = None shardshuffle: int = 100 repeat: bool = False resample: bool = True @@ -784,6 +826,7 @@ class MultimodalWebDataset(wds.DataPipeline, wds.compat.FluidInterface): def __init__( self, builders: List[MultimodalWebDatasetBuilder], + batch_size: int, mixing_ratios: Optional[List[float]] = None, shardshuffle: int = 100, repeat: bool = False, @@ -829,7 +872,7 @@ def __init__( datasets = [] for b in self.builders: datasets.append(b.web_dataset) - dataset = wds.RandomMix(datasets, self.mixing_ratios) # Apply mixing at sample level + dataset = FixedRatioRoundRobinMix(datasets, self.mixing_ratios, batch_size) # Apply mixing at sample level self.pipeline.append(dataset) else: self.pipeline.extend(self.builders[0].web_dataset.pipeline) From 7f31684f4731668c6832668c4694dd17a8332a4a Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Mon, 9 Sep 2024 12:57:56 +0200 Subject: [PATCH 127/161] refactor: make batch_size mandatory only if using multiple builders --- src/modalities/dataloader/dataset.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 2b91d5a29..63a32ff72 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -813,7 +813,7 @@ def dummy_nodesplitter(src, group=None): class MultimodalWebDatasetConfig(BaseModel): builders: List[PydanticMultimodalWebDatasetBuilderIFType] - batch_size: int + batch_size: Optional[int] = None mixing_ratios: Optional[List[float]] = None shardshuffle: int = 100 repeat: bool = False @@ -826,7 +826,7 @@ class MultimodalWebDataset(wds.DataPipeline, wds.compat.FluidInterface): def __init__( self, builders: List[MultimodalWebDatasetBuilder], - batch_size: int, + batch_size: int = None, mixing_ratios: Optional[List[float]] = None, shardshuffle: int = 100, repeat: bool = False, @@ -869,6 +869,8 @@ def __init__( assert len(self.mixing_ratios) == len(self.builders) if len(self.builders) > 1: + if batch_size is None: + raise ValueError("batch_size cannot be None if multiple builders are used") datasets = [] for b in self.builders: datasets.append(b.web_dataset) From 2e84732da2cf70a6beab274582b4bf5876f82885 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Tue, 17 Sep 2024 16:45:28 +0200 Subject: [PATCH 128/161] chore: remove unused configs; add new config --- config_files/config_example_coca_audio.yaml | 294 ---------- .../config_example_coca_audio_vision.yaml | 348 ----------- config_files/config_example_coca_vision.yaml | 323 ----------- config_files/example_separate_datasets.yaml | 360 ------------ .../config_coca_img_aud_vid_dataset.yaml | 549 ++++++++++++++++++ .../training/config_example_coca.yaml | 342 ----------- .../config_example_coca_webdataset.yaml | 327 ----------- .../config_example_video_coca_webdataset.yaml | 298 ---------- 8 files changed, 549 insertions(+), 2292 deletions(-) delete mode 100644 config_files/config_example_coca_audio.yaml delete mode 100644 config_files/config_example_coca_audio_vision.yaml delete mode 100644 config_files/config_example_coca_vision.yaml delete mode 100644 config_files/example_separate_datasets.yaml create mode 100644 config_files/training/config_coca_img_aud_vid_dataset.yaml delete mode 100644 config_files/training/config_example_coca.yaml delete mode 100644 config_files/training/config_example_coca_webdataset.yaml delete mode 100644 config_files/training/config_example_video_coca_webdataset.yaml diff --git a/config_files/config_example_coca_audio.yaml b/config_files/config_example_coca_audio.yaml deleted file mode 100644 index d61fcdc14..000000000 --- a/config_files/config_example_coca_audio.yaml +++ /dev/null @@ -1,294 +0,0 @@ -settings: - experiment_id: ${modalities_env:experiment_id} - config_file_path: ${modalities_env:config_file_path} - referencing_keys: - sample_key: input_ids - target_key: target_ids - training: - global_training_log_interval_in_steps: 10000 # Needs to be a multiple of gradient_acc_steps - global_checkpointing_interval_in_steps: 10000 - global_evaluation_interval_in_steps: 10000 - global_num_training_samples: 1925398 - global_num_seen_steps: 0 - do_apply_activation_checkpointing: true - gradient_acc_steps: 1 - local_train_micro_batch_size: 256 - sequence_length: 512 - cuda_env: - local_rank: ${cuda_env:LOCAL_RANK} - global_rank: ${cuda_env:RANK} - world_size: ${cuda_env:WORLD_SIZE} - paths: - checkpointing_path: data/checkpoints - - -collate_fn: - component_key: collate_fn - variant_key: coca_collator - config: - sample_keys: - - feats - - feats_len - - ${settings.referencing_keys.sample_key} - - modality - target_keys: [] - text_sample_key: ${settings.referencing_keys.sample_key} - text_target_key: ${settings.referencing_keys.target_key} - -train_dataset: - component_key: dataset - variant_key: arrow_dataset_audio - config: - type_: train - audio_dataset_arrows: gertv-arrow-remove-na/train - bpe_to_ind: bpe_to_ind_gertv_2000.pkl - bpecodes: bpecodes_gertv_2000 - n_mels: 128 - block_size_audio_encoder: 500 - block_size_text_decoder: 512 - freq_domain_mask_length: 30 - time_domain_mask_length: 100 - -val_dataset: - component_key: dataset - variant_key: arrow_dataset_audio - config: - type_: val - audio_dataset_arrows: gertv-arrow-remove-na/test - bpe_to_ind: bpe_to_ind_gertv_2000.pkl - bpecodes: bpecodes_gertv_2000 - n_mels: 128 - block_size_audio_encoder: 500 - block_size_text_decoder: 512 - freq_domain_mask_length: 30 - time_domain_mask_length: 100 - - -train_dataloader: - component_key: data_loader - variant_key: default - config: - num_workers: 2 - pin_memory: true - shuffle: false - dataloader_tag: "train" - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.training.local_train_micro_batch_size} - drop_last: true - sampler: - component_key: sampler - variant_key: distributed_sampler - config: - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: false - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -val_dataloader: - component_key: data_loader - variant_key: default - config: - num_workers: 2 - pin_memory: true - shuffle: false - dataloader_tag: "val" - dataset: - instance_key: val_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.training.local_train_micro_batch_size} - drop_last: true - sampler: - component_key: sampler - variant_key: distributed_sampler - config: - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: false - dataset: - instance_key: val_dataset - pass_type: BY_REFERENCE - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -eval_dataloaders: - - instance_key: val_dataloader - pass_type: BY_REFERENCE - -checkpoint_saving: - component_key: checkpoint_saving - variant_key: default - config: - checkpoint_saving_strategy: - component_key: checkpoint_saving_strategy - variant_key: save_k_most_recent_checkpoints_strategy - config: - k: -1 # -1 to save all checkpoints - checkpoint_saving_execution: - component_key: checkpoint_saving_execution - variant_key: fsdp - config: - checkpoint_path: ${settings.paths.checkpointing_path} - global_rank: ${settings.cuda_env.global_rank} - experiment_id: ${settings.experiment_id} - # mixed_precision_settings: FP_16 - # sharding_strategy: FULL_SHARD - # block_names: [TransformerBlock, AudioTransformer] - - -captioning_loss: - component_key: loss - variant_key: cross_entropy_loss - config: - target_key: ${settings.referencing_keys.target_key} - prediction_key: ${model.config.prediction_key} - tag: captioning_loss - -contrastive_loss: - component_key: loss - variant_key: clip_loss - config: - prediction_key1: ${model.config.modality_cls_prediction_key} - prediction_key2: ${model.config.text_cls_prediction_key} - logit_scale_key: ${model.config.logit_scale_prediction_key} - tag: contrastive_loss - -loss_fn: - - instance_key: captioning_loss - pass_type: BY_REFERENCE - - instance_key: contrastive_loss - pass_type: BY_REFERENCE - -wrapped_model: - component_key: model - variant_key: fsdp_wrapped - config: - model: - instance_key: model - pass_type: BY_REFERENCE - sync_module_states: true - mixed_precision_settings: FP_16 - sharding_strategy: FULL_SHARD - block_names: [TransformerBlock, AudioTransformer] - -model: - component_key: model - variant_key: coca - config: - prediction_key: logits - audio_embd_prediction_key: modality_embeddings - text_embd_prediction_key: text_embeddings - audio_cls_prediction_key: modality_cls - text_cls_prediction_key: text_cls - logit_scale_prediction_key: logit_scale - audio_encoder_config: - sample_key: feats - prediction_key: modality_embeddings - block_size: 500 - n_mels: 128 - n_embd: 512 - n_heads: 4 - n_conformer_blocks: 3 - attention_config: - attention_engine_type: default_attention - pointwise_conv_kernel_size: 1 - depthwise_conv_kernel_size: 31 - text_decoder_config: - sample_key: ${settings.referencing_keys.sample_key} - prediction_key: ${model.config.prediction_key} - block_size: 512 - vocab_size: 2134 - n_layer_text: 3 - n_layer_multimodal_text: 3 - attention_config: - attention_engine_type: default_attention - n_head: 4 - ffn_hidden: 1024 - n_embd: 512 - dropout: 0.0 - bias: true - activation: fused_swiglu - epsilon: 1e-5 - n_pool_head: 8 - n_queries: 256 - bias_attn_pool: False - epsilon_attn_pool: 1e-5 - weight_init: - mean: 0.0 - std: 0.02 - -scheduler: - component_key: scheduler - variant_key: onecycle_lr - config: - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - max_lr: 6e-4 - div_factor: 10 - final_div_factor: 1 - total_steps: 962699 - pct_start: 0.01 - anneal_strategy: cos - -optimizer: - component_key: optimizer - variant_key: adam_w - config: - lr: 0.00001 - betas: [0.9, 0.95] - eps: 1e-8 - weight_decay: 1e-1 - wrapped_model: - instance_key: wrapped_model - pass_type: BY_REFERENCE - -gradient_clipper: - component_key: gradient_clipper - variant_key: fsdp - config: - wrapped_model: - instance_key: wrapped_model - pass_type: BY_REFERENCE - norm_type: P2_NORM - max_norm: 1.0 - -batch_progress_subscriber: - component_key: progress_subscriber - variant_key: rich - config: - local_rank: ${settings.cuda_env.local_rank} - # world_size: ${settings.cuda_env.world_size} - global_num_seen_steps: ${settings.training.global_num_seen_steps} - train_dataloader: - instance_key: train_dataloader - pass_type: BY_REFERENCE - eval_dataloaders: - instance_key: eval_dataloaders - pass_type: BY_REFERENCE - -evaluation_subscriber: - component_key: results_subscriber - variant_key: wandb - config: - local_rank: ${settings.cuda_env.local_rank} - project: modalities - mode: OFFLINE - experiment_id: ${settings.experiment_id} - directory: "." - config_file_path: ${settings.config_file_path} diff --git a/config_files/config_example_coca_audio_vision.yaml b/config_files/config_example_coca_audio_vision.yaml deleted file mode 100644 index 79207cee7..000000000 --- a/config_files/config_example_coca_audio_vision.yaml +++ /dev/null @@ -1,348 +0,0 @@ -settings: - experiment_id: ${modalities_env:experiment_id} - referencing_keys: - sample_key: input_ids - target_key: target_ids - training: - callback_interval_in_samples: 10000 - global_num_training_samples: 1925398 - global_num_seen_samples: 0 - do_apply_activation_checkpointing: true - gradient_acc_steps: 1 - local_train_micro_batch_size: 124 - sequence_length: 512 - gradient_clipping: - mode: p2_norm - threshold: 1.0 - cuda_env: - local_rank: ${cuda_env:LOCAL_RANK} - global_rank: ${cuda_env:RANK} - world_size: ${cuda_env:WORLD_SIZE} - paths: - checkpointing_path: data/checkpoints - -tokenizer: - component_key: tokenizer - variant_key: gpt2_tokenizer_fast - config: - tokenizer_file: data/tokenizer/tokenizer_gpt2.json - -collate_fn: - component_key: collate_fn - variant_key: coca_collator - config: - sample_keys: - - feats - - feats_len - - ${settings.referencing_keys.sample_key} - - modality - target_keys: [] - text_sample_key: ${settings.referencing_keys.sample_key} - text_target_key: ${settings.referencing_keys.target_key} - -train_dataset: - component_key: dataset - variant_key: arrow_dataset_av - config: - type_: train - batch_size: 124 - audio_dataset_arrows: gertv-arrow-remove-na/train - vision_dataset_arrows: coco_captions_arrow/train - bpe_to_ind: bpe_to_ind_test.pkl - bpecodes: bpecodes_test - n_mels: 128 - img_size: 224 - block_size_audio_encoder: 2000 - block_size_text_decoder: 512 - freq_domain_mask_length: 30 - time_domain_mask_length: 100 - -val_dataset: - component_key: dataset - variant_key: arrow_dataset_av - config: - type_: val - batch_size: 124 - audio_dataset_arrows: gertv-arrow-remove-na/test - vision_dataset_arrows: coco_captions_arrow/val - bpe_to_ind: bpe_to_ind_test.pkl - bpecodes: bpecodes_test - n_mels: 128 - img_size: 224 - block_size_audio_encoder: 2000 - block_size_text_decoder: 512 - freq_domain_mask_length: 30 - time_domain_mask_length: 100 - - -# train_dataloader: -# component_key: data_loader -# variant_key: repeating_data_loader -# config: -# dataloader: -# component_key: data_loader -# variant_key: default -# config: -# num_workers: 2 -# pin_memory: true -# shuffle: false -# dataloader_tag: "train" -# dataset: -# instance_key: train_dataset -# pass_type: BY_REFERENCE -# batch_sampler: -# component_key: batch_sampler -# variant_key: default -# config: -# batch_size: ${settings.training.local_train_micro_batch_size} -# drop_last: false -# sampler: -# component_key: sampler -# variant_key: distributed_sampler -# config: -# rank: ${settings.cuda_env.global_rank} -# num_replicas: ${settings.cuda_env.world_size} -# shuffle: true -# dataset: -# instance_key: train_dataset -# pass_type: BY_REFERENCE -# collate_fn: -# instance_key: collate_fn -# pass_type: BY_REFERENCE - -# num_epochs: 1 - - -train_dataloader: - component_key: data_loader - variant_key: default - config: - num_workers: 2 - pin_memory: true - shuffle: false - dataloader_tag: "train" - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.training.local_train_micro_batch_size} - drop_last: false - sampler: - component_key: sampler - variant_key: distributed_sampler - config: - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: false - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -val_dataloader: - component_key: data_loader - variant_key: default - config: - num_workers: 2 - pin_memory: true - shuffle: false - dataloader_tag: "val" - dataset: - instance_key: val_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.training.local_train_micro_batch_size} - drop_last: false - sampler: - component_key: sampler - variant_key: distributed_sampler - config: - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: false - dataset: - instance_key: val_dataset - pass_type: BY_REFERENCE - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -eval_dataloaders: - - instance_key: val_dataloader - pass_type: BY_REFERENCE - -checkpointing: - component_key: checkpointing - variant_key: default - config: - checkpointing_strategy: - component_key: checkpointing_strategy - variant_key: save_k_most_recent_checkpoints_strategy - config: - k: -1 # -1 to save all checkpoints - checkpointing_execution: - component_key: checkpointing_execution - variant_key: fsdp_to_disc_checkpointing - config: - checkpoint_path: ${settings.paths.checkpointing_path} - global_rank: ${settings.cuda_env.global_rank} - experiment_id: ${settings.experiment_id} - mixed_precision_settings: FP_16 - sharding_strategy: FULL_SHARD - block_names: [TransformerBlock, AudioTransformer, VisionTransformerBlock] - - -captioning_loss: - component_key: loss - variant_key: clm_cross_entropy_loss - config: - target_key: ${settings.referencing_keys.target_key} - prediction_key: logits - -contrastive_loss: - component_key: loss - variant_key: nce_loss - config: - prediction_key1: ${model.config.modality_cls_prediction_key} - prediction_key2: ${model.config.text_cls_prediction_key} - tag: contrastive_loss - -loss_fn: - - instance_key: captioning_loss - pass_type: BY_REFERENCE - - instance_key: contrastive_loss - pass_type: BY_REFERENCE - -wrapped_model: - component_key: model - variant_key: fsdp_wrapped - config: - model: - instance_key: model - pass_type: BY_REFERENCE - sync_module_states: true - mixed_precision_settings: FP_16 - sharding_strategy: FULL_SHARD - block_names: [TransformerBlock, AudioTransformer, VisionTransformerBlock] - -model: - component_key: model - variant_key: coca - config: - prediction_key: logits - modality_key: modality - modality_embd_prediction_key: modality_embeddings - text_embd_prediction_key: text_embeddings - modality_cls_prediction_key: modality_cls - text_cls_prediction_key: text_cls - modality_encoder_config: - vision_transformer_config: - sample_key: feats # need to fix these - prediction_key: modality_embeddings - img_size: 224 - n_classes: Null # Disable vision transformer head - n_layer: 3 - attention_config: - attention_engine_type: default_attention - n_head: 4 - n_embd: 512 - dropout: 0.0 - patch_size: 16 - patch_stride: 16 - n_img_channels: 3 - add_cls_token: False - bias: True - audio_transformer_config: - sample_key: feats - prediction_key: modality_embeddings - block_size: 2000 - n_mels: 128 - n_embd: 512 - n_heads: 4 - n_conformer_blocks: 3 - attention_config: - attention_engine_type: default_attention - pointwise_conv_kernel_size: 1 - depthwise_conv_kernel_size: 31 - text_decoder_config: - sample_key: ${settings.referencing_keys.sample_key} - prediction_key: ${model.config.prediction_key} - block_size: 512 - vocab_size: 634 - n_layer_text: 1 - n_layer_multimodal_text: 2 - attention_config: - attention_engine_type: default_attention - n_head: 4 - ffn_hidden: 1024 - n_embd: 512 - dropout: 0.0 - bias: true - activation: fused_swiglu - epsilon: 1e-5 - n_pool_head: 8 - n_vision_queries: 256 - n_audio_queries: 256 - bias_attn_pool: False - epsilon_attn_pool: 1e-5 - weight_init: - mean: 0.0 - std: 0.02 - -scheduler: - component_key: scheduler - variant_key: onecycle_lr - config: - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - max_lr: 6e-4 - div_factor: 10 - final_div_factor: 1 - total_steps: 962699 - pct_start: 0.01 - anneal_strategy: cos - -optimizer: - component_key: optimizer - variant_key: adam_w - config: - lr: 0.00001 - betas: [0.9, 0.95] - eps: 1e-8 - weight_decay: 1e-1 - wrapped_model: - instance_key: wrapped_model - pass_type: BY_REFERENCE - -batch_progress_subscriber: - component_key: progress_subscriber - variant_key: rich - config: - local_rank: ${settings.cuda_env.local_rank} - world_size: ${settings.cuda_env.world_size} - global_num_seen_samples: ${settings.training.global_num_seen_samples} - train_dataloader: - instance_key: train_dataloader - pass_type: BY_REFERENCE - eval_dataloaders: - instance_key: eval_dataloaders - pass_type: BY_REFERENCE - -evaluation_subscriber: - component_key: results_subscriber - variant_key: wandb - config: - local_rank: ${settings.cuda_env.local_rank} - project: modalities - mode: OFFLINE - experiment_id: ${settings.experiment_id} - directory: "." diff --git a/config_files/config_example_coca_vision.yaml b/config_files/config_example_coca_vision.yaml deleted file mode 100644 index bcc00dc13..000000000 --- a/config_files/config_example_coca_vision.yaml +++ /dev/null @@ -1,323 +0,0 @@ -settings: - experiment_id: ${modalities_env:experiment_id} - referencing_keys: - sample_key: input_ids - target_key: target_ids - training: - callback_interval_in_samples: 10000 - global_num_training_samples: 1925398 - global_num_seen_samples: 0 - do_apply_activation_checkpointing: true - gradient_acc_steps: 1 - local_train_micro_batch_size: 512 - sequence_length: 512 - gradient_clipping: - mode: p2_norm - threshold: 1.0 - cuda_env: - local_rank: ${cuda_env:LOCAL_RANK} - global_rank: ${cuda_env:RANK} - world_size: ${cuda_env:WORLD_SIZE} - paths: - checkpointing_path: data/checkpoints - -tokenizer: - component_key: tokenizer - variant_key: gpt2_tokenizer_fast - config: - tokenizer_file: data/tokenizer/tokenizer_gpt2.json - -collate_fn: - component_key: collate_fn - variant_key: coca_collator - config: - sample_keys: - - feats - - feats_len - - ${settings.referencing_keys.sample_key} - - modality - target_keys: [] - text_sample_key: ${settings.referencing_keys.sample_key} - text_target_key: ${settings.referencing_keys.target_key} - -train_dataset: - component_key: dataset - variant_key: arrow_dataset_vision - config: - vision_dataset_arrows: coco_captions_arrow/train - bpe_to_ind: bpe_to_ind_test.pkl - bpecodes: bpecodes_test - img_size: 224 - block_size_text_decoder: 512 - - - -val_dataset: - component_key: dataset - variant_key: arrow_dataset_vision - config: - vision_dataset_arrows: coco_captions_arrow/val - bpe_to_ind: bpe_to_ind_test.pkl - bpecodes: bpecodes_test - img_size: 224 - block_size_text_decoder: 512 - - -# train_dataloader: -# component_key: data_loader -# variant_key: repeating_data_loader -# config: -# dataloader: -# component_key: data_loader -# variant_key: default -# config: -# num_workers: 2 -# pin_memory: true -# shuffle: false -# dataloader_tag: "train" -# dataset: -# instance_key: train_dataset -# pass_type: BY_REFERENCE -# batch_sampler: -# component_key: batch_sampler -# variant_key: default -# config: -# batch_size: ${settings.training.local_train_micro_batch_size} -# drop_last: false -# sampler: -# component_key: sampler -# variant_key: distributed_sampler -# config: -# rank: ${settings.cuda_env.global_rank} -# num_replicas: ${settings.cuda_env.world_size} -# shuffle: true -# dataset: -# instance_key: train_dataset -# pass_type: BY_REFERENCE -# collate_fn: -# instance_key: collate_fn -# pass_type: BY_REFERENCE - -# num_epochs: 1 - - -train_dataloader: - component_key: data_loader - variant_key: default - config: - num_workers: 2 - pin_memory: true - shuffle: false - dataloader_tag: "train" - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.training.local_train_micro_batch_size} - drop_last: false - sampler: - component_key: sampler - variant_key: distributed_sampler - config: - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: false - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -val_dataloader: - component_key: data_loader - variant_key: default - config: - num_workers: 2 - pin_memory: true - shuffle: false - dataloader_tag: "val" - dataset: - instance_key: val_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.training.local_train_micro_batch_size} - drop_last: false - sampler: - component_key: sampler - variant_key: distributed_sampler - config: - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: false - dataset: - instance_key: val_dataset - pass_type: BY_REFERENCE - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -eval_dataloaders: - - instance_key: val_dataloader - pass_type: BY_REFERENCE - -checkpointing: - component_key: checkpointing - variant_key: default - config: - checkpointing_strategy: - component_key: checkpointing_strategy - variant_key: save_k_most_recent_checkpoints_strategy - config: - k: -1 # -1 to save all checkpoints - checkpointing_execution: - component_key: checkpointing_execution - variant_key: fsdp_to_disc_checkpointing - config: - checkpoint_path: ${settings.paths.checkpointing_path} - global_rank: ${settings.cuda_env.global_rank} - experiment_id: ${settings.experiment_id} - mixed_precision_settings: FP_16 - sharding_strategy: FULL_SHARD - block_names: [TransformerBlock, VisionTransformerBlock] - - -captioning_loss: - component_key: loss - variant_key: clm_cross_entropy_loss - config: - target_key: ${settings.referencing_keys.target_key} - prediction_key: logits - -contrastive_loss: - component_key: loss - variant_key: nce_loss - config: - prediction_key1: ${model.config.modality_cls_prediction_key} - prediction_key2: ${model.config.text_cls_prediction_key} - tag: contrastive_loss - -loss_fn: - - instance_key: captioning_loss - pass_type: BY_REFERENCE - - instance_key: contrastive_loss - pass_type: BY_REFERENCE - -wrapped_model: - component_key: model - variant_key: fsdp_wrapped - config: - model: - instance_key: model - pass_type: BY_REFERENCE - sync_module_states: true - mixed_precision_settings: FP_16 - sharding_strategy: FULL_SHARD - block_names: [TransformerBlock, VisionTransformerBlock] - -model: - component_key: model - variant_key: coca - config: - prediction_key: logits - modality_key: modality - modality_embd_prediction_key: modality_embeddings - text_embd_prediction_key: text_embeddings - modality_cls_prediction_key: modality_cls - text_cls_prediction_key: text_cls - modality_encoder_config: - sample_key: feats - prediction_key: modality_embeddings - img_size: 224 # 288 in the original coca - n_classes: Null # Disable vision transformer head - n_layer: 3 - attention_config: - attention_engine_type: default_attention - n_head: 12 - n_embd: 768 - dropout: 0.0 - patch_size: 16 # 18 in the original coca - patch_stride: 16 # 18 in the original coca - n_img_channels: 3 - add_cls_token: False - bias: True - text_decoder_config: - sample_key: ${settings.referencing_keys.sample_key} - prediction_key: ${model.config.prediction_key} - block_size: 512 - vocab_size: 610 - n_layer_text: 3 - n_layer_multimodal_text: 1 - attention_config: - attention_engine_type: default_attention - n_head: 12 - ffn_hidden: 2048 - n_embd: 768 - dropout: 0.0 - bias: true - activation: fused_swiglu - epsilon: 1e-5 - n_pool_head: 8 - n_vision_queries: 256 - n_audio_queries: Null - bias_attn_pool: False - epsilon_attn_pool: 1e-5 - weight_init: - mean: 0.0 - std: 0.02 - -scheduler: - component_key: scheduler - variant_key: onecycle_lr - config: - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - max_lr: 6e-4 - div_factor: 10 - final_div_factor: 1 - total_steps: 962699 - pct_start: 0.01 - anneal_strategy: cos - -optimizer: - component_key: optimizer - variant_key: adam_w - config: - lr: 0.00001 - betas: [0.9, 0.95] - eps: 1e-8 - weight_decay: 1e-1 - wrapped_model: - instance_key: wrapped_model - pass_type: BY_REFERENCE - -batch_progress_subscriber: - component_key: progress_subscriber - variant_key: rich - config: - local_rank: ${settings.cuda_env.local_rank} - world_size: ${settings.cuda_env.world_size} - global_num_seen_samples: ${settings.training.global_num_seen_samples} - train_dataloader: - instance_key: train_dataloader - pass_type: BY_REFERENCE - eval_dataloaders: - instance_key: eval_dataloaders - pass_type: BY_REFERENCE - -evaluation_subscriber: - component_key: results_subscriber - variant_key: wandb - config: - local_rank: ${settings.cuda_env.local_rank} - project: modalities - mode: OFFLINE - experiment_id: ${settings.experiment_id} - directory: "." diff --git a/config_files/example_separate_datasets.yaml b/config_files/example_separate_datasets.yaml deleted file mode 100644 index 8601e0290..000000000 --- a/config_files/example_separate_datasets.yaml +++ /dev/null @@ -1,360 +0,0 @@ -settings: - experiment_id: ${modalities_env:experiment_id} - config_file_path: ${modalities_env:config_file_path} - referencing_keys: - sample_key: input_ids - target_key: target_ids - training: - global_training_log_interval_in_steps: 30 # Needs to be a multiple of gradient_acc_steps - global_checkpointing_interval_in_steps: 9_990 - global_evaluation_interval_in_steps: 4_980 - global_num_training_samples: 100_000 # 491 steps with 8 gpus and global bs of 1152 - global_num_seen_steps: 0 - do_apply_activation_checkpointing: true - gradient_acc_steps: 30 - local_train_micro_batch_size: 1 - sequence_length: 64 - cuda_env: - local_rank: ${cuda_env:LOCAL_RANK} - global_rank: ${cuda_env:RANK} - world_size: ${cuda_env:WORLD_SIZE} - paths: - checkpointing_path: data/checkpoints - -tokenizer: - component_key: tokenizer - variant_key: pretrained_hf_tokenizer - config: - pretrained_model_name_or_path: tokenizer/ - padding: true - max_length: ${settings.training.sequence_length} - -collate_fn: - component_key: collate_fn - variant_key: coca_collator - config: - sample_keys: - - video - - audio - - audio_len - - ${settings.referencing_keys.sample_key} - target_keys: [] - text_sample_key: ${settings.referencing_keys.sample_key} - text_target_key: ${settings.referencing_keys.target_key} - -train_audio_transform: - component_key: transform - variant_key: audio_transform - config: - is_training: True - block_size_audio_encoder: 300 - freq_domain_mask_length: 30 - time_domain_mask_length: 100 - -val_audio_transform: - component_key: transform - variant_key: audio_transform - config: - is_training: false - block_size_audio_encoder: 300 - -train_video_transform: - component_key: transform - variant_key: video_transform - config: - is_training: True - input_size: 288 - num_frames: ${model.config.vision_encoder_config.num_video_frames} - -val_video_transform: - component_key: transform - variant_key: video_transform - config: - is_training: True - input_size: 288 - num_frames: ${model.config.vision_encoder_config.num_video_frames} - -text_transform: - component_key: transform - variant_key: text_transform - config: - tokenizer: - instance_key: tokenizer - pass_type: BY_REFERENCE - -train_video_builder: - component_key: dataset - variant_key: web_dataset_builder - config: - urls: "videodata/validation/000000.tar" - modality_key_mapping: - TEXT: ["json", "input_ids"] - VIDEO: ["mp4", "video"] - modality_transforms: - AUDIO: - instance_key: train_audio_transform - pass_type: BY_REFERENCE - VIDEO: - instance_key: train_video_transform - pass_type: BY_REFERENCE - TEXT: - instance_key: text_transform - pass_type: BY_REFERENCE - num_samples: 100_000 - -val_video_builder: - component_key: dataset - variant_key: web_dataset_builder - config: - urls: "videodata/validation/000000.tar" - modality_key_mapping: - TEXT: ["json", "input_ids"] - VIDEO: ["mp4", "video"] - modality_transforms: - AUDIO: - instance_key: val_audio_transform - pass_type: BY_REFERENCE - VIDEO: - instance_key: val_video_transform - pass_type: BY_REFERENCE - TEXT: - instance_key: text_transform - pass_type: BY_REFERENCE - num_samples: 10_000 - -train_dataset: - component_key: dataset - variant_key: web_dataset - config: - builders: - - instance_key: train_video_builder - pass_type: BY_REFERENCE - shardshuffle: 100 - repeat: true - resample: true - shuffle_buffer: 10_000 - -val_dataset: - component_key: dataset - variant_key: web_dataset - config: - builders: - - instance_key: val_video_builder - pass_type: BY_REFERENCE - shardshuffle: 1000 - repeat: true - resample: true - shuffle_buffer: 10_000 - -train_dataloader: - component_key: data_loader - variant_key: web_loader - config: - num_workers: 2 - pin_memory: true - drop_last: true - dataloader_tag: "train" - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - batch_size: ${settings.training.local_train_micro_batch_size} - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -val_dataloader: - component_key: data_loader - variant_key: web_loader - config: - num_workers: 8 - pin_memory: true - drop_last: false - dataloader_tag: "val" - dataset: - instance_key: val_dataset - pass_type: BY_REFERENCE - batch_size: ${settings.training.local_train_micro_batch_size} - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -eval_dataloaders: - - instance_key: val_dataloader - pass_type: BY_REFERENCE - -checkpoint_saving: - component_key: checkpoint_saving - variant_key: default - config: - checkpoint_saving_strategy: - component_key: checkpoint_saving_strategy - variant_key: save_k_most_recent_checkpoints_strategy - config: - k: -1 # -1 to save all checkpoints - checkpoint_saving_execution: - component_key: checkpoint_saving_execution - variant_key: fsdp - config: - checkpoint_path: ${settings.paths.checkpointing_path} - global_rank: ${settings.cuda_env.global_rank} - experiment_id: ${settings.experiment_id} - -captioning_loss: - component_key: loss - variant_key: cross_entropy_loss - config: - target_key: ${settings.referencing_keys.target_key} - prediction_key: ${model.config.prediction_key} - tag: captioning_loss - weight: 2.0 - -contrastive_loss: - component_key: loss - variant_key: clip_loss - config: - prediction_keys: - - ${model.config.individual_datasets_cls_prediction_key} - - ${model.config.text_cls_prediction_key} - logit_scale_key: ${model.config.logit_scale_prediction_key} - tag: contrastive_loss - -loss_fn: - - instance_key: captioning_loss - pass_type: BY_REFERENCE - - instance_key: contrastive_loss - pass_type: BY_REFERENCE - -wrapped_model: - component_key: model - variant_key: fsdp_wrapped - config: - model: - instance_key: model - pass_type: BY_REFERENCE - sync_module_states: true - mixed_precision_settings: FP_16 - sharding_strategy: HYBRID_SHARD - block_names: [TransformerBlock, PerceiverTransformerBlock] - -model: - component_key: model - variant_key: coca - config: - prediction_key: logits - audio_embd_prediction_key: audio_embeddings - vision_embd_prediction_key: vision_embeddings - text_embd_prediction_key: text_embeddings - vision_cls_prediction_key: vision_cls - individual_datasets_cls_prediction_key: modalities_cls - audio_cls_prediction_key: audio_cls - text_cls_prediction_key: text_cls - logit_scale_prediction_key: logit_scale - audio_encoder_config: - sample_key: audio - prediction_key: audio_embeddings - block_size: 300 - n_mels: 128 - n_embd: 768 - n_heads: 12 - n_conformer_blocks: 3 - attention_config: - attention_engine_type: default_attention - pointwise_conv_kernel_size: 1 - depthwise_conv_kernel_size: 31 - vision_encoder_config: - sample_key: video - prediction_key: vision_embeddings - img_size: 288 # 288 in the original coca - n_classes: Null # Disable vision transformer head - n_layer: 3 - attention_config: - attention_engine_type: default_attention - n_head: 12 - n_embd: 768 - dropout: 0.0 - patch_size: 18 # 18 in the original coca - patch_stride: 18 # 18 in the original coca - n_img_channels: 3 - add_cls_token: False - bias: True - num_video_frames: 16 - n_latents: 64 - text_decoder_config: - sample_key: ${settings.referencing_keys.sample_key} - prediction_key: ${model.config.prediction_key} - block_size: 77 - vocab_size: 49_408 # 64k in the original coca - n_layer_text: 2 - n_layer_multimodal_text: 1 - attention_config: - attention_engine_type: default_attention - n_head: 12 - ffn_hidden: 3072 - n_embd: 768 - dropout: 0.0 - bias: true - activation: fused_swiglu - epsilon: 1e-5 - n_pool_head: 12 - n_queries: 256 - bias_attn_pool: False - epsilon_attn_pool: 1e-5 - weight_init: - mean: 0.0 - std: 0.02 - -scheduler: - component_key: scheduler - variant_key: cosine_annealing_with_warmup_lr # COCA uses linear decay - config: - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - num_warmup_steps: 2_000 - num_training_steps: 500_000 - -optimizer: - component_key: optimizer - variant_key: adam_w - config: - lr: 8e-4 - betas: [0.9, 0.999] - eps: 1e-8 - weight_decay: 0.01 - wrapped_model: - instance_key: wrapped_model - pass_type: BY_REFERENCE - -gradient_clipper: - component_key: gradient_clipper - variant_key: fsdp - config: - wrapped_model: - instance_key: wrapped_model - pass_type: BY_REFERENCE - norm_type: P2_NORM - max_norm: 1.0 - -batch_progress_subscriber: - component_key: progress_subscriber - variant_key: rich - config: - local_rank: ${settings.cuda_env.global_rank} - global_num_seen_steps: ${settings.training.global_num_seen_steps} - train_dataloader: - instance_key: train_dataloader - pass_type: BY_REFERENCE - eval_dataloaders: - instance_key: eval_dataloaders - pass_type: BY_REFERENCE - -evaluation_subscriber: - component_key: results_subscriber - variant_key: wandb - config: - local_rank: ${settings.cuda_env.global_rank} - project: modalities - mode: OFFLINE - experiment_id: ${settings.experiment_id} - directory: "." - config_file_path: ${settings.config_file_path} diff --git a/config_files/training/config_coca_img_aud_vid_dataset.yaml b/config_files/training/config_coca_img_aud_vid_dataset.yaml new file mode 100644 index 000000000..ffa74df21 --- /dev/null +++ b/config_files/training/config_coca_img_aud_vid_dataset.yaml @@ -0,0 +1,549 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpoint_saving_path: data/checkpoints + train_dataset_path: ./data/lorem_ipsum.pbin + intervals: + training_log_interval_in_steps: 2 + checkpointing_interval_in_steps: 2 + evaluation_interval_in_steps: 2 + consistency_enforcement: + enforce_tokens_per_step_consistency: true + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 10 + sequence_length: 256 + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_num_steps + config: + num_steps: ${settings.training_target.num_target_steps} + num_ranks: ${settings.cuda_env.world_size} + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: # for the batch progress subscriber + component_key: number_conversion + variant_key: num_steps_from_num_samples + config: + num_ranks: ${settings.cuda_env.world_size} + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + global_num_samples: ${settings.coca_example_settings.train_num_samples} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + local_num_seen_batches: 0 + last_step: -1 + coca_example_settings: + train_num_samples: 64 + val_num_samples: 32 + +tokenizer: + component_key: tokenizer + variant_key: pretrained_hf_tokenizer + config: + pretrained_model_name_or_path: tokenizer/ + padding: true + max_length: ${settings.step_profile.sequence_length} + +collate_fn: + component_key: collate_fn + variant_key: coca_collator + config: + sample_keys: + - images + - audio + - audio_len + - video + - ${settings.referencing_keys.sample_key} + target_keys: [] + text_sample_key: ${settings.referencing_keys.sample_key} + text_target_key: ${settings.referencing_keys.target_key} + +train_audio_transform: + component_key: transform + variant_key: audio_transform + config: + is_training: True + block_size_audio_encoder: ${model_raw.config.audio_encoder_config.block_size} + freq_domain_mask_length: 30 + time_domain_mask_length: 100 + +train_image_transform: + component_key: transform + variant_key: image_transform + config: + is_training: True + input_size: ${model_raw.config.image_encoder_config.img_size} + +train_video_transform: + component_key: transform + variant_key: video_transform + config: + is_training: True + input_size: ${model_raw.config.video_encoder_config.img_size} + num_frames: ${model_raw.config.video_encoder_config.num_video_frames} + +val_audio_transform: + component_key: transform + variant_key: audio_transform + config: + is_training: False + block_size_audio_encoder: ${model_raw.config.audio_encoder_config.block_size} + freq_domain_mask_length: 30 + time_domain_mask_length: 100 + +val_image_transform: + component_key: transform + variant_key: image_transform + config: + is_training: False + input_size: ${model_raw.config.image_encoder_config.img_size} + +val_video_transform: + component_key: transform + variant_key: video_transform + config: + is_training: False + input_size: ${model_raw.config.video_encoder_config.img_size} + num_frames: ${model_raw.config.video_encoder_config.num_video_frames} + +text_transform: + component_key: transform + variant_key: text_transform + config: + tokenizer: + instance_key: tokenizer + pass_type: BY_REFERENCE + +train_video_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "youcook2/training/000000.tar" + is_audio_video: ${model_raw.config.is_audio_video} + modality_key_mapping: + TEXT: ["json", "input_ids"] + VIDEO: ["mp4", "video"] + modality_transforms: + VIDEO: + instance_key: train_video_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 100_000 + +val_video_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "youcook2/training/000000.tar" + is_audio_video: ${model_raw.config.is_audio_video} + modality_key_mapping: + TEXT: ["json", "input_ids"] + VIDEO: ["mp4", "video"] + modality_transforms: + VIDEO: + instance_key: val_video_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 10 + +train_audio_dataset_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "commonvoice/commonvoice_17_dev_wav_000001.tar" + modality_key_mapping: + TEXT: ["transcript.txt", "input_ids"] # source and target keys + AUDIO: ["wav", "audio"] + modality_transforms: + AUDIO: + instance_key: train_audio_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 30000 + +val_audio_dataset_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "commonvoice/commonvoice_17_dev_wav_000001.tar" + modality_key_mapping: + TEXT: ["transcript.txt", "input_ids"] # source and target keys + AUDIO: ["wav", "audio"] + modality_transforms: + AUDIO: + instance_key: val_audio_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 10 + +train_coco_dataset_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "coco_captions/data/train/000000.tar" + modality_key_mapping: + TEXT: ["json_text0", "input_ids"] + IMAGE: ["jpg", "images"] + modality_transforms: + IMAGE: + instance_key: train_image_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 10 + +val_coco_dataset_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "coco_captions/data/train/000000.tar" + modality_key_mapping: + TEXT: ["json_text0", "input_ids"] + IMAGE: ["jpg", "images"] + modality_transforms: + IMAGE: + instance_key: val_image_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 10 + + +train_dataset: + component_key: dataset + variant_key: web_dataset + config: + builders: + - instance_key: train_audio_dataset_builder + pass_type: BY_REFERENCE + - instance_key: train_coco_dataset_builder + pass_type: BY_REFERENCE + - instance_key: train_video_builder + pass_type: BY_REFERENCE + mixing_ratios: [0.5, 0.4, 0.1] + batch_size: ${settings.step_profile.local_train_micro_batch_size} + shardshuffle: 100 + repeat: false + resample: false + shuffle_buffer: 10_000 + +val_dataset: + component_key: dataset + variant_key: web_dataset + config: + builders: + - instance_key: val_audio_dataset_builder + pass_type: BY_REFERENCE + - instance_key: val_coco_dataset_builder + pass_type: BY_REFERENCE + - instance_key: val_video_builder + pass_type: BY_REFERENCE + mixing_ratios: [0.5, 0.4, 0.1] + batch_size: ${settings.step_profile.local_train_micro_batch_size} + shardshuffle: 1000 + repeat: true + resample: true + shuffle_buffer: 10_000 + +train_dataloader: + component_key: data_loader + variant_key: web_loader + config: + num_workers: 8 + pin_memory: true + drop_last: true + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_size: ${settings.step_profile.local_train_micro_batch_size} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +val_dataloader: + component_key: data_loader + variant_key: web_loader + config: + num_workers: 8 + pin_memory: true + drop_last: false + dataloader_tag: "val" + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_size: ${settings.step_profile.local_train_micro_batch_size} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: fsdp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +captioning_loss: + component_key: loss + variant_key: cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${model_raw.config.prediction_key} + tag: captioning_loss + weight: 2.0 + +contrastive_loss_audio: + component_key: loss + variant_key: clip_loss + config: + prediction_keys: + - ${model_raw.config.audio_cls_prediction_key} + - ${model_raw.config.audio_text_cls_prediction_key} + logit_scale_key: ${model_raw.config.logit_scale_prediction_key} + tag: contrastive_loss_audio + weight: 1.0 + +contrastive_loss_image: + component_key: loss + variant_key: clip_loss + config: + prediction_keys: + - ${model_raw.config.image_cls_prediction_key} + - ${model_raw.config.image_text_cls_prediction_key} + logit_scale_key: ${model_raw.config.logit_scale_prediction_key} + tag: contrastive_loss_image + weight: 1.0 + +contrastive_loss_video: + component_key: loss + variant_key: clip_loss + config: + prediction_keys: + - ${model_raw.config.video_cls_prediction_key} + - ${model_raw.config.video_text_cls_prediction_key} + logit_scale_key: ${model_raw.config.logit_scale_prediction_key} + tag: contrastive_loss_image + weight: 1.0 + +loss_fn: + - instance_key: captioning_loss + pass_type: BY_REFERENCE + - instance_key: contrastive_loss_audio + pass_type: BY_REFERENCE + - instance_key: contrastive_loss_image + pass_type: BY_REFERENCE + - instance_key: contrastive_loss_video + pass_type: BY_REFERENCE + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: FP_16 + sharding_strategy: HYBRID_SHARD + block_names: [TransformerBlock, VisionTransformerBlock, ConformerBlock] + +model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: coca + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.text_decoder_config.n_layer_text} + +model_raw: + component_key: model + variant_key: coca + config: + prediction_key: ${settings.referencing_keys.prediction_key} + audio_embd_prediction_key: audio_embeddings + image_embd_prediction_key: image_embeddings + video_embd_prediction_key: video_embeddings + text_embd_prediction_key: text_embeddings + image_cls_prediction_key: image_cls + image_text_cls_prediction_key: image_text_cls + audio_cls_prediction_key: audio_cls + audio_text_cls_prediction_key: audio_text_cls + video_cls_prediction_key: video_cls + video_text_cls_prediction_key: video_text_cls + text_cls_prediction_key: text_cls + modality_keys: ${collate_fn.config.sample_keys} + is_audio_video: false + individual_datasets: true + logit_scale_prediction_key: logit_scale + audio_encoder_config: + sample_key: audio + prediction_key: audio_embeddings + block_size: 2_000 + n_mels: 128 + n_embd: 768 + n_heads: 8 + n_conformer_blocks: 2 + attention_config: + attention_engine_type: default_attention + pointwise_conv_kernel_size: 1 + depthwise_conv_kernel_size: 31 + image_encoder_config: + sample_key: images + prediction_key: image_embeddings + img_size: 256 # 288 in the original coca + n_classes: Null # Disable vision transformer head + n_layer: 2 + attention_config: + attention_engine_type: default_attention + n_head: 12 + n_embd: 768 + dropout: 0.0 + patch_size: 16 # 18 in the original coca + patch_stride: 16 # 18 in the original coca + n_img_channels: 3 + add_cls_token: False + bias: True + video_encoder_config: + sample_key: video + prediction_key: video_embeddings + img_size: 256 # 288 in the original coca + n_classes: Null # Disable vision transformer head + n_layer: 2 + attention_config: + attention_engine_type: default_attention + n_head: 12 + n_embd: 768 + dropout: 0.0 + patch_size: 18 # 18 in the original coca + patch_stride: 18 # 18 in the original coca + n_img_channels: 3 + add_cls_token: False + bias: True + num_video_frames: 16 + n_latents: 64 + text_decoder_config: + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${model_raw.config.prediction_key} + block_size: 512 + vocab_size: 50304 # 64k in the original coca + n_layer_text: 2 + n_layer_multimodal_text: 2 + attention_config: + attention_engine_type: default_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: swiglu + epsilon: 1e-5 + n_pool_head: 12 + n_queries: 256 + bias_attn_pool: False + epsilon_attn_pool: 1e-5 + +scheduler: + component_key: scheduler + variant_key: cosine_annealing_with_warmup_lr # COCA uses linear decay + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + num_warmup_steps: 2_000 + num_training_steps: 500_000 + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 8e-4 + betas: [0.9, 0.999] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [] + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp + config: + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + global_rank: ${settings.cuda_env.global_rank} + project: modalities_coca + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: wandb_storage + config_file_path: ${settings.config_file_path} diff --git a/config_files/training/config_example_coca.yaml b/config_files/training/config_example_coca.yaml deleted file mode 100644 index 17c0adf36..000000000 --- a/config_files/training/config_example_coca.yaml +++ /dev/null @@ -1,342 +0,0 @@ -settings: - experiment_id: ${modalities_env:experiment_id} - config_file_path: ${modalities_env:config_file_path} - referencing_keys: - sample_key: input_ids - target_key: target_ids - prediction_key: logits - cuda_env: - local_rank: ${cuda_env:LOCAL_RANK} - global_rank: ${cuda_env:RANK} - world_size: ${cuda_env:WORLD_SIZE} - paths: - checkpoint_saving_path: data/checkpoints - train_dataset_path: ./data/lorem_ipsum.pbin - intervals: - training_log_interval_in_steps: 2 - checkpointing_interval_in_steps: 2 - evaluation_interval_in_steps: 2 - consistency_enforcement: - enforce_tokens_per_step_consistency: true - enforce_last_step_logged: false - enforce_last_step_evaluated: false - enforce_last_step_checkpointed: false - step_profile: - gradient_accumulation_steps: 1 - local_train_micro_batch_size: 1 - sequence_length: 256 - training_target: - num_target_tokens: - component_key: number_conversion - variant_key: num_tokens_from_num_steps - config: - num_steps: ${settings.training_target.num_target_steps} - num_ranks: ${settings.cuda_env.world_size} - local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} - sequence_length: ${settings.step_profile.sequence_length} - gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} - num_target_steps: # for the batch progress subscriber - component_key: number_conversion - variant_key: num_steps_from_num_samples - config: - num_ranks: ${settings.cuda_env.world_size} - local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} - global_num_samples: ${settings.coca_example_settings.train_num_samples} - gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} - training_progress: - global_num_seen_tokens: 0 - num_seen_steps: 0 - local_num_seen_batches: 0 - last_step: -1 - coca_example_settings: - train_num_samples: 64 - val_num_samples: 32 - -collate_fn: - component_key: collate_fn - variant_key: coca_collator - config: - sample_keys: - - images - - ${settings.referencing_keys.sample_key} - - modality - target_keys: [] - text_sample_key: ${settings.referencing_keys.sample_key} - text_target_key: ${settings.referencing_keys.target_key} - -train_dataset: - component_key: dataset - variant_key: dummy_dataset - config: - num_samples: ${settings.coca_example_settings.train_num_samples} - sample_definition: - - sample_key: images - sample_shape: [3, 256, 256] - sample_type: float - - sample_key: input_ids - sample_shape: [256] - sample_type: int - - sample_key: modality - sample_shape: [1] - sample_type: const - -val_dataset: - component_key: dataset - variant_key: dummy_dataset - config: - num_samples: ${settings.coca_example_settings.val_num_samples} - sample_definition: - - sample_key: images - sample_shape: [3, 256, 256] - sample_type: float - - sample_key: input_ids - sample_shape: [256] - sample_type: int - - sample_key: modality - sample_shape: [1] - sample_type: const - -train_dataloader: - component_key: data_loader - variant_key: default - config: - num_workers: 2 - pin_memory: true - dataloader_tag: train - skip_num_batches: ${settings.training_progress.local_num_seen_batches} - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.step_profile.local_train_micro_batch_size} - drop_last: true - sampler: - component_key: sampler - variant_key: distributed_sampler - config: - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: true - drop_last: true - seed: 42 - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -val_dataloader: - component_key: data_loader - variant_key: default - config: - num_workers: 2 - pin_memory: true - dataloader_tag: val - dataset: - instance_key: val_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.step_profile.local_train_micro_batch_size} - drop_last: true - - sampler: - component_key: sampler - variant_key: distributed_sampler - config: - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: false - drop_last: true - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -eval_dataloaders: - - instance_key: val_dataloader - pass_type: BY_REFERENCE - -checkpoint_saving: - component_key: checkpoint_saving - variant_key: default - config: - checkpoint_saving_strategy: - component_key: checkpoint_saving_strategy - variant_key: save_k_most_recent_checkpoints_strategy - config: - k: -1 # -1 to save all checkpoints - checkpoint_saving_execution: - component_key: checkpoint_saving_execution - variant_key: fsdp - config: - checkpoint_path: ${settings.paths.checkpoint_saving_path} - global_rank: ${settings.cuda_env.global_rank} - experiment_id: ${settings.experiment_id} - -captioning_loss: - component_key: loss - variant_key: cross_entropy_loss - config: - target_key: ${settings.referencing_keys.target_key} - prediction_key: ${settings.referencing_keys.prediction_key} - tag: captioning_loss - -contrastive_loss: - component_key: loss - variant_key: clip_loss - config: - prediction_key1: ${model.config.vision_cls_prediction_key} - prediction_key2: ${model.config.text_cls_prediction_key} - logit_scale_key: ${model.config.logit_scale_prediction_key} - tag: contrastive_loss - -loss_fn: - - instance_key: captioning_loss - pass_type: BY_REFERENCE - - instance_key: contrastive_loss - pass_type: BY_REFERENCE - -wrapped_model: - component_key: model - variant_key: fsdp_wrapped - config: - model: - instance_key: model - pass_type: BY_REFERENCE - sync_module_states: true - mixed_precision_settings: FP_16 - sharding_strategy: HYBRID_SHARD - block_names: [TransformerBlock, VisionTransformerBlock] - -model: - component_key: model - variant_key: model_initialized - config: - model: - instance_key: model_raw - pass_type: BY_REFERENCE - model_initializer: - component_key: model_initialization - variant_key: composed - config: - model_type: coca - weight_init_type: plain - mean: 0.0 - std: 0.02 - -model_raw: - component_key: model - variant_key: coca - config: - prediction_key: logits - modality_key: modality - modality_embd_prediction_key: modality_embeddings - text_embd_prediction_key: text_embeddings - modality_cls_prediction_key: modality_cls - text_cls_prediction_key: text_cls - modality_encoder_config: - sample_key: images - prediction_key: modality_embeddings - img_size: 256 # 288 in the original coca - n_classes: Null # Disable vision transformer head - n_layer: 12 - attention_config: - attention_engine_type: default_attention - n_head: 12 - n_embd: 768 - dropout: 0.0 - patch_size: 16 # 18 in the original coca - patch_stride: 16 # 18 in the original coca - n_img_channels: 3 - add_cls_token: False - bias: True - text_decoder_config: - sample_key: ${settings.referencing_keys.sample_key} - prediction_key: ${model.config.prediction_key} - block_size: 1024 - vocab_size: 50304 # 64k in the original coca - n_layer_text: 12 - n_layer_multimodal_text: 12 - attention_config: - attention_engine_type: default_attention - n_head: 12 - ffn_hidden: 2048 - n_embd: 768 - dropout: 0.0 - bias: true - activation: swiglu - epsilon: 1e-5 - n_pool_head: 12 - n_vision_queries: 256 - n_audio_queries: Null - bias_attn_pool: False - epsilon_attn_pool: 1e-5 - -scheduler: - component_key: scheduler - variant_key: onecycle_lr # COCA uses linear decay - config: - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - max_lr: 8e-4 - div_factor: 10 - final_div_factor: 1 - total_steps: ${settings.training_target.num_target_steps} - pct_start: 0.01 - anneal_strategy: cos - last_epoch: ${settings.training_progress.last_step} - -optimizer: - component_key: optimizer - variant_key: adam_w - config: - lr: 8e-4 - betas: [0.9, 0.999] - eps: 1e-8 - weight_decay: 1e-1 - weight_decay_groups_excluded: [] - wrapped_model: - instance_key: wrapped_model - pass_type: BY_REFERENCE - -gradient_clipper: - component_key: gradient_clipper - variant_key: fsdp_logging_only - config: - wrapped_model: - instance_key: wrapped_model - pass_type: BY_REFERENCE - norm_type: P2_NORM - -progress_subscriber: - component_key: progress_subscriber - variant_key: rich - config: - global_rank: ${settings.cuda_env.global_rank} - num_seen_steps: ${settings.training_progress.num_seen_steps} - num_target_steps: ${settings.training_target.num_target_steps} - train_dataloader_tag: ${train_dataloader.config.dataloader_tag} - eval_dataloaders: - instance_key: eval_dataloaders - pass_type: BY_REFERENCE - -evaluation_subscriber: - component_key: results_subscriber - variant_key: wandb - config: - global_rank: ${settings.cuda_env.global_rank} - project: modalities - mode: ONLINE - experiment_id: ${settings.experiment_id} - directory: wandb_storage - config_file_path: ${settings.config_file_path} diff --git a/config_files/training/config_example_coca_webdataset.yaml b/config_files/training/config_example_coca_webdataset.yaml deleted file mode 100644 index 52506c8c5..000000000 --- a/config_files/training/config_example_coca_webdataset.yaml +++ /dev/null @@ -1,327 +0,0 @@ -settings: - experiment_id: ${modalities_env:experiment_id} - config_file_path: ${modalities_env:config_file_path} - referencing_keys: - sample_key: input_ids - target_key: target_ids - training: - global_training_log_interval_in_steps: 30 # Needs to be a multiple of gradient_acc_steps - global_checkpointing_interval_in_steps: 9_990 - global_evaluation_interval_in_steps: 4_980 - global_num_training_samples: 566748 # 491 steps with 8 gpus and global bs of 1152 - global_num_seen_steps: 0 - do_apply_activation_checkpointing: true - gradient_acc_steps: 30 - local_train_micro_batch_size: 144 - sequence_length: 64 - cuda_env: - local_rank: ${cuda_env:LOCAL_RANK} - global_rank: ${cuda_env:RANK} - world_size: ${cuda_env:WORLD_SIZE} - paths: - checkpointing_path: data/checkpoints - -tokenizer: - component_key: tokenizer - variant_key: pretrained_hf_tokenizer - config: - pretrained_model_name_or_path: openai/clip-vit-base-patch32 - padding: true - max_length: ${settings.training.sequence_length} - -collate_fn: - component_key: collate_fn - variant_key: coca_collator - config: - sample_keys: - - images - - ${settings.referencing_keys.sample_key} - target_keys: [] - text_sample_key: ${settings.referencing_keys.sample_key} - text_target_key: ${settings.referencing_keys.target_key} - -train_image_transform: - component_key: transform - variant_key: image_transform - config: - is_training: True - input_size: 256 - -train_image_transform: - component_key: transform - variant_key: image_transform - config: - is_training: True - input_size: 288 - -text_transform: - component_key: transform - variant_key: text_transform - config: - tokenizer: - instance_key: tokenizer - pass_type: BY_REFERENCE - -train_coco_dataset_builder: - component_key: dataset - variant_key: web_dataset_builder - config: - urls: "/nm-raid/video/multimodal_data/coco_captions/data/train/{000000..000011}.tar" - modality_key_mapping: - TEXT: ["json_text0", "input_ids"] - IMAGE: ["jpg", "images"] - modality_transforms: - IMAGE: - instance_key: train_image_transform - pass_type: BY_REFERENCE - TEXT: - instance_key: text_transform - pass_type: BY_REFERENCE - num_samples: 30000 - -train_dataset_builder: - component_key: dataset - variant_key: web_dataset_builder - config: - urls: /p/scratch/jureap63/multimodal_data/audio/commonvoice_17_test_wav_000000.tar - modality_key_mapping: - TEXT: ["transcript.txt", "input_ids"] # source and target keys - AUDIO: ["wav", "feats"] - modality_transforms: - AUDIO: - instance_key: train_audio_transform - pass_type: BY_REFERENCE - TEXT: - instance_key: text_transform - pass_type: BY_REFERENCE - num_samples: 10000 - -train_dataset: - component_key: dataset - variant_key: web_dataset - config: - builders: - - instance_key: train_dataset_builder - pass_type: BY_REFERENCE - - instance_key: train_coco_dataset_builder - pass_type: BY_REFERENCE - shardshuffle: 100 - repeat: true - resample: true - shuffle_buffer: 10_000 - mixing_ratios: [0.35, 0.65] - batch_size: ${settings.training.local_train_micro_batch_size} - -val_dataset: - component_key: dataset - variant_key: web_dataset - config: - builders: - - instance_key: train_coco_dataset_builder - pass_type: BY_REFERENCE - shardshuffle: 1000 - repeat: true - resample: true - shuffle_buffer: 10_000 - mixing_ratios: [0.35, 0.65] - batch_size: ${settings.training.local_train_micro_batch_size} - -train_dataloader: - component_key: data_loader - variant_key: web_loader - config: - num_workers: 8 - pin_memory: true - drop_last: true - dataloader_tag: "train" - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - batch_size: ${settings.training.local_train_micro_batch_size} - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -val_dataloader: - component_key: data_loader - variant_key: web_loader - config: - num_workers: 8 - pin_memory: true - drop_last: false - dataloader_tag: "val" - dataset: - instance_key: val_dataset - pass_type: BY_REFERENCE - batch_size: ${settings.training.local_train_micro_batch_size} - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -eval_dataloaders: - - instance_key: val_dataloader - pass_type: BY_REFERENCE - -checkpoint_saving: - component_key: checkpoint_saving - variant_key: default - config: - checkpoint_saving_strategy: - component_key: checkpoint_saving_strategy - variant_key: save_k_most_recent_checkpoints_strategy - config: - k: -1 # -1 to save all checkpoints - checkpoint_saving_execution: - component_key: checkpoint_saving_execution - variant_key: fsdp - config: - checkpoint_path: ${settings.paths.checkpointing_path} - global_rank: ${settings.cuda_env.global_rank} - experiment_id: ${settings.experiment_id} - -captioning_loss: - component_key: loss - variant_key: cross_entropy_loss - config: - target_key: ${settings.referencing_keys.target_key} - prediction_key: ${model.config.prediction_key} - tag: captioning_loss - weight: 2.0 - -contrastive_loss: - component_key: loss - variant_key: clip_loss - config: - prediction_keys: - - ${model.config.audio_cls_prediction_key} - - ${model.config.vision_cls_prediction_key} - - ${model.config.text_cls_prediction_key} - logit_scale_key: ${model.config.logit_scale_prediction_key} - tag: contrastive_loss - weight: 1.0 - -loss_fn: - - instance_key: captioning_loss - pass_type: BY_REFERENCE - - instance_key: contrastive_loss - pass_type: BY_REFERENCE - -wrapped_model: - component_key: model - variant_key: fsdp_wrapped - config: - model: - instance_key: model - pass_type: BY_REFERENCE - sync_module_states: true - mixed_precision_settings: FP_16 - sharding_strategy: HYBRID_SHARD - block_names: [TransformerBlock, VisionTransformerBlock] - -model: - component_key: model - variant_key: coca - config: - prediction_key: logits - vision_embd_prediction_key: vision_embeddings - text_embd_prediction_key: text_embeddings - vision_cls_prediction_key: vision_cls - text_cls_prediction_key: text_cls - individual_datasets: false - modality_keys: ${collate_fn.config.sample_keys} - logit_scale_prediction_key: logit_scale - vision_encoder_config: - sample_key: images - prediction_key: vision_embeddings - img_size: 256 # 288 in the original coca - n_classes: Null # Disable vision transformer head - n_layer: 12 - attention_config: - attention_engine_type: default_attention - n_head: 12 - n_embd: 768 - dropout: 0.0 - patch_size: 16 # 18 in the original coca - patch_stride: 16 # 18 in the original coca - n_img_channels: 3 - add_cls_token: False - bias: True - text_decoder_config: - sample_key: ${settings.referencing_keys.sample_key} - prediction_key: ${model.config.prediction_key} - block_size: 1024 - vocab_size: 50304 # 64k in the original coca - n_layer_text: 12 - n_layer_multimodal_text: 12 - attention_config: - attention_engine_type: default_attention - n_head: 12 - ffn_hidden: 2048 - n_embd: 768 - dropout: 0.0 - bias: true - activation: fused_swiglu - epsilon: 1e-5 - n_pool_head: 12 - n_queries: 256 - bias_attn_pool: False - epsilon_attn_pool: 1e-5 - weight_init: - mean: 0.0 - std: 0.02 - -scheduler: - component_key: scheduler - variant_key: cosine_annealing_with_warmup_lr # COCA uses linear decay - config: - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - num_warmup_steps: 2_000 - num_training_steps: 500_000 - -optimizer: - component_key: optimizer - variant_key: adam_w - config: - lr: 8e-4 - betas: [0.9, 0.999] - eps: 1e-8 - weight_decay: 0.01 - wrapped_model: - instance_key: wrapped_model - pass_type: BY_REFERENCE - -gradient_clipper: - component_key: gradient_clipper - variant_key: fsdp - config: - wrapped_model: - instance_key: wrapped_model - pass_type: BY_REFERENCE - norm_type: P2_NORM - max_norm: 1.0 - -batch_progress_subscriber: - component_key: progress_subscriber - variant_key: rich - config: - local_rank: ${settings.cuda_env.global_rank} - global_num_seen_steps: ${settings.training.global_num_seen_steps} - train_dataloader: - instance_key: train_dataloader - pass_type: BY_REFERENCE - eval_dataloaders: - instance_key: eval_dataloaders - pass_type: BY_REFERENCE - -evaluation_subscriber: - component_key: results_subscriber - variant_key: wandb - config: - local_rank: ${settings.cuda_env.global_rank} - project: modalities - mode: OFFLINE - experiment_id: ${settings.experiment_id} - directory: "." - config_file_path: ${settings.config_file_path} diff --git a/config_files/training/config_example_video_coca_webdataset.yaml b/config_files/training/config_example_video_coca_webdataset.yaml deleted file mode 100644 index 240f3c053..000000000 --- a/config_files/training/config_example_video_coca_webdataset.yaml +++ /dev/null @@ -1,298 +0,0 @@ -settings: - experiment_id: ${modalities_env:experiment_id} - config_file_path: ${modalities_env:config_file_path} - referencing_keys: - sample_key: input_ids - target_key: target_ids - training: - global_training_log_interval_in_steps: 30 # Needs to be a multiple of gradient_acc_steps - global_checkpointing_interval_in_steps: 9_990 - global_evaluation_interval_in_steps: 4_980 - global_num_training_samples: 566748 # 491 steps with 8 gpus and global bs of 1152 - global_num_seen_steps: 0 - do_apply_activation_checkpointing: true - gradient_acc_steps: 30 - local_train_micro_batch_size: 8 - sequence_length: 64 - cuda_env: - local_rank: ${cuda_env:LOCAL_RANK} - global_rank: ${cuda_env:RANK} - world_size: ${cuda_env:WORLD_SIZE} - paths: - checkpointing_path: data/checkpoints - -tokenizer: - component_key: tokenizer - variant_key: pretrained_hf_tokenizer - config: - pretrained_model_name_or_path: openai/clip-vit-base-patch32 - padding: true - max_length: ${settings.training.sequence_length} - -collate_fn: - component_key: collate_fn - variant_key: coca_collator - config: - sample_keys: - - videos - - ${settings.referencing_keys.sample_key} - target_keys: [] - text_sample_key: ${settings.referencing_keys.sample_key} - text_target_key: ${settings.referencing_keys.target_key} - -train_video_transform: - component_key: transform - variant_key: video_transform - config: - is_training: True - input_size: 256 - num_frames: ${model.config.vision_encoder_config.num_video_frames} - -text_transform: - component_key: transform - variant_key: text_transform - config: - tokenizer: - instance_key: tokenizer - pass_type: BY_REFERENCE - -train_video_builder: - component_key: dataset - variant_key: web_dataset_builder - config: - urls: "/nm-raid/video/Kinetics/kinetics-dataset/k400/dummy_wds/{000000..000010}.tar" - modality_key_mapping: - TEXT: ["json", "input_ids"] - VIDEO: ["mp4", "videos"] - modality_transforms: - VIDEO: - instance_key: train_video_transform - pass_type: BY_REFERENCE - TEXT: - instance_key: text_transform - pass_type: BY_REFERENCE - num_samples: 566_748 - -train_dataset: - component_key: dataset - variant_key: web_dataset - config: - builders: - - instance_key: train_video_builder - pass_type: BY_REFERENCE - shardshuffle: 100 - repeat: true - resample: true - shuffle_buffer: 10_000 - -val_dataset: - component_key: dataset - variant_key: web_dataset - config: - builders: - - instance_key: train_video_builder - pass_type: BY_REFERENCE - shardshuffle: 1000 - repeat: true - resample: true - shuffle_buffer: 10_000 - -train_dataloader: - component_key: data_loader - variant_key: web_loader - config: - num_workers: 8 - pin_memory: true - drop_last: true - dataloader_tag: "train" - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - batch_size: ${settings.training.local_train_micro_batch_size} - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -val_dataloader: - component_key: data_loader - variant_key: web_loader - config: - num_workers: 8 - pin_memory: true - drop_last: false - dataloader_tag: "val" - dataset: - instance_key: val_dataset - pass_type: BY_REFERENCE - batch_size: ${settings.training.local_train_micro_batch_size} - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -eval_dataloaders: - - instance_key: val_dataloader - pass_type: BY_REFERENCE - -checkpoint_saving: - component_key: checkpoint_saving - variant_key: default - config: - checkpoint_saving_strategy: - component_key: checkpoint_saving_strategy - variant_key: save_k_most_recent_checkpoints_strategy - config: - k: -1 # -1 to save all checkpoints - checkpoint_saving_execution: - component_key: checkpoint_saving_execution - variant_key: fsdp - config: - checkpoint_path: ${settings.paths.checkpointing_path} - global_rank: ${settings.cuda_env.global_rank} - experiment_id: ${settings.experiment_id} - -captioning_loss: - component_key: loss - variant_key: cross_entropy_loss - config: - target_key: ${settings.referencing_keys.target_key} - prediction_key: ${model.config.prediction_key} - tag: captioning_loss - weight: 2.0 - -contrastive_loss: - component_key: loss - variant_key: clip_loss - config: - prediction_key1: ${model.config.vision_cls_prediction_key} - prediction_key2: ${model.config.text_cls_prediction_key} - logit_scale_key: ${model.config.logit_scale_prediction_key} - tag: contrastive_loss - weight: 1.0 - -loss_fn: - - instance_key: captioning_loss - pass_type: BY_REFERENCE - - instance_key: contrastive_loss - pass_type: BY_REFERENCE - -wrapped_model: - component_key: model - variant_key: fsdp_wrapped - config: - model: - instance_key: model - pass_type: BY_REFERENCE - sync_module_states: true - mixed_precision_settings: FP_16 - sharding_strategy: HYBRID_SHARD - block_names: [TransformerBlock, PerceiverTransformerBlock] - -model: - component_key: model - variant_key: coca - config: - prediction_key: logits - vision_embd_prediction_key: vision_embeddings - text_embd_prediction_key: text_embeddings - vision_cls_prediction_key: vision_cls - text_cls_prediction_key: text_cls - individual_datasets: false - modality_keys: ${collate_fn.config.sample_keys} - logit_scale_prediction_key: logit_scale - vision_encoder_config: - sample_key: videos - prediction_key: vision_embeddings - img_size: 256 # 288 in the original coca - n_classes: Null # Disable vision transformer head - n_layer: 12 - attention_config: - attention_engine_type: default_attention - n_head: 12 - n_embd: 768 - dropout: 0.0 - patch_size: 16 # 18 in the original coca - patch_stride: 16 # 18 in the original coca - n_img_channels: 3 - add_cls_token: False - bias: True - num_video_frames: 16 - n_latents: 64 - text_decoder_config: - sample_key: ${settings.referencing_keys.sample_key} - prediction_key: ${model.config.prediction_key} - block_size: 1024 - vocab_size: 50304 # 64k in the original coca - n_layer_text: 12 - n_layer_multimodal_text: 12 - attention_config: - attention_engine_type: default_attention - n_head: 12 - ffn_hidden: 2048 - n_embd: 768 - dropout: 0.0 - bias: true - activation: fused_swiglu - epsilon: 1e-5 - n_pool_head: 16 - n_queries: 256 - bias_attn_pool: False - epsilon_attn_pool: 1e-5 - weight_init: - mean: 0.0 - std: 0.02 - -scheduler: - component_key: scheduler - variant_key: cosine_annealing_with_warmup_lr # COCA uses linear decay - config: - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - num_warmup_steps: 2_000 - num_training_steps: 500_000 - -optimizer: - component_key: optimizer - variant_key: adam_w - config: - lr: 8e-4 - betas: [0.9, 0.999] - eps: 1e-8 - weight_decay: 0.01 - wrapped_model: - instance_key: wrapped_model - pass_type: BY_REFERENCE - -gradient_clipper: - component_key: gradient_clipper - variant_key: fsdp - config: - wrapped_model: - instance_key: wrapped_model - pass_type: BY_REFERENCE - norm_type: P2_NORM - max_norm: 1.0 - -batch_progress_subscriber: - component_key: progress_subscriber - variant_key: rich - config: - local_rank: ${settings.cuda_env.global_rank} - global_num_seen_steps: ${settings.training.global_num_seen_steps} - train_dataloader: - instance_key: train_dataloader - pass_type: BY_REFERENCE - eval_dataloaders: - instance_key: eval_dataloaders - pass_type: BY_REFERENCE - -evaluation_subscriber: - component_key: results_subscriber - variant_key: wandb - config: - local_rank: ${settings.cuda_env.global_rank} - project: modalities - mode: OFFLINE - experiment_id: ${settings.experiment_id} - directory: "." - config_file_path: ${settings.config_file_path} From 1664744f09ec452dfd52ec23c5cd5e32d46ea370 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Tue, 17 Sep 2024 16:46:06 +0200 Subject: [PATCH 129/161] fix: misc fixes after merging main into feat/coca --- src/modalities/__main__.py | 1 + src/modalities/dataloader/dataset.py | 1 + .../logging_broker/subscriber_impl/results_subscriber.py | 2 +- src/modalities/models/coca/coca_model.py | 4 ++-- src/modalities/running_env/fsdp/fsdp_auto_wrapper.py | 6 +----- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index c96e0dd19..ef93cdbe9 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -1,5 +1,6 @@ #!/usr/bin/env python +import logging import os import shutil from datetime import datetime diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 63a32ff72..aedb97bff 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -11,6 +11,7 @@ import jq import numpy as np import torch +import torchaudio import webdataset as wds from pydantic import BaseModel, Field from timm.data import create_transform diff --git a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py index 9ca24b2f9..c7c5160ae 100644 --- a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py @@ -105,7 +105,7 @@ def consume_message(self, message: Message[EvaluationResultBatch]): wandb.log(data=throughput_metrics, step=eval_result.num_train_steps_done) - num_samples = eval_result.train_step_id + 1 + num_samples = eval_result.num_train_steps_done group_content = [f"Train [{num_samples}]:"] losses = [f"{k}: {v}" for k, v in losses.items()] diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index dc7200dca..8d858cc18 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -138,8 +138,6 @@ def __init__( n_queries: Optional[int], bias_attn_pool: bool, epsilon_attn_pool: float, - modality_encoder_config: VisionTransformerConfig | AudioTransformerConfig | AVConfig, - text_decoder_config: TextDecoderConfig, ) -> None: """ Initializes the CocaModel object. @@ -254,6 +252,7 @@ def __init__( self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # apply special scaled init to the residual projections, per GPT-2 paper + ''' for pn, p in self.named_parameters(): if pn.endswith("c_proj.weight"): torch.nn.init.normal_( @@ -262,6 +261,7 @@ def __init__( std=weight_init.std / math.sqrt(2 * (text_decoder_config.n_layer_text + text_decoder_config.n_layer_multimodal_text)), ) + ''' def _init_modality(self, encoder_class, encoder_config, n_queries): encoder = encoder_class(**dict(encoder_config)) diff --git a/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py index 5c5bbcd19..22d463f3a 100644 --- a/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py +++ b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py @@ -28,12 +28,8 @@ def _get_fsdp_blocks_from_block_names(model: nn.Module, block_names: List[str]) for cls_block_name in block_names: # TODO FullyShardedDataParallelPlugin from Accelerate uses string matching to find the correct # block class. In the long-term we should implmement this ourselves in a robuster fashion. - try: - block_type = FullyShardedDataParallelPlugin.get_module_class_from_name(model, cls_block_name) - except AttributeError: - from accelerate.utils.dataclasses import get_module_class_from_name + block_type = get_module_class_from_name(model, cls_block_name) - block_type = get_module_class_from_name(model, cls_block_name) if block_type is None: raise ValueError(f"Could not find block with name {cls_block_name} in model") fsdp_block_types.append(block_type) From f31be51793c97d1010382297e90a2a5e23818804 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Tue, 17 Sep 2024 17:10:01 +0200 Subject: [PATCH 130/161] chore: remove comment and unused file --- src/modalities/trainer.py | 1 - start.sh | 3 --- 2 files changed, 4 deletions(-) delete mode 100644 start.sh diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 6ae9c034e..e81a3aa6e 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -257,7 +257,6 @@ def train( synced_num_samples_per_second = synced_num_samples / synced_forward_backward_time # TODO: insert reducer from outside so Trainer is independent of FSDP # add the loss and gradient norm for the LAST batch - # cumulated_losses[1] = batch_loss.item() reduced_losses = Reducer.reduce( tensor=cumulated_losses, diff --git a/start.sh b/start.sh deleted file mode 100644 index 236ff7e68..000000000 --- a/start.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -CUDA_VISIBLE_DEVICES=0 torchrun --nnodes 1 --nproc_per_node 1 --rdzv-endpoint=0.0.0.0:29502 src/modalities/__main__.py run --config_file_path config_files/training/config_example_coca_webdataset.yaml \ No newline at end of file From 91934fd1ab3ac8a3c4f3addd9174da5362895fe7 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Tue, 17 Sep 2024 17:10:54 +0200 Subject: [PATCH 131/161] refactor: out_put -> output --- tests/models/vision_transformer/test_vision_transformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/vision_transformer/test_vision_transformer.py b/tests/models/vision_transformer/test_vision_transformer.py index 41a1afb2a..be72a6de1 100644 --- a/tests/models/vision_transformer/test_vision_transformer.py +++ b/tests/models/vision_transformer/test_vision_transformer.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize( - "input,sample_key,n_classes,num_video_frames,add_cls_token,out_put", + "input,sample_key,n_classes,num_video_frames,add_cls_token,output", [ (torch.randn(1, 3, 224, 224), "images", 1000, 1, True, (1, 1000)), (torch.randn(1, 3, 224, 224), "images", None, 1, True, (1, 197, 768)), @@ -21,7 +21,7 @@ (torch.randn(1, 16, 3, 224, 224), "videos", 1000, 16, False, (1, 1000)), ], ) -def test_vision_transformer(input, sample_key, n_classes, num_video_frames, add_cls_token, out_put): +def test_vision_transformer(input, sample_key, n_classes, num_video_frames, add_cls_token, output): # Create model config_file_path = _ROOT_DIR / Path("tests/models/vision_transformer/vision_transformer_config.yaml") config_dict = load_app_config_dict(config_file_path=config_file_path) @@ -48,7 +48,7 @@ def test_vision_transformer(input, sample_key, n_classes, num_video_frames, add_ # Test outputs assert "logits" in out - assert out["logits"].shape == out_put + assert out["logits"].shape == output @pytest.mark.parametrize( From cdbb9568a58b4237a60450e4f86f3510df7756fe Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Tue, 17 Sep 2024 17:15:59 +0200 Subject: [PATCH 132/161] fix: reset cumulated losses using function --- src/modalities/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index e81a3aa6e..1128a6b95 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -172,6 +172,7 @@ def train( None """ model.train() + cumulated_losses = self._reset_tracked_losses(len(loss_fun)) # throughput & MFU thoughput_aggregator = Aggregator[ThroughputAggregationKeys]() @@ -180,8 +181,6 @@ def train( device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - cumulated_losses = torch.zeros(len(loss_fun) + 1 + 1).to(device) - # batch loop batch: DatasetBatch # TODO: why do we need a barrier here? From 9bcc475566f42609a9291f8622410e0d52f4196f Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Fri, 20 Sep 2024 10:56:15 +0200 Subject: [PATCH 133/161] fix: scaled weight initialization for residual layers of coca --- .../training/config_coca_img_aud_vid_dataset.yaml | 6 +++--- src/modalities/models/coca/coca_model.py | 14 -------------- .../model_initialization/parameter_name_filters.py | 10 ++++++++-- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/config_files/training/config_coca_img_aud_vid_dataset.yaml b/config_files/training/config_coca_img_aud_vid_dataset.yaml index ffa74df21..76076923b 100644 --- a/config_files/training/config_coca_img_aud_vid_dataset.yaml +++ b/config_files/training/config_coca_img_aud_vid_dataset.yaml @@ -403,7 +403,7 @@ model: weight_init_type: scaled mean: 0.0 std: 0.02 - num_layers: ${model_raw.config.text_decoder_config.n_layer_text} + num_layers: 4 # text_decoder_config.n_layer_text + text_decoder_config.n_layer_multimodal_text model_raw: component_key: model @@ -476,8 +476,8 @@ model_raw: prediction_key: ${model_raw.config.prediction_key} block_size: 512 vocab_size: 50304 # 64k in the original coca - n_layer_text: 2 - n_layer_multimodal_text: 2 + n_layer_text: 2 # update model_initializer num_layers + n_layer_multimodal_text: 2 # update model_initializer num_layers attention_config: attention_engine_type: default_attention n_head: 12 diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index 8d858cc18..14bf5d0ab 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -1,5 +1,3 @@ -import math -from functools import partial from typing import Annotated, Dict, List, Optional, Tuple import numpy as np @@ -251,18 +249,6 @@ def __init__( # Logit scale for contrastive loss self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - # apply special scaled init to the residual projections, per GPT-2 paper - ''' - for pn, p in self.named_parameters(): - if pn.endswith("c_proj.weight"): - torch.nn.init.normal_( - p, - mean=weight_init.mean, - std=weight_init.std - / math.sqrt(2 * (text_decoder_config.n_layer_text + text_decoder_config.n_layer_multimodal_text)), - ) - ''' - def _init_modality(self, encoder_class, encoder_config, n_queries): encoder = encoder_class(**dict(encoder_config)) queries = nn.Parameter(torch.randn(n_queries + 1, encoder_config.n_embd)) diff --git a/src/modalities/nn/model_initialization/parameter_name_filters.py b/src/modalities/nn/model_initialization/parameter_name_filters.py index 4f24c5aa3..eca76ea5e 100644 --- a/src/modalities/nn/model_initialization/parameter_name_filters.py +++ b/src/modalities/nn/model_initialization/parameter_name_filters.py @@ -68,9 +68,15 @@ class RegexFilter(BaseModel): SupportWeightInitModels.COCA: { # we reject all bias and weight parameters belonging to norms WeightInitTypes.PLAIN: RegexFilter( - weights=[r"^(?!.*norm)(?!.*ln_).*\.weight$"], biases=[r"^(?!.*norm)(?!.*ln_).*\.bias$"] + weights=[r"^(?!.*norm)(?!.*ln).*\.weight$"], biases=[r"^(?!.*norm)(?!.*ln).*\.bias$"] + ), + # scaled init for residual layers: + # https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf (pp 4) + WeightInitTypes.SCALED: RegexFilter( + weights=[ + r"transformer\.h\.\d+\.attn\.c_proj\.weight", + ] ), - WeightInitTypes.SCALED: RegexFilter(weights=[], biases=[]), WeightInitTypes.SCALED_EMBED: RegexFilter(weights=[], biases=[]), }, } From 25130cdb85ee93d37a6ca6265358669d3b668459 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Fri, 20 Sep 2024 14:35:18 +0200 Subject: [PATCH 134/161] refactor: replace CosineAnnealingWithWarmupLR with OneCycleLR --- .../training/config_coca_img_aud_vid_dataset.yaml | 11 ++++++++--- src/modalities/config/config.py | 7 ------- src/modalities/registry/components.py | 8 -------- 3 files changed, 8 insertions(+), 18 deletions(-) diff --git a/config_files/training/config_coca_img_aud_vid_dataset.yaml b/config_files/training/config_coca_img_aud_vid_dataset.yaml index 76076923b..4674a39ab 100644 --- a/config_files/training/config_coca_img_aud_vid_dataset.yaml +++ b/config_files/training/config_coca_img_aud_vid_dataset.yaml @@ -494,13 +494,18 @@ model_raw: scheduler: component_key: scheduler - variant_key: cosine_annealing_with_warmup_lr # COCA uses linear decay + variant_key: onecycle_lr config: optimizer: instance_key: optimizer pass_type: BY_REFERENCE - num_warmup_steps: 2_000 - num_training_steps: 500_000 + max_lr: 8e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.02 + anneal_strategy: linear # COCA uses linear decay + last_epoch: ${settings.training_progress.last_step} optimizer: component_key: optimizer diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 560d58e1e..df1197066 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -188,13 +188,6 @@ class CosineAnnealingLRSchedulerConfig(BaseModel): verbose: bool = False -class CosineAnnealingWithWarmupLRSchedulerConfig(BaseModel): - optimizer: PydanticOptimizerIFType - num_warmup_steps: Annotated[int, Field(strict=True, gt=0)] - num_training_steps: Annotated[int, Field(strict=True, gt=0)] - last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1 - - class CheckpointedOptimizerConfig(BaseModel): checkpoint_loading: PydanticCheckpointLoadingIFType checkpoint_path: Path diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 856de5d9b..5c598e627 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn -import transformers from pydantic import BaseModel from torch.utils.data import BatchSampler, DistributedSampler @@ -25,7 +24,6 @@ CheckpointSavingConfig, ConstantLRSchedulerConfig, CosineAnnealingLRSchedulerConfig, - CosineAnnealingWithWarmupLRSchedulerConfig, DistributedSamplerConfig, DummyLRSchedulerConfig, DummyProgressSubscriberConfig, @@ -182,12 +180,6 @@ class ComponentEntity: ComponentEntity( "scheduler", "cosine_annealing_lr", torch.optim.lr_scheduler.CosineAnnealingLR, CosineAnnealingLRSchedulerConfig ), - ComponentEntity( - "scheduler", - "cosine_annealing_with_warmup_lr", - transformers.get_linear_schedule_with_warmup, - CosineAnnealingWithWarmupLRSchedulerConfig, - ), # tokenizers ComponentEntity("tokenizer", "pretrained_hf_tokenizer", PreTrainedHFTokenizer, PreTrainedHFTokenizerConfig), ComponentEntity("tokenizer", "pretrained_sp_tokenizer", PreTrainedSPTokenizer, PreTrainedSPTokenizerConfig), From a97f448fdfbce67ee39ba93b72783c37ec609773 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Fri, 20 Sep 2024 16:46:23 +0200 Subject: [PATCH 135/161] refactor: set weight decay groups for coca --- .../training/config_coca_img_aud_vid_dataset.yaml | 2 +- src/modalities/__main__.py | 1 - src/modalities/config/component_factory.py | 1 - .../audio_transformer/audio_transformer_model.py | 6 +++--- src/modalities/models/coca/coca_model.py | 12 +++++++++++- 5 files changed, 15 insertions(+), 7 deletions(-) diff --git a/config_files/training/config_coca_img_aud_vid_dataset.yaml b/config_files/training/config_coca_img_aud_vid_dataset.yaml index 4674a39ab..d372e5051 100644 --- a/config_files/training/config_coca_img_aud_vid_dataset.yaml +++ b/config_files/training/config_coca_img_aud_vid_dataset.yaml @@ -515,7 +515,7 @@ optimizer: betas: [0.9, 0.999] eps: 1e-8 weight_decay: 1e-1 - weight_decay_groups_excluded: [] + weight_decay_groups_excluded: [embedding, norm, parameter] wrapped_model: instance_key: wrapped_model pass_type: BY_REFERENCE diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index ef93cdbe9..fd139fefd 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -9,7 +9,6 @@ import click import click_pathlib -import torch.distributed as dist from pydantic import BaseModel, FilePath from modalities.batch import EvaluationResultBatch diff --git a/src/modalities/config/component_factory.py b/src/modalities/config/component_factory.py index 73e92417a..94ed71e69 100644 --- a/src/modalities/config/component_factory.py +++ b/src/modalities/config/component_factory.py @@ -1,4 +1,3 @@ -import os from typing import Any, Dict, List, Type, TypeVar, Union from pydantic import BaseModel diff --git a/src/modalities/models/audio_transformer/audio_transformer_model.py b/src/modalities/models/audio_transformer/audio_transformer_model.py index 2244e4f55..61e99bdad 100644 --- a/src/modalities/models/audio_transformer/audio_transformer_model.py +++ b/src/modalities/models/audio_transformer/audio_transformer_model.py @@ -92,8 +92,8 @@ def __init__( act_fn=nn.SiLU, dropout=ffmodule_dropout, ) - self.mhsa_ln = nn.LayerNorm(n_embd) - self.mhsa = MultiHeadAttention( + self.ln_mhsa = nn.LayerNorm(n_embd) + self.attn = MultiHeadAttention( attention_config=attention_config, attention_type=AttentionType.NON_CAUSAL_SELF_ATTENTION, n_embd=n_embd, @@ -125,7 +125,7 @@ def forward( ) -> torch.Tensor: x = self.ln1(x) # x.shape: B, T, D x = x + 0.5 * self.entry_ffmodule(x) - x = x + self.mhsa(self.mhsa_ln(x), mask=mask) + x = x + self.attn(self.ln_mhsa(x), mask=mask) x = x + self.convmodule(x) x = self.ln2(x) x = x + 0.5 * self.exit_ffmodule(x) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index 14bf5d0ab..26ce17922 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -97,6 +97,7 @@ class CoCaConfig(BaseModel): n_queries: Optional[Annotated[int, Field(ge=1)]] bias_attn_pool: bool epsilon_attn_pool: Annotated[float, Field(ge=0.0)] + seed: Optional[int] = None class CoCa(NNModel): @@ -136,6 +137,7 @@ def __init__( n_queries: Optional[int], bias_attn_pool: bool, epsilon_attn_pool: float, + seed: int = None, ) -> None: """ Initializes the CocaModel object. @@ -153,11 +155,19 @@ def __init__( epsilon_attn_pool (float): The epsilon value for attention pooling. vision_encoder_config (VisionTransformerConfig): The configuration for the vision encoder. text_decoder_config (TextDecoderConfig): The configuration for the text decoder. + seed (int, optional): The random seed. Defaults to None. Returns: None """ - super().__init__() + weight_decay_groups = { + "linear": ["attention", "\.attn", "\.cross_attn", "\.post_subsampler", "_ffmodule", "mlp"], + "conv": ["embedding_fn\.conv", "project", "\.subsampler", "pointwise_conv", "depthwise_conv"], + "embedding": ["wte", "wpe", "positional_embedding", "time_embd"], + "norm": ["norm", "\.ln_", "\.ln", "\.bn", "exit_ln"], + "parameter": ["_queries", "logit_scale", "\.latents", "cls_token"], + } + super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) self.prediction_key = prediction_key self.text_embd_prediction_key = text_embd_prediction_key From 3b89853d017caa1b1edc2f65d1405e2197483587 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Mon, 23 Sep 2024 10:38:04 +0200 Subject: [PATCH 136/161] fix: update path for coca tokenizer --- config_files/training/config_coca_img_aud_vid_dataset.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config_files/training/config_coca_img_aud_vid_dataset.yaml b/config_files/training/config_coca_img_aud_vid_dataset.yaml index d372e5051..59b6c0dd9 100644 --- a/config_files/training/config_coca_img_aud_vid_dataset.yaml +++ b/config_files/training/config_coca_img_aud_vid_dataset.yaml @@ -56,7 +56,7 @@ tokenizer: component_key: tokenizer variant_key: pretrained_hf_tokenizer config: - pretrained_model_name_or_path: tokenizer/ + pretrained_model_name_or_path: openai/clip-vit-base-patch32 padding: true max_length: ${settings.step_profile.sequence_length} From 8f7a114285ea639dd3d7d292ca7a13ad65339bca Mon Sep 17 00:00:00 2001 From: Thomas Holz Date: Mon, 23 Sep 2024 14:00:45 +0000 Subject: [PATCH 137/161] chore: use built-in types --- src/modalities/__main__.py | 10 ++-- src/modalities/batch.py | 18 +++--- .../checkpointing/checkpoint_saving.py | 5 +- .../checkpoint_saving_instruction.py | 5 +- .../checkpoint_saving_strategies.py | 18 +++--- .../fsdp/fsdp_checkpoint_loading.py | 5 +- .../fsdp/fsdp_checkpoint_saving.py | 3 +- src/modalities/config/component_factory.py | 26 ++++---- src/modalities/config/config.py | 32 +++++----- src/modalities/config/instantiation_models.py | 18 +++--- src/modalities/config/utils.py | 4 +- .../dataloader/create_packed_data.py | 16 ++--- src/modalities/dataloader/dataloader.py | 7 ++- src/modalities/dataloader/dataset.py | 60 +++++++++---------- src/modalities/dataloader/dataset_factory.py | 8 +-- .../dataloader/large_file_lines_reader.py | 8 +-- src/modalities/evaluator.py | 16 +++-- src/modalities/gym.py | 8 +-- .../logging_broker/message_broker.py | 3 +- src/modalities/logging_broker/subscriber.py | 4 +- .../subscriber_impl/progress_subscriber.py | 14 ++--- .../subscriber_impl/results_subscriber.py | 8 +-- .../subscriber_impl/subscriber_factory.py | 6 +- .../audio_transformer_model.py | 6 +- src/modalities/models/coca/coca_model.py | 22 +++---- src/modalities/models/coca/collator.py | 19 +++--- .../models/coca/multi_modal_decoder.py | 5 +- src/modalities/models/coca/text_decoder.py | 4 +- src/modalities/models/gpt2/collator.py | 9 ++- src/modalities/models/gpt2/gpt2_model.py | 40 ++++++------- .../models/gpt2/pretrained_gpt_model.py | 4 +- .../models/huggingface/huggingface_model.py | 12 ++-- .../models/huggingface_adapters/hf_adapter.py | 22 +++---- src/modalities/models/model.py | 12 ++-- src/modalities/models/model_factory.py | 9 ++- src/modalities/models/utils.py | 5 +- .../vision_transformer_model.py | 16 ++--- src/modalities/nn/attention.py | 4 +- .../composed_initialization.py | 8 +-- .../initialization_routines.py | 18 +++--- .../parameter_name_filters.py | 6 +- src/modalities/optimizers/lr_schedulers.py | 5 +- .../optimizers/optimizer_factory.py | 25 ++++---- src/modalities/registry/registry.py | 10 ++-- .../running_env/fsdp/fsdp_auto_wrapper.py | 6 +- .../tokenization/tokenizer_wrapper.py | 34 +++++------ src/modalities/trainer.py | 8 +-- .../training/activation_checkpointing.py | 5 +- src/modalities/util.py | 5 +- src/modalities/utils/mfu.py | 6 +- .../pytorch/test_torch_checkpoint_loading.py | 4 +- .../test_checkpoint_strategies.py | 4 +- .../test_fsdp_to_disc_checkpointing.py | 17 +++--- tests/config/components.py | 3 +- tests/config/configs.py | 4 +- tests/conftest.py | 3 +- tests/dataloader/dummy_sequential_dataset.py | 4 +- tests/dataloader/test_dataloader.py | 6 +- tests/end2end_tests/test_fsdp_warmstart.py | 12 ++-- tests/test_initialization.py | 8 +-- tests/test_optimizer_factory.py | 3 +- tests/utils/test_mfu.py | 4 +- tutorials/library_usage/README.md | 2 +- tutorials/library_usage/main.py | 3 +- 64 files changed, 338 insertions(+), 366 deletions(-) diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index fd139fefd..3ad8469bf 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -5,7 +5,7 @@ import shutil from datetime import datetime from pathlib import Path -from typing import List, Tuple, Type +from typing import Type import click import click_pathlib @@ -198,7 +198,7 @@ def entry_point_pack_encoded_data(config_path: FilePath): @data.command(name="merge_packed_data") @click.argument("src_paths", type=click.types.Path(exists=True, path_type=Path), nargs=-1, required=True) @click.argument("target_path", type=click.types.Path(file_okay=False, dir_okay=False, path_type=Path)) -def entry_point_merge_packed_data(src_paths: List[Path], target_path: Path): +def entry_point_merge_packed_data(src_paths: list[Path], target_path: Path): """Utility for merging different pbin-files into one. This is especially useful, if different datasets were at different points in time or if one encoding takes so long, that the overall process was done in chunks. @@ -207,7 +207,7 @@ def entry_point_merge_packed_data(src_paths: List[Path], target_path: Path): Specify an arbitrary amount of pbin-files and/or directory containing such as input. Args: - src_paths (List[Path]): List of paths to the pbin-files or directories containing such. + src_paths (list[Path]): List of paths to the pbin-files or directories containing such. target_path (Path): The path to the merged pbin-file, that will be created. """ input_files = [] @@ -364,7 +364,7 @@ def get_logging_publishers( results_subscriber: MessageSubscriberIF[EvaluationResultBatch], global_rank: int, local_rank: int, - ) -> Tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]: + ) -> tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]: """Returns the logging publishers for the training. These publishers are used to pass the evaluation results and the progress updates to the message broker. @@ -377,7 +377,7 @@ def get_logging_publishers( local_rank (int): The local rank of the current process on the current node Returns: - Tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]: The evaluation + tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]: The evaluation result publisher and the progress publisher """ message_broker = MessageBroker() diff --git a/src/modalities/batch.py b/src/modalities/batch.py index 746cbdc2b..cd32ed807 100644 --- a/src/modalities/batch.py +++ b/src/modalities/batch.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Dict, Optional +from typing import Optional import torch @@ -32,8 +32,8 @@ class Batch(ABC): class DatasetBatch(Batch, TorchDeviceMixin): """A batch of samples and its targets. Used to batch train a model.""" - samples: Dict[str, torch.Tensor] - targets: Dict[str, torch.Tensor] + samples: dict[str, torch.Tensor] + targets: dict[str, torch.Tensor] batch_dim: int = 0 def to(self, device: torch.device): @@ -58,8 +58,8 @@ def __len__(self) -> int: class InferenceResultBatch(Batch, TorchDeviceMixin): """Stores targets and predictions of an entire batch.""" - targets: Dict[str, torch.Tensor] - predictions: Dict[str, torch.Tensor] + targets: dict[str, torch.Tensor] + predictions: dict[str, torch.Tensor] batch_dim: int = 0 def to_cpu(self): @@ -106,12 +106,12 @@ class EvaluationResultBatch(Batch): dataloader_tag: str num_train_steps_done: int - losses: Dict[str, ResultItem] = field(default_factory=dict) - metrics: Dict[str, ResultItem] = field(default_factory=dict) - throughput_metrics: Dict[str, ResultItem] = field(default_factory=dict) + losses: dict[str, ResultItem] = field(default_factory=dict) + metrics: dict[str, ResultItem] = field(default_factory=dict) + throughput_metrics: dict[str, ResultItem] = field(default_factory=dict) def __str__(self) -> str: - def _round_result_item_dict(result_item_dict: Dict[str, ResultItem]) -> Dict[str, ResultItem]: + def _round_result_item_dict(result_item_dict: dict[str, ResultItem]) -> dict[str, ResultItem]: rounded_result_item_dict = {} for k, item in result_item_dict.items(): if item.decimal_places is not None: diff --git a/src/modalities/checkpointing/checkpoint_saving.py b/src/modalities/checkpointing/checkpoint_saving.py index 986e71f74..2a47648e5 100644 --- a/src/modalities/checkpointing/checkpoint_saving.py +++ b/src/modalities/checkpointing/checkpoint_saving.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Dict import torch.nn as nn from torch.optim import Optimizer @@ -43,7 +42,7 @@ def __init__( def save_checkpoint( self, training_progress: TrainingProgress, - evaluation_result: Dict[str, EvaluationResultBatch], + evaluation_result: dict[str, EvaluationResultBatch], model: nn.Module, optimizer: Optimizer, early_stoppping_criterion_fulfilled: bool = False, @@ -53,7 +52,7 @@ def save_checkpoint( Args: training_progress (TrainingProgress): The training progress. - evaluation_result (Dict[str, EvaluationResultBatch]): The evaluation result. + evaluation_result (dict[str, EvaluationResultBatch]): The evaluation result. model (nn.Module): The model to be saved. optimizer (Optimizer): The optimizer to be saved. early_stoppping_criterion_fulfilled (bool, optional): diff --git a/src/modalities/checkpointing/checkpoint_saving_instruction.py b/src/modalities/checkpointing/checkpoint_saving_instruction.py index 1bd424704..c9a80b709 100644 --- a/src/modalities/checkpointing/checkpoint_saving_instruction.py +++ b/src/modalities/checkpointing/checkpoint_saving_instruction.py @@ -1,5 +1,4 @@ from dataclasses import dataclass, field -from typing import List from modalities.training.training_progress import TrainingProgress @@ -11,8 +10,8 @@ class CheckpointingInstruction: Attributes: save_current (bool): Indicates whether to save the current checkpoint. - checkpoints_to_delete (List[TrainingProgress]): List of checkpoint IDs to delete. + checkpoints_to_delete (list[TrainingProgress]): List of checkpoint IDs to delete. """ save_current: bool = False - checkpoints_to_delete: List[TrainingProgress] = field(default_factory=list) + checkpoints_to_delete: list[TrainingProgress] = field(default_factory=list) diff --git a/src/modalities/checkpointing/checkpoint_saving_strategies.py b/src/modalities/checkpointing/checkpoint_saving_strategies.py index e4902600c..50c72bee9 100644 --- a/src/modalities/checkpointing/checkpoint_saving_strategies.py +++ b/src/modalities/checkpointing/checkpoint_saving_strategies.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Optional from modalities.batch import EvaluationResultBatch from modalities.checkpointing.checkpoint_saving_instruction import CheckpointingInstruction @@ -13,7 +13,7 @@ class CheckpointSavingStrategyIF(ABC): def get_checkpoint_instruction( self, training_progress: TrainingProgress, - evaluation_result: Optional[Dict[str, EvaluationResultBatch]] = None, + evaluation_result: Optional[dict[str, EvaluationResultBatch]] = None, early_stoppping_criterion_fulfilled: bool = False, ) -> CheckpointingInstruction: """ @@ -21,7 +21,7 @@ def get_checkpoint_instruction( Parameters: training_progress (TrainingProgress): The training progress. - evaluation_result (Dict[str, EvaluationResultBatch] | None, optional): + evaluation_result (dict[str, EvaluationResultBatch] | None, optional): The evaluation result. Defaults to None. early_stoppping_criterion_fulfilled (bool, optional): Whether the early stopping criterion is fulfilled. Defaults to False. @@ -45,13 +45,13 @@ def __init__(self, k: int = -1): Set to a positive integer to save the specified number of checkpointsStrategy for saving the k most recent checkpoints only. """ - self.saved_step_checkpoints: List[TrainingProgress] = [] + self.saved_step_checkpoints: list[TrainingProgress] = [] self.k = k def get_checkpoint_instruction( self, training_progress: TrainingProgress, - evaluation_result: Dict[str, EvaluationResultBatch] | None = None, + evaluation_result: dict[str, EvaluationResultBatch] | None = None, early_stoppping_criterion_fulfilled: bool = False, ) -> CheckpointingInstruction: """ @@ -59,7 +59,7 @@ def get_checkpoint_instruction( Args: training_progress (TrainingProgress): The training progress. - evaluation_result (Dict[str, EvaluationResultBatch] | None, optional): + evaluation_result (dict[str, EvaluationResultBatch] | None, optional): The evaluation result. Defaults to None. early_stoppping_criterion_fulfilled (bool, optional): Whether the early stopping criterion is fulfilled. Defaults to False. @@ -67,7 +67,7 @@ def get_checkpoint_instruction( Returns: CheckpointingInstruction: The generated checkpointing instruction. """ - checkpoints_to_delete: List[TrainingProgress] = [] + checkpoints_to_delete: list[TrainingProgress] = [] save_current = True if self.k > 0: @@ -100,7 +100,7 @@ def __init__(self, k: int): def get_checkpoint_instruction( self, training_progress: TrainingProgress, - evaluation_result: Dict[str, EvaluationResultBatch] | None = None, + evaluation_result: dict[str, EvaluationResultBatch] | None = None, early_stoppping_criterion_fulfilled: bool = False, ) -> CheckpointingInstruction: """ @@ -108,7 +108,7 @@ def get_checkpoint_instruction( Args: training_progress (TrainingProgress): The training progress. - evaluation_result (Dict[str, EvaluationResultBatch] | None, optional): + evaluation_result (dict[str, EvaluationResultBatch] | None, optional): The evaluation result. Defaults to None. early_stoppping_criterion_fulfilled (bool, optional): Whether the early stopping criterion is fulfilled. Defaults to False. diff --git a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py index dc3b9de0c..556ba5b88 100644 --- a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py +++ b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List import torch import torch.nn as nn @@ -17,7 +16,7 @@ class FSDPCheckpointLoading(CheckpointLoadingIF): def __init__( self, global_rank: int, - block_names: List[str], + block_names: list[str], mixed_precision_settings: MixedPrecisionSettings, sharding_strategy: ShardingStrategy, ): @@ -26,7 +25,7 @@ def __init__( Args: global_rank (int): The global rank of the process. - block_names (List[str]): The names of the blocks. + block_names (list[str]): The names of the blocks. mixed_precision_settings (MixedPrecisionSettings): The settings for mixed precision. sharding_strategy (ShardingStrategy): The sharding strategy. diff --git a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py index 684847f08..0c5757a9b 100644 --- a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py +++ b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py @@ -1,6 +1,5 @@ from enum import Enum from pathlib import Path -from typing import List import torch import torch.distributed as dist @@ -124,7 +123,7 @@ def _save_checkpoint(self, model: FSDP, optimizer: Optimizer, training_progress: # leading to wrong throughput measurements. dist.barrier() - def _get_paths_to_delete(self, training_progress: TrainingProgress) -> List[Path]: + def _get_paths_to_delete(self, training_progress: TrainingProgress) -> list[Path]: return [ self._get_checkpointing_path( experiment_id=self.experiment_id, diff --git a/src/modalities/config/component_factory.py b/src/modalities/config/component_factory.py index 94ed71e69..c8c79604d 100644 --- a/src/modalities/config/component_factory.py +++ b/src/modalities/config/component_factory.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Type, TypeVar, Union +from typing import Any, Type, TypeVar from pydantic import BaseModel @@ -19,12 +19,12 @@ def __init__(self, registry: Registry) -> None: """ self.registry = registry - def build_components(self, config_dict: Dict, components_model_type: Type[BaseModelChild]) -> BaseModelChild: + def build_components(self, config_dict: dict, components_model_type: Type[BaseModelChild]) -> BaseModelChild: """Builds the components from a config dictionary. All components specified in `components_model_type` are built from the config dictionary in a recursive manner. Args: - config_dict (Dict): Dictionary with the configuration of the components. + config_dict (dict): dictionary with the configuration of the components. components_model_type (Type[BaseModelChild]): Base model type defining the components to be build. Returns: @@ -35,7 +35,7 @@ def build_components(self, config_dict: Dict, components_model_type: Type[BaseMo components = components_model_type(**component_dict) return components - def _build_config(self, config_dict: Dict, component_names: List[str]) -> Dict[str, Any]: + def _build_config(self, config_dict: dict, component_names: list[str]) -> dict[str, Any]: component_dict_filtered = {name: config_dict[name] for name in component_names} components, _ = self._build_component( current_component_config=component_dict_filtered, @@ -47,10 +47,10 @@ def _build_config(self, config_dict: Dict, component_names: List[str]) -> Dict[s def _build_component( self, - current_component_config: Union[Dict, List, Any], - component_config: Union[Dict, List, Any], - top_level_components: Dict[str, Any], - traversal_path: List, + current_component_config: dict | list | Any, + component_config: dict | list | Any, + top_level_components: dict[str, Any], + traversal_path: list, ) -> Any: # build sub components first via recursion if isinstance(current_component_config, dict): @@ -130,16 +130,16 @@ def _build_component( return current_component_config, top_level_components @staticmethod - def _is_component_config(config_dict: Dict) -> bool: + def _is_component_config(config_dict: dict) -> bool: # TODO instead of field checks, we should introduce an enum for the config type. return "component_key" in config_dict.keys() @staticmethod - def _is_reference_config(config_dict: Dict) -> bool: + def _is_reference_config(config_dict: dict) -> bool: # TODO instead of field checks, we should introduce an enum for the config type. return {"instance_key", "pass_type"} == config_dict.keys() - def _instantiate_component_config(self, component_key: str, variant_key: str, config_dict: Dict) -> BaseModel: + def _instantiate_component_config(self, component_key: str, variant_key: str, config_dict: dict) -> BaseModel: component_config_type: Type[BaseModel] = self.registry.get_config(component_key, variant_key) self._assert_valid_config_keys( component_key=component_key, @@ -151,7 +151,7 @@ def _instantiate_component_config(self, component_key: str, variant_key: str, co return comp_config def _assert_valid_config_keys( - self, component_key: str, variant_key: str, config_dict: Dict, component_config_type: Type[BaseModelChild] + self, component_key: str, variant_key: str, config_dict: dict, component_config_type: Type[BaseModelChild] ) -> None: required_keys = [] optional_keys = [] @@ -178,7 +178,7 @@ def _instantiate_component(self, component_key: str, variant_key: str, component return component @staticmethod - def _base_model_to_dict(base_model: BaseModel) -> Dict: + def _base_model_to_dict(base_model: BaseModel) -> dict: # converts top level structure of base_model into dictionary while maintaining substructure output = {} for name, _ in base_model.model_fields.items(): diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index df1197066..7d36c97f9 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -1,7 +1,7 @@ import os from functools import partial from pathlib import Path -from typing import Annotated, Dict, List, Literal, Optional, Tuple +from typing import Annotated, Literal, Optional import torch from omegaconf import OmegaConf @@ -82,7 +82,7 @@ def parse_device(cls, device) -> PydanticPytorchDeviceType: class FSDPCheckpointLoadingConfig(BaseModel): global_rank: Annotated[int, Field(strict=True, ge=0)] - block_names: List[str] + block_names: list[str] mixed_precision_settings: MixedPrecisionSettings sharding_strategy: ShardingStrategy @@ -117,19 +117,19 @@ class CheckpointSavingConfig(BaseModel): class AdamOptimizerConfig(BaseModel): lr: float wrapped_model: PydanticPytorchModuleType - betas: Tuple[float, float] + betas: tuple[float, float] eps: float weight_decay: float - weight_decay_groups_excluded: List[str] + weight_decay_groups_excluded: list[str] class AdamWOptimizerConfig(BaseModel): lr: float wrapped_model: PydanticPytorchModuleType - betas: Tuple[float, float] + betas: tuple[float, float] eps: float weight_decay: float - weight_decay_groups_excluded: List[str] + weight_decay_groups_excluded: list[str] class DummyLRSchedulerConfig(BaseModel): @@ -146,17 +146,17 @@ class StepLRSchedulerConfig(BaseModel): class OneCycleLRSchedulerConfig(BaseModel): optimizer: PydanticOptimizerIFType - max_lr: Annotated[float, Field(strict=True, gt=0.0)] | List[Annotated[float, Field(strict=True, gt=0.0)]] + max_lr: Annotated[float, Field(strict=True, gt=0.0)] | list[Annotated[float, Field(strict=True, gt=0.0)]] total_steps: Optional[Annotated[int, Field(strict=True, gt=0)]] = None epochs: Optional[Annotated[int, Field(strict=True, gt=0)]] = None steps_per_epoch: Optional[Annotated[int, Field(strict=True, gt=0)]] = None pct_start: Annotated[float, Field(strict=True, gt=0.0, le=1.0)] anneal_strategy: str cycle_momentum: bool = True - base_momentum: Annotated[float, Field(strict=True, gt=0)] | List[ + base_momentum: Annotated[float, Field(strict=True, gt=0)] | list[ Annotated[float, Field(strict=True, gt=0.0)] ] = 0.85 - max_momentum: Annotated[float, Field(strict=True, gt=0.0)] | List[ + max_momentum: Annotated[float, Field(strict=True, gt=0.0)] | list[ Annotated[float, Field(strict=True, gt=0.0)] ] = 0.95 div_factor: Annotated[float, Field(strict=True, gt=0.0)] @@ -206,7 +206,7 @@ class FSDPWrappedModelConfig(BaseModel): sync_module_states: bool mixed_precision_settings: MixedPrecisionSettings sharding_strategy: ShardingStrategy - block_names: List[str] + block_names: list[str] @field_validator("mixed_precision_settings", mode="before") def parse_mixed_precision_setting_by_name(cls, name): @@ -236,7 +236,7 @@ class WeightInitializedModelConfig(BaseModel): class ActivationCheckpointedModelConfig(BaseModel): model: PydanticFSDPModuleType - activation_checkpointing_modules: Optional[List[str]] = Field(default_factory=list) + activation_checkpointing_modules: Optional[list[str]] = Field(default_factory=list) class PreTrainedHFTokenizerConfig(BaseModel): @@ -244,7 +244,7 @@ class PreTrainedHFTokenizerConfig(BaseModel): max_length: Optional[Annotated[int, Field(strict=True, ge=0)]] = None truncation: bool = False padding: bool | str = False - special_tokens: Optional[Dict[str, str]] = None + special_tokens: Optional[dict[str, str]] = None class PreTrainedSPTokenizerConfig(BaseModel): @@ -329,14 +329,14 @@ class DummyProgressSubscriberConfig(BaseModel): class SimpleProgressSubscriberConfig(BaseModel): train_dataloader: PydanticLLMDataLoaderIFType - eval_dataloaders: Optional[List[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) + eval_dataloaders: Optional[list[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) world_size: int global_num_seen_samples: int local_rank: int class RichProgressSubscriberConfig(BaseModel): - eval_dataloaders: Optional[List[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) + eval_dataloaders: Optional[list[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) train_dataloader_tag: str num_seen_steps: Annotated[int, Field(strict=True, ge=0)] num_target_steps: Annotated[int, Field(strict=True, gt=0)] @@ -361,7 +361,7 @@ class RichResultSubscriberConfig(BaseModel): global_rank: int -def load_app_config_dict(config_file_path: Path) -> Dict: +def load_app_config_dict(config_file_path: Path) -> dict: """Load the application configuration from the given YAML file. The function defines custom resolvers for the OmegaConf library to resolve environment variables and Modalities-specific variables. @@ -370,7 +370,7 @@ def load_app_config_dict(config_file_path: Path) -> Dict: config_file_path (Path): YAML config file. Returns: - Dict: Dictionary representation of the config file. + dict: Dictionary representation of the config file. """ def cuda_env_resolver_fun(var_name: str) -> int: diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index 67b255c30..91a30c259 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Annotated, Any, Dict, List, Optional, Union +from typing import Annotated, Any, Optional from pydantic import BaseModel, ConfigDict, Field, FilePath, field_validator, model_validator, root_validator @@ -69,7 +69,7 @@ class Config: extra = "allow" @root_validator(pre=True) - def _validate_all_paths(cls, values: Dict[str, Any]) -> Dict[str, Any]: + def _validate_all_paths(cls, values: dict[str, Any]) -> dict[str, Any]: for field_name, value in values.items(): if isinstance(value, str): # If a value is a string, convert it to Path values[field_name] = Path(value) @@ -83,7 +83,7 @@ class WarmstartCheckpointPaths(BaseModel): experiment_id: str config_file_path: FilePath - referencing_keys: Dict[str, str] + referencing_keys: dict[str, str] cuda_env: CudaEnvSettings paths: Paths intervals: Intervals @@ -168,10 +168,10 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel wrapped_model: PydanticPytorchModuleType optimizer: PydanticOptimizerIFType scheduler: PydanticLRSchedulerIFType - loss_fn: Union[PydanticLossIFType, List[PydanticLossIFType]] + loss_fn: PydanticLossIFType | list[PydanticLossIFType] train_dataset: PydanticDatasetIFType train_dataloader: PydanticLLMDataLoaderIFType - eval_dataloaders: List[PydanticLLMDataLoaderIFType] + eval_dataloaders: list[PydanticLLMDataLoaderIFType] progress_subscriber: PydanticMessageSubscriberIFType evaluation_subscriber: PydanticMessageSubscriberIFType checkpoint_saving: PydanticCheckpointSavingIFType @@ -212,7 +212,7 @@ class TextGenerationSettings(BaseModel): model_path: FilePath sequence_length: int device: PydanticPytorchDeviceType - referencing_keys: Dict[str, str] + referencing_keys: dict[str, str] # avoid warning about protected namespace 'model_', see # https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces @@ -246,10 +246,10 @@ def __init__( self.training_progress = training_progress def get_report(self) -> str: - def _get_formatted_dict_str(d: Dict[str, Any]) -> str: + def _get_formatted_dict_str(d: dict[str, Any]) -> str: return "\n\t".join([f"{k}: {v}" for k, v in d.items()]) - def _get_formatted_list_str(lst: List[str]) -> str: + def _get_formatted_list_str(lst: list[str]) -> str: return "\n\t".join(lst) training_target_str = _get_formatted_dict_str(dict(self.training_target)) @@ -273,7 +273,7 @@ def _get_formatted_list_str(lst: List[str]) -> str: ) return report - def _get_issue_warnings(self) -> List[str]: + def _get_issue_warnings(self) -> list[str]: issue_warnings = [] num_tokens = ( self.step_profile.local_train_micro_batch_size diff --git a/src/modalities/config/utils.py b/src/modalities/config/utils.py index fe47cfaf3..a1d414fe2 100644 --- a/src/modalities/config/utils.py +++ b/src/modalities/config/utils.py @@ -1,10 +1,10 @@ -from typing import Any, Dict +from typing import Any import torch from pydantic import BaseModel -def convert_base_model_config_to_dict(config: BaseModel) -> Dict[Any, Any]: +def convert_base_model_config_to_dict(config: BaseModel) -> dict[Any, Any]: """ "Converts non-recursively a Pydantic BaseModel to a dictionary.""" return {key: getattr(config, key) for key in config.model_dump().keys()} diff --git a/src/modalities/dataloader/create_packed_data.py b/src/modalities/dataloader/create_packed_data.py index 33775fcd7..c695cc163 100644 --- a/src/modalities/dataloader/create_packed_data.py +++ b/src/modalities/dataloader/create_packed_data.py @@ -6,7 +6,7 @@ import warnings from io import BufferedWriter from pathlib import Path -from typing import Callable, Iterator, List, Optional, Tuple +from typing import Callable, Iterator, Optional import jq import numpy as np @@ -197,8 +197,8 @@ def _writer_thread(self, dst_path: Path) -> Callable: def writer(): # writes a batch received from the processed_samples_queue to the destination file def _write_batch( - batch: List[Tuple[int, bytes]], prev_line_id: int, curr_offset: int, index_list: List, f: BufferedWriter - ) -> Tuple[int, int]: + batch: list[tuple[int, bytes]], prev_line_id: int, curr_offset: int, index_list: list, f: BufferedWriter + ) -> tuple[int, int]: # write the tokens for each document for line_id, tokens_as_bytes in batch: if prev_line_id + 1 != line_id: @@ -293,7 +293,7 @@ def _process_thread(self, process_id: int): f"Raised the following error: {exception=}" ) - def _update_data_length_in_pre_allocated_header(self, dst_path: Path, index_list: List[Tuple[int, int]]): + def _update_data_length_in_pre_allocated_header(self, dst_path: Path, index_list: list[tuple[int, int]]): # Update the length of the data section in the pre-allocated header of the destination file. # The data segment length is sum of the starting position and the length of the last document. length_of_byte_encoded_data_section = index_list[-1][0] + index_list[-1][1] @@ -356,18 +356,18 @@ def __init__(self, data_path: Path): pkl_encoded_index = f.read() # contains the start offset and length of each segment # as byte positions in the data section - self.index_base: List[Tuple[int, int]] = pickle.loads(pkl_encoded_index) + self.index_base: list[tuple[int, int]] = pickle.loads(pkl_encoded_index) # initialize memmapped data section self.data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,)) -def join_embedded_stream_data(stream_data: List[EmbeddedStreamData], target_file: Path, chunk_size: int = 2048): +def join_embedded_stream_data(stream_data: list[EmbeddedStreamData], target_file: Path, chunk_size: int = 2048): """ Joins the embedded stream data into a single file. Args: - stream_data (List[EmbeddedStreamData]): A list of EmbeddedStreamData objects representing the stream data. + stream_data (list[EmbeddedStreamData]): A list of EmbeddedStreamData objects representing the stream data. target_file (Path): The target file to write the joined data to. chunk_size (int, optional): The size of each data chunk. Defaults to 2048. @@ -391,7 +391,7 @@ def join_embedded_stream_data(stream_data: List[EmbeddedStreamData], target_file num_entries = sum(len(d.index_base) for d in stream_data) - def index_stream_generator() -> Iterator[Tuple[int, int]]: + def index_stream_generator() -> Iterator[tuple[int, int]]: # generates a stream of index offsets and segment lengths. curr_offset = 0 for embedded_stream_data in stream_data: diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index f142db6a9..ef2cedc39 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional, Union +from typing import Iterable, Optional import webdataset as wd from torch.utils.data import Dataset, DistributedSampler, Sampler @@ -13,13 +13,14 @@ class DataLoaderIF: class LLMDataLoader(DataLoader[T_co], DataLoaderIF): """LLMDataLoader is a custom DataLoader class that extends the PyTorch DataLoader class.""" + def __init__( self, dataloader_tag: str, batch_sampler: ResumableBatchSampler, dataset: Dataset[T_co], batch_size: Optional[int] = 1, - sampler: Union[Sampler, Iterable, None] = None, + sampler: Sampler | Iterable | None = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, pin_memory: bool = False, @@ -41,7 +42,7 @@ def __init__( batch_sampler (ResumableBatchSampler): The batch sampler used for sampling batches. dataset (Dataset[T_co]): The dataset to load the data from. batch_size (Optional[int], optional): The number of samples per batch. Defaults to 1. - sampler (Union[Sampler, Iterable, None], optional): The sampler used for sampling data. Defaults to None. + sampler (Sampler | Iterable | None, optional): The sampler used for sampling data. Defaults to None. num_workers (int, optional): The number of worker processes to use for data loading. Defaults to 0. collate_fn (Optional[_collate_fn_t], optional): The function used to collate the data samples. Defaults to None. diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index aedb97bff..e157e7fcb 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -5,7 +5,7 @@ import re from enum import Enum from pathlib import Path -from typing import Annotated, Dict, List, Optional, Tuple, Union +from typing import Annotated, Optional import decord import jq @@ -73,13 +73,13 @@ class DummySampleConfig(BaseModel): Attributes: sample_key (str): The key of the sample. - sample_shape (Tuple[int, ...]): The shape of the sample. + sample_shape (tuple[int, ...]): The shape of the sample. sample_type (DummySampleDataType): The type of the sample. """ sample_key: str - sample_shape: Tuple[int, ...] + sample_shape: tuple[int, ...] sample_type: DummySampleDataType @@ -89,24 +89,24 @@ class DummyDatasetConfig(BaseModel): Attributes: num_samples (int): The number of samples in the dataset. - sample_definition (List[DummySampleConfig]): The list of sample definitions in the dataset. + sample_definition (list[DummySampleConfig]): The list of sample definitions in the dataset. """ num_samples: int - sample_definition: List[DummySampleConfig] + sample_definition: list[DummySampleConfig] class DummyDataset(Dataset): """DummyDataset class.""" - def __init__(self, num_samples: int, sample_definition: Tuple[DummySampleConfig]): + def __init__(self, num_samples: int, sample_definition: tuple[DummySampleConfig]): """ Initializes a DummyDataset object with the given number of samples and sample definition. When calling the __getitem__ method, the dataset will return a random sample based on the sample definition. Args: num_samples (int): The number of samples in the dataset. - sample_definition (Tuple[DummySampleConfig]): A list of tuples defining the dataset output. + sample_definition (tuple[DummySampleConfig]): A list of tuples defining the dataset output. Each touple contains the sample key, shape and data type. Returns: @@ -127,7 +127,7 @@ def __len__(self) -> int: """ return self.num_samples - def __getitem__(self, idx: int) -> Dict: + def __getitem__(self, idx: int) -> dict: """ Retrieves an item from the dataset at the specified index. @@ -142,7 +142,7 @@ def __getitem__(self, idx: int) -> Dict: """ return self._create_random_sample() - def _create_random_sample(self) -> Dict: + def _create_random_sample(self) -> dict: # creates a random sample based on the sample definition sample = dict() for s in self.sample_definition: @@ -259,7 +259,7 @@ def __init__(self, raw_data_path: Path, sample_key: str): ) self._index = self._generate_packing_index() - def _generate_packing_index(self) -> List[Tuple[int, int]]: + def _generate_packing_index(self) -> list[tuple[int, int]]: # Generates the packing index for the dataset. # The index is list of tuples, where each tuple contains the offset and length in bytes. @@ -329,7 +329,7 @@ def __init__(self, raw_data_path: Path, sample_key: str, block_size: int): self.block_size = block_size super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) - def _generate_packing_index(self) -> List[Tuple[int, int]]: + def _generate_packing_index(self) -> list[tuple[int, int]]: # Generates the packing index for the dataset. # A list of tuples representing the index, where each tuple contains the offset and length in bytes. @@ -360,7 +360,7 @@ def __init__(self, raw_data_path: Path, sample_key: str, block_size: int): self.block_size = block_size super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) - def _generate_packing_index(self) -> List[Tuple[int, int]]: + def _generate_packing_index(self) -> list[tuple[int, int]]: index = [] curr_offset = self.HEADER_SIZE_IN_BYTES curr_len = 0 @@ -410,22 +410,22 @@ class Transform: class ImageTransformConfig(TransformConfig): - input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224 + input_size: int | tuple[int, int] | tuple[int, int, int] = 224 is_training: bool = False no_aug: bool = False train_crop_mode: Optional[str] = None - scale: Optional[Tuple[float, float]] = None - ratio: Optional[Tuple[float, float]] = None + scale: Optional[tuple[float, float]] = None + ratio: Optional[tuple[float, float]] = None hflip: float = 0.5 vflip: float = 0.0 - color_jitter: Union[float, Tuple[float, ...]] = 0.4 + color_jitter: float | tuple[float, ...] = 0.4 color_jitter_prob: Optional[float] = None grayscale_prob: float = 0.0 gaussian_blur_prob: float = 0.0 auto_augment: Optional[str] = None interpolation: str = "bilinear" - mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN - std: Tuple[float, ...] = IMAGENET_DEFAULT_STD + mean: tuple[float, ...] = IMAGENET_DEFAULT_MEAN + std: tuple[float, ...] = IMAGENET_DEFAULT_STD re_prob: float = 0.0 re_mode: str = "const" re_count: int = 1 @@ -538,7 +538,7 @@ def __call__(self, video): class VideoTransformConfig(TransformConfig): - input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224 + input_size: int | tuple[int, int] | tuple[int, int, int] = 224 is_training: bool = False num_frames: int = 16 @@ -546,7 +546,7 @@ class VideoTransformConfig(TransformConfig): class VideoTransform(Transform): def __init__( self, - input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224, + input_size: int | tuple[int, int] | tuple[int, int, int] = 224, is_training: bool = False, num_frames: int = 16, ): @@ -652,9 +652,9 @@ def __iter__(self): class MultimodalWebDatasetBuilderConfig(BaseModel): - urls: Union[List[str], str] - modality_key_mapping: Dict[ModalityEnum, Tuple[str, str]] - modality_transforms: Dict[ModalityEnum, PydanticTransformIFType] + urls: list[str] | str + modality_key_mapping: dict[ModalityEnum, tuple[str, str]] + modality_transforms: dict[ModalityEnum, PydanticTransformIFType] is_audio_video: Optional[bool] = False num_samples: Annotated[int, Field(ge=1)] @@ -663,9 +663,9 @@ class MultimodalWebDatasetBuilderConfig(BaseModel): class MultimodalWebDatasetBuilder: def __init__( self, - urls: Union[List[str], str], - modality_key_mapping: Dict[str, Tuple[str, str]], - modality_transforms: Dict[str, Transform], + urls: list[str] | str, + modality_key_mapping: dict[str, tuple[str, str]], + modality_transforms: dict[str, Transform], is_audio_video: bool, num_samples: int, ): @@ -813,9 +813,9 @@ def dummy_nodesplitter(src, group=None): class MultimodalWebDatasetConfig(BaseModel): - builders: List[PydanticMultimodalWebDatasetBuilderIFType] + builders: list[PydanticMultimodalWebDatasetBuilderIFType] batch_size: Optional[int] = None - mixing_ratios: Optional[List[float]] = None + mixing_ratios: Optional[list[float]] = None shardshuffle: int = 100 repeat: bool = False resample: bool = True @@ -826,9 +826,9 @@ class MultimodalWebDatasetConfig(BaseModel): class MultimodalWebDataset(wds.DataPipeline, wds.compat.FluidInterface): def __init__( self, - builders: List[MultimodalWebDatasetBuilder], + builders: list[MultimodalWebDatasetBuilder], batch_size: int = None, - mixing_ratios: Optional[List[float]] = None, + mixing_ratios: Optional[list[float]] = None, shardshuffle: int = 100, repeat: bool = False, resample: bool = True, diff --git a/src/modalities/dataloader/dataset_factory.py b/src/modalities/dataloader/dataset_factory.py index 7ce204c23..d5580b8c8 100644 --- a/src/modalities/dataloader/dataset_factory.py +++ b/src/modalities/dataloader/dataset_factory.py @@ -1,6 +1,6 @@ import pickle from pathlib import Path -from typing import List, Optional, Tuple +from typing import Optional from transformers import PreTrainedTokenizer @@ -17,13 +17,13 @@ class DatasetFactory: """DatasetFactory for building the different dataset types.""" @staticmethod - def get_dummy_dataset(num_samples: int, sample_definition: Tuple[DummySampleConfig]) -> DummyDataset: + def get_dummy_dataset(num_samples: int, sample_definition: tuple[DummySampleConfig]) -> DummyDataset: """ Returns a DummyDataset object. Args: num_samples (int): The number of samples the dataset should generate. - sample_definition (Tuple[DummySampleConfig]): A list of tuples defining the dataset output. + sample_definition (tuple[DummySampleConfig]): A list of tuples defining the dataset output. Each tuple contains the sample key, shape and data type. Returns: @@ -64,7 +64,7 @@ def get_mem_map_dataset( return dataset @staticmethod - def get_raw_index(raw_index_path: Path) -> List[Tuple[int, int]]: + def get_raw_index(raw_index_path: Path) -> list[tuple[int, int]]: with raw_index_path.open("rb") as f: index = pickle.load(f) return index diff --git a/src/modalities/dataloader/large_file_lines_reader.py b/src/modalities/dataloader/large_file_lines_reader.py index 220f95bbf..3d896dcd3 100644 --- a/src/modalities/dataloader/large_file_lines_reader.py +++ b/src/modalities/dataloader/large_file_lines_reader.py @@ -1,7 +1,7 @@ import pickle from abc import ABC, abstractmethod from pathlib import Path -from typing import List, Optional +from typing import Optional class BaseReader(ABC): @@ -10,7 +10,7 @@ def __len__(self) -> int: raise NotImplementedError @abstractmethod - def __getitem__(self, key: int | slice) -> str | List[str]: + def __getitem__(self, key: int | slice) -> str | list[str]: raise NotImplementedError @@ -72,7 +72,7 @@ def __len__(self) -> int: """ return len(self.index) - def __getitem__(self, key: int | slice) -> str | List[str]: + def __getitem__(self, key: int | slice) -> str | list[str]: """ Retrieves an item from the LargeFileLinesReader. @@ -80,7 +80,7 @@ def __getitem__(self, key: int | slice) -> str | List[str]: key (int | slice): The index or slice used to retrieve the item(s). Returns: - str | List[str]: The item(s) retrieved from the LargeFileLinesReader. + str | list[str]: The item(s) retrieved from the LargeFileLinesReader. Raises: IndexError: If the key is out of range. diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index 42f4e710e..ba69b1399 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -1,5 +1,3 @@ -from typing import Dict, List - import torch import torch.distributed as dist import torch.nn as nn @@ -36,7 +34,7 @@ def evaluate_batch( self, batch: DatasetBatch, model: nn.Module, - loss_fun: List[Loss], + loss_fun: list[Loss], ) -> torch.Tensor: """Evaluate a single batch by forwarding it through the model and calculating the loss. @@ -72,22 +70,22 @@ def evaluate_batch( def evaluate( self, model: nn.Module, - data_loaders: List[LLMDataLoader], - loss_fun: List[Loss], + data_loaders: list[LLMDataLoader], + loss_fun: list[Loss], num_train_steps_done: int, - ) -> Dict[str, EvaluationResultBatch]: + ) -> dict[str, EvaluationResultBatch]: """Evaluate the model on a set of datasets. Args: model (nn.Module): The model to evaluate - data_loaders (List[LLMDataLoader]): List of dataloaders to evaluate the model on + data_loaders (list[LLMDataLoader]): List of dataloaders to evaluate the model on loss_fun (Callable[[InferenceResultBatch], torch.Tensor]): The loss function to calculate the loss num_train_steps_done (int): The number of training steps done so far for logging purposes Returns: - Dict[str, EvaluationResultBatch]: A dictionary containing the evaluation results for each dataloader + dict[str, EvaluationResultBatch]: A dictionary containing the evaluation results for each dataloader """ - result_dict: Dict[str, EvaluationResultBatch] = {} + result_dict: dict[str, EvaluationResultBatch] = {} model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/src/modalities/gym.py b/src/modalities/gym.py index 8f69082dc..270e1dae1 100644 --- a/src/modalities/gym.py +++ b/src/modalities/gym.py @@ -1,6 +1,6 @@ from datetime import datetime from functools import partial -from typing import Callable, List +from typing import Callable import torch.nn as nn from torch.optim import Optimizer @@ -41,7 +41,7 @@ def run( checkpointing_interval_in_steps: int, evaluation_interval_in_steps: int, train_data_loader: LLMDataLoader, - evaluation_data_loaders: List[LLMDataLoader], + evaluation_data_loaders: list[LLMDataLoader], checkpoint_saving: CheckpointSaving, ): """Runs the model training, including evaluation and checkpointing. @@ -54,7 +54,7 @@ def run( checkpointing_interval_in_steps (int): Interval in steps to save checkpoints. evaluation_interval_in_steps (int): Interval in steps to perform evaluation. train_data_loader (LLMDataLoader): Data loader with the training data. - evaluation_data_loaders (List[LLMDataLoader]): List of data loaders with the evaluation data. + evaluation_data_loaders (list[LLMDataLoader]): List of data loaders with the evaluation data. checkpoint_saving (CheckpointSaving): Routine for saving checkpoints. """ evaluation_callback: Callable[[int], None] = partial( @@ -109,7 +109,7 @@ def _run_evaluation( self, model: nn.Module, num_train_steps_done: int, - evaluation_data_loaders: List[LLMDataLoader], + evaluation_data_loaders: list[LLMDataLoader], evaluation_interval_in_steps: int, ): if num_train_steps_done % evaluation_interval_in_steps == 0: diff --git a/src/modalities/logging_broker/message_broker.py b/src/modalities/logging_broker/message_broker.py index 7b38e58ff..d81f86b71 100644 --- a/src/modalities/logging_broker/message_broker.py +++ b/src/modalities/logging_broker/message_broker.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import Dict, List from modalities.logging_broker.messages import Message, MessageTypes from modalities.logging_broker.subscriber import MessageSubscriberIF @@ -22,7 +21,7 @@ class MessageBroker(MessageBrokerIF): """The MessageBroker sends notifications to its subscribers.""" def __init__(self) -> None: - self.subscriptions: Dict[MessageTypes, List[MessageSubscriberIF]] = defaultdict(list) + self.subscriptions: dict[MessageTypes, list[MessageSubscriberIF]] = defaultdict(list) def add_subscriber(self, subscription: MessageTypes, subscriber: MessageSubscriberIF): """Adds a single subscriber.""" diff --git a/src/modalities/logging_broker/subscriber.py b/src/modalities/logging_broker/subscriber.py index 9d62c17a5..5bdc885ae 100644 --- a/src/modalities/logging_broker/subscriber.py +++ b/src/modalities/logging_broker/subscriber.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, TypeVar +from typing import Any, Generic, TypeVar from modalities.logging_broker.messages import Message @@ -14,5 +14,5 @@ def consume_message(self, message: Message[T]): raise NotImplementedError @abstractmethod - def consume_dict(self, mesasge_dict: Dict[str, Any]): + def consume_dict(self, mesasge_dict: dict[str, Any]): raise NotImplementedError diff --git a/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py b/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py index e996f481f..9a991fe0c 100644 --- a/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Tuple +from typing import Any from rich.console import Group from rich.live import Live @@ -14,15 +14,15 @@ class DummyProgressSubscriber(MessageSubscriberIF[ProgressUpdate]): def consume_message(self, message: Message[ProgressUpdate]): pass - def consume_dict(self, mesasge_dict: Dict[str, Any]): + def consume_dict(self, mesasge_dict: dict[str, Any]): pass class SimpleProgressSubscriber(MessageSubscriberIF[ProgressUpdate]): def __init__( self, - train_split_num_samples: Dict[str, int], - eval_splits_num_samples: Dict[str, int], + train_split_num_samples: dict[str, int], + eval_splits_num_samples: dict[str, int], ) -> None: self.train_split_num_samples = train_split_num_samples self.eval_splits_num_samples = eval_splits_num_samples @@ -61,8 +61,8 @@ class RichProgressSubscriber(MessageSubscriberIF[ProgressUpdate]): def __init__( self, - train_split_num_steps: Dict[str, Tuple[int, int]], - eval_splits_num_steps: Dict[str, int], + train_split_num_steps: dict[str, tuple[int, int]], + eval_splits_num_steps: dict[str, int], ) -> None: # train split progress bar self.train_splits_progress = Progress( @@ -132,5 +132,5 @@ def consume_message(self, message: Message[ProgressUpdate]): completed=batch_progress.num_steps_done, ) - def consume_dict(self, mesasge_dict: Dict[str, Any]): + def consume_dict(self, mesasge_dict: dict[str, Any]): raise NotImplementedError diff --git a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py index c7c5160ae..e44054913 100644 --- a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict +from typing import Any import rich import wandb @@ -18,7 +18,7 @@ def consume_message(self, message: Message[EvaluationResultBatch]): """Consumes a message from a message broker.""" pass - def consume_dict(self, mesasge_dict: Dict[str, Any]): + def consume_dict(self, mesasge_dict: dict[str, Any]): pass @@ -50,7 +50,7 @@ def consume_message(self, message: Message[EvaluationResultBatch]): if losses or metrics: rich.print(Panel(Group(*group_content))) - def consume_dict(self, mesasge_dict: Dict[str, Any]): + def consume_dict(self, mesasge_dict: dict[str, Any]): raise NotImplementedError @@ -75,7 +75,7 @@ def __init__( self.run.log_artifact(config_file_path, name=f"config_{wandb.run.id}", type="config") - def consume_dict(self, mesasge_dict: Dict[str, Any]): + def consume_dict(self, mesasge_dict: dict[str, Any]): for k, v in mesasge_dict.items(): self.run.config[k] = v diff --git a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py index 322df0c68..cbab71898 100644 --- a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py +++ b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import List, Optional +from typing import Optional from modalities.config.config import WandbMode from modalities.dataloader.dataloader import LLMDataLoader @@ -19,7 +19,7 @@ class ProgressSubscriberFactory: @staticmethod def get_rich_progress_subscriber( - eval_dataloaders: List[LLMDataLoader], + eval_dataloaders: list[LLMDataLoader], train_dataloader_tag: str, num_seen_steps: int, num_target_steps: int, @@ -42,7 +42,7 @@ def get_rich_progress_subscriber( @staticmethod def get_simple_progress_subscriber( train_dataloader: LLMDataLoader, - eval_dataloaders: List[LLMDataLoader], + eval_dataloaders: list[LLMDataLoader], world_size: int, global_num_seen_samples: int, local_rank: int, diff --git a/src/modalities/models/audio_transformer/audio_transformer_model.py b/src/modalities/models/audio_transformer/audio_transformer_model.py index 61e99bdad..5852aa01d 100644 --- a/src/modalities/models/audio_transformer/audio_transformer_model.py +++ b/src/modalities/models/audio_transformer/audio_transformer_model.py @@ -1,4 +1,4 @@ -from typing import Annotated, Dict +from typing import Annotated import torch from pydantic import BaseModel, Field @@ -194,8 +194,8 @@ def __init__( def forward( self, - inputs: Dict[str, tuple[torch.Tensor, torch.Tensor]], - ) -> Dict[str, tuple[torch.Tensor, torch.Tensor]]: + inputs: dict[str, tuple[torch.Tensor, torch.Tensor]], + ) -> dict[str, tuple[torch.Tensor, torch.Tensor]]: x = inputs[self.sample_key] # x.shape: B, T, D attn_key_mask = self._get_attn_key_mask(inputs["audio_len"]) # x.shape: B, T, D diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index 26ce17922..d6ca7debf 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -1,4 +1,4 @@ -from typing import Annotated, Dict, List, Optional, Tuple +from typing import Annotated, Optional import numpy as np import torch @@ -86,7 +86,7 @@ class CoCaConfig(BaseModel): image_text_cls_prediction_key: Optional[str] = None video_cls_prediction_key: Optional[str] = None video_text_cls_prediction_key: Optional[str] = None - modality_keys: List[str] + modality_keys: list[str] individual_datasets: Optional[bool] = False is_audio_video: Optional[bool] = False audio_encoder_config: Optional[AudioTransformerConfig] = None @@ -126,7 +126,7 @@ def __init__( image_text_cls_prediction_key: Optional[str], video_cls_prediction_key: Optional[str], video_text_cls_prediction_key: Optional[str], - modality_keys: List[str], + modality_keys: list[str], individual_datasets: Optional[bool], is_audio_video: Optional[bool], audio_encoder_config: Optional[AudioTransformerConfig], @@ -271,7 +271,7 @@ def _init_modality(self, encoder_class, encoder_config, n_queries): ) return encoder, queries, attn_pool - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the CoCa model. @@ -350,15 +350,15 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: ) return output - def _forward_encode_image(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + def _forward_encode_image(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: """ Encodes the input image using the vision encoder. Args: - inputs (dict[str, torch.Tensor]): Dictionary containing vision inputs. + inputs (dict[str, torch.Tensor]): dictionary containing vision inputs. Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple containing encoded vision embeddings and classification token. + tuple[torch.Tensor, torch.Tensor]: Tuple containing encoded vision embeddings and classification token. """ image_embd = self.image_encoder(inputs)[self.image_embd_prediction_key] queries = repeat(self.image_queries, "n d -> b n d", b=image_embd.shape[0]) @@ -366,21 +366,21 @@ def _forward_encode_image(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch. image_embd, image_cls_token = image_embd[:, :-1, :], F.normalize(image_embd[:, -1, :], dim=-1) return image_embd, image_cls_token - def _forward_encode_video(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + def _forward_encode_video(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: video_embd = self.video_encoder(inputs)[self.video_embd_prediction_key] queries = repeat(self.video_queries, "n d -> b n d", b=video_embd.shape[0]) video_embd = self.video_attn_pool(queries, context=video_embd) video_embd, video_cls_token = video_embd[:, :-1, :], F.normalize(video_embd[:, -1, :], dim=-1) return video_embd, video_cls_token - def _forward_encode_audio(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + def _forward_encode_audio(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: audio_embd = self.audio_encoder(inputs)[self.audio_embd_prediction_key] queries = repeat(self.audio_queries, "n d -> b n d", b=audio_embd.shape[0]) audio_embd = self.audio_attn_pool(queries, context=audio_embd) audio_embd, audio_cls_token = audio_embd[:, :-1, :], F.normalize(audio_embd[:, -1, :], dim=-1) return audio_embd, audio_cls_token - def _forward_encode_text(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + def _forward_encode_text(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: """ Encodes the input text using the text decoder. @@ -388,7 +388,7 @@ def _forward_encode_text(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.T inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing the encoded text tensor + tuple[torch.Tensor, torch.Tensor]: A tuple containing the encoded text tensor and the classification token tensor. """ text_embd = self.text_decoder(inputs)[self.text_embd_prediction_key] diff --git a/src/modalities/models/coca/collator.py b/src/modalities/models/coca/collator.py index 42b4fcdde..d77b7f481 100644 --- a/src/modalities/models/coca/collator.py +++ b/src/modalities/models/coca/collator.py @@ -1,5 +1,4 @@ from dataclasses import field -from typing import Dict, List import torch from pydantic import BaseModel @@ -13,14 +12,14 @@ class CoCaCollateFnConfig(BaseModel): Configuration class for CoCaCollateFn. Args: - sample_keys (List[str]): List of samples keys. - target_keys (List[str]): List of target keys. + sample_keys (list[str]): List of samples keys. + target_keys (list[str]): List of target keys. text_sample_key (str): Key for the text samples. text_target_key (str): Key for the text targets. """ - sample_keys: List[str] - target_keys: List[str] + sample_keys: list[str] + target_keys: list[str] text_sample_key: str text_target_key: str @@ -28,13 +27,13 @@ class CoCaCollateFnConfig(BaseModel): class CoCaCollatorFn(CollateFnIF): """Collator function for CoCa model.""" - def __init__(self, sample_keys: List[str], target_keys: List[str], text_sample_key: str, text_target_key: str): + def __init__(self, sample_keys: list[str], target_keys: list[str], text_sample_key: str, text_target_key: str): """ Initializes the CoCaCollatorFn object. Args: - sample_keys (List[str]): List of samples keys. - target_keys (List[str]): List of target keys. + sample_keys (list[str]): List of samples keys. + target_keys (list[str]): List of target keys. text_sample_key (str): Key for the text samples. text_target_key (str): Key for the text targets. @@ -58,12 +57,12 @@ def __init__(self, sample_keys: List[str], target_keys: List[str], text_sample_k self.text_sample_key = text_sample_key self.text_target_key = text_target_key - def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: """ Process a batch of data. Args: - batch (List[Dict[str, torch.Tensor]]): A list of dictionaries containing tensors + batch (list[dict[str, torch.Tensor]]): A list of dictionaries containing tensors representing the batch data. Returns: diff --git a/src/modalities/models/coca/multi_modal_decoder.py b/src/modalities/models/coca/multi_modal_decoder.py index dbf5409b9..981ed5398 100644 --- a/src/modalities/models/coca/multi_modal_decoder.py +++ b/src/modalities/models/coca/multi_modal_decoder.py @@ -1,5 +1,4 @@ from functools import partial -from typing import Dict import torch from torch import nn @@ -102,7 +101,7 @@ def forward(self, x: torch.Tensor, context: list[torch.Tensor] | torch.Tensor | if not self.with_context or self.add_extra_mlp: x = x + self.mlp(self.ln_2(x)) if self.with_context: - if isinstance(context, Dict): + if isinstance(context, dict): x = self.ln_3(x) x = x + self.cross_attn(x, context=context["audio"]) + self.cross_attn2(x, context=context["video"]) x = x + self.mlp_2(self.ln_4(x)) @@ -184,7 +183,7 @@ def __init__( ) self.lm_head = nn.Linear(in_features=n_embd, out_features=vocab_size, bias=False) - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the MultiModalTextDecoder module. diff --git a/src/modalities/models/coca/text_decoder.py b/src/modalities/models/coca/text_decoder.py index c21ef6871..e6b15c7ff 100644 --- a/src/modalities/models/coca/text_decoder.py +++ b/src/modalities/models/coca/text_decoder.py @@ -1,5 +1,3 @@ -from typing import Dict - import torch from torch import nn @@ -78,7 +76,7 @@ def __init__( ) ) - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the TextDecoder module. diff --git a/src/modalities/models/gpt2/collator.py b/src/modalities/models/gpt2/collator.py index 7211dd255..4e0256cb5 100644 --- a/src/modalities/models/gpt2/collator.py +++ b/src/modalities/models/gpt2/collator.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Dict, List import torch @@ -10,12 +9,12 @@ class CollateFnIF(ABC): """CollateFnIF class to define a collate function interface.""" @abstractmethod - def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: """ Process a batch of data. Args: - batch (List[Dict[str, torch.Tensor]]): A list of dictionaries containing tensors. + batch (list[dict[str, torch.Tensor]]): A list of dictionaries containing tensors. Returns: DatasetBatch: The processed batch of data. @@ -40,12 +39,12 @@ def __init__(self, sample_key: str, target_key: str): self.sample_key = sample_key self.target_key = target_key - def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: """ Process a batch of data. Args: - batch (List[Dict[str, torch.Tensor]]): A list of dictionaries containing tensors. + batch (list[dict[str, torch.Tensor]]): A list of dictionaries containing tensors. Returns: DatasetBatch: A processed batch of data where sample and target sequences are created. diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index ab43dfd1a..d388e93cb 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -1,7 +1,7 @@ import math from copy import deepcopy from enum import Enum -from typing import Annotated, Dict, List, Tuple +from typing import Annotated import torch import torch.nn as nn @@ -42,7 +42,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Perform forward pass for transforming queries/keys/values. @@ -52,7 +52,7 @@ def forward( v (torch.Tensor): The value tensor. Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the output tensors. + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the output tensors. """ pass @@ -65,7 +65,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass of the IdentityTransform which does not apply any transform. @@ -75,7 +75,7 @@ def forward( v (torch.Tensor): The value tensor. Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The tensors q, k, and v. + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The tensors q, k, and v. """ return q, k, v @@ -160,7 +160,7 @@ def apply_rotary_pos_emb(self, x, cos, sin): def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass of the RotaryTransform module. @@ -170,7 +170,7 @@ def forward( v (torch.Tensor): Value tensor. Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing the modified query tensor, key tensor, and value tensor. """ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k) @@ -213,7 +213,7 @@ class AttentionConfig(BaseModel): Configuration class for attention mechanism. Attributes: - qkv_transforms (List[QueryKeyValueTransformConfig]): List of configurations for query-key-value transforms. + qkv_transforms (list[QueryKeyValueTransformConfig]): List of configurations for query-key-value transforms. """ class QueryKeyValueTransformConfig(BaseModel): @@ -222,7 +222,7 @@ class QueryKeyValueTransformConfig(BaseModel): Attributes: type_hint (QueryKeyValueTransformType): The type hint for the transform. - config (Union[RotaryTransformConfig, IdentityTransformConfig]): The configuration for the transform. + config (RotaryTransformConfig | IdentityTransformConfig): The configuration for the transform. """ class IdentityTransformConfig(BaseModel): @@ -262,7 +262,7 @@ def parse_sharding_strategy_by_name(cls, name): type_hint: QueryKeyValueTransformType config: RotaryTransformConfig | IdentityTransformConfig - qkv_transforms: List[QueryKeyValueTransformConfig] + qkv_transforms: list[QueryKeyValueTransformConfig] class GPT2LLMConfig(BaseModel): @@ -422,7 +422,7 @@ def __init__( for transform_config in attention_config.qkv_transforms ) - def projection(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def projection(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Applies projections to the input tensor to get queries, keys, and values. @@ -430,7 +430,7 @@ def projection(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch x (torch.Tensor): The input tensor. Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the query, key, and value tensors. + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the query, key, and value tensors. """ # calculate query, key, values for all heads in batch and move head forward to be the batch dim return self.q_attn(x), self.k_attn(x), self.v_attn(x) @@ -438,7 +438,7 @@ def projection(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch @staticmethod def execute_qkv_transforms( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qkv_transforms: nn.ModuleList, n_head_q: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Applies a series of transformations to the query, key, and value tensors. @@ -450,7 +450,7 @@ def execute_qkv_transforms( n_head_q (int): The number of heads for the query tensors. Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the transformed query, key, and value tensors. """ batch_size, sequence_length, embedding_dim = q.size() @@ -826,16 +826,16 @@ def __init__( # not 100% sure what this is, so far seems to be harmless. TODO investigate self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying - def forward_impl(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward_impl(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass implementation of the GPT2LLM module. Args: - inputs (Dict[str, torch.Tensor]): A dictionary containing input tensors. + inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. - sample_key (str): Key for the input tensor containing token ids. Returns: - Dict[str, torch.Tensor]: A dictionary containing output tensors. + dict[str, torch.Tensor]: A dictionary containing output tensors. - prediction_key (str): Key for the output tensor containing logits. """ input_ids = inputs[self.sample_key] @@ -861,16 +861,16 @@ def forward_impl(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tenso logits = self.lm_head(x) return {self.prediction_key: logits} - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the GPT2LLM module. Args: - inputs (Dict[str, torch.Tensor]): A dictionary containing input tensors. + inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. - sample_key (str): Key for the input tensor containing token ids. Returns: - Dict[str, torch.Tensor]: A dictionary containing output tensors. + dict[str, torch.Tensor]: A dictionary containing output tensors. - prediction_key (str): Key for the output tensor containing logits. """ return self.forward_impl(inputs) diff --git a/src/modalities/models/gpt2/pretrained_gpt_model.py b/src/modalities/models/gpt2/pretrained_gpt_model.py index 7251229a8..ea896624b 100644 --- a/src/modalities/models/gpt2/pretrained_gpt_model.py +++ b/src/modalities/models/gpt2/pretrained_gpt_model.py @@ -1,5 +1,3 @@ -from typing import Dict - import torch from transformers import PreTrainedModel @@ -38,7 +36,7 @@ def forward(self, tensor): """ model_input = {"input_ids": tensor} - model_forward_output: Dict[str, torch.Tensor] = self.model.forward(model_input) + model_forward_output: dict[str, torch.Tensor] = self.model.forward(model_input) return model_forward_output[self.config.config.prediction_key] diff --git a/src/modalities/models/huggingface/huggingface_model.py b/src/modalities/models/huggingface/huggingface_model.py index 2d1bfa30c..9d9a74142 100644 --- a/src/modalities/models/huggingface/huggingface_model.py +++ b/src/modalities/models/huggingface/huggingface_model.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch from pydantic import BaseModel, ConfigDict @@ -102,26 +102,26 @@ def __init__( model_name, local_files_only=False, *model_args, **kwargs ) - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the model. Args: - inputs (Dict[str, torch.Tensor]): A dictionary containing input tensors. + inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. Returns: - Dict[str, torch.Tensor]: A dictionary containing output tensors. + dict[str, torch.Tensor]: A dictionary containing output tensors. """ output = self.huggingface_model.forward(inputs[self.sample_key]) return {self.prediction_key: output[self.huggingface_prediction_subscription_key]} @property - def fsdp_block_names(self) -> List[str]: + def fsdp_block_names(self) -> list[str]: """ Returns a list of FSDP block names. Returns: - List[str]: A list of FSDP block names. + list[str]: A list of FSDP block names. """ return self.huggingface_model._no_split_modules diff --git a/src/modalities/models/huggingface_adapters/hf_adapter.py b/src/modalities/models/huggingface_adapters/hf_adapter.py index a111e1f8a..09c751836 100644 --- a/src/modalities/models/huggingface_adapters/hf_adapter.py +++ b/src/modalities/models/huggingface_adapters/hf_adapter.py @@ -1,7 +1,7 @@ import json from dataclasses import dataclass from pathlib import PosixPath -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional import torch from transformers import PretrainedConfig, PreTrainedModel @@ -49,8 +49,8 @@ def to_json_string(self, use_diff: bool = True) -> str: return json.dumps(json_dict) def _convert_posixpath_to_str( - self, data_to_be_formatted: Union[Dict[str, Any], List[Any], PosixPath, Any] - ) -> Union[Dict[str, Any], List[Any], PosixPath, Any]: + self, data_to_be_formatted: dict[str, Any] | list[Any] | PosixPath | Any + ) -> dict[str, Any] | list[Any] | PosixPath | Any: # Recursively converts any PosixPath objects within a nested data structure to strings. if isinstance(data_to_be_formatted, dict): @@ -108,13 +108,13 @@ def forward( output_hidden_states (bool, optional): Whether to output hidden states. Defaults to False. Returns: - Union[ModalitiesModelOutput, torch.Tensor]: The output of the forward pass. + ModalitiesModelOutput | torch.Tensor: The output of the forward pass. """ # These parameters are required by HuggingFace. We do not use them and hence don't implement them. if output_attentions or output_hidden_states: raise NotImplementedError model_input = {"input_ids": input_ids, "attention_mask": attention_mask} - model_forward_output: Dict[str, torch.Tensor] = self.model.forward(model_input) + model_forward_output: dict[str, torch.Tensor] = self.model.forward(model_input) if return_dict: return ModalitiesModelOutput(**model_forward_output) else: @@ -122,7 +122,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor = None, **kwargs - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Prepares the inputs for generation. @@ -132,7 +132,7 @@ def prepare_inputs_for_generation( **kwargs: Additional keyword arguments. Returns: - Dict[str, Any]: A dictionary containing the prepared inputs for generation. + dict[str, Any]: A dictionary containing the prepared inputs for generation. Note: Implement in subclasses of :class:`~transformers.PreTrainedModel` @@ -151,10 +151,10 @@ class ModalitiesModelOutput(ModelOutput): Args: logits (torch.FloatTensor, optional): The logits output of the model. Defaults to None. - hidden_states (Tuple[torch.FloatTensor], optional): The hidden states output of the model. Defaults to None. - attentions (Tuple[torch.FloatTensor], optional): The attentions output of the model. Defaults to None. + hidden_states (tuple[torch.FloatTensor], optional): The hidden states output of the model. Defaults to None. + attentions (tuple[torch.FloatTensor], optional): The attentions output of the model. Defaults to None. """ logits: Optional[torch.FloatTensor] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index 0bbaccc34..fd7703e55 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -1,13 +1,13 @@ from abc import abstractmethod from enum import Enum -from typing import Dict, List, Optional +from typing import Optional import torch import torch.nn as nn from modalities.batch import DatasetBatch, InferenceResultBatch -WeightDecayGroups = Dict[str, List[str]] +WeightDecayGroups = dict[str, list[str]] class ActivationType(str, Enum): @@ -50,19 +50,19 @@ def weight_decay_groups(self) -> WeightDecayGroups: return self._weight_decay_groups @abstractmethod - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the model. Args: - inputs (Dict[str, torch.Tensor]): A dictionary containing input tensors. + inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. Returns: - Dict[str, torch.Tensor]: A dictionary containing output tensors. + dict[str, torch.Tensor]: A dictionary containing output tensors. """ raise NotImplementedError - def get_parameters(self) -> Dict[str, torch.Tensor]: + def get_parameters(self) -> dict[str, torch.Tensor]: """ Returns a dictionary of the model's parameters. diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 96ea93a56..0a02fd8ae 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List import torch import torch.distributed as dist @@ -46,7 +45,7 @@ def get_checkpointed_model( def get_fsdp_wrapped_model( model: nn.Module, sync_module_states: bool, - block_names: List[str], + block_names: list[str], mixed_precision_settings: MixedPrecisionSettings, sharding_strategy: ShardingStrategy, ) -> FSDP: @@ -56,7 +55,7 @@ def get_fsdp_wrapped_model( Args: model (nn.Module): The original model to be wrapped. sync_module_states (bool): Whether to synchronize module states across ranks. - block_names (List[str]): List of block names. + block_names (list[str]): List of block names. mixed_precision_settings (MixedPrecisionSettings): Mixed precision settings. sharding_strategy (ShardingStrategy): Sharding strategy. @@ -108,12 +107,12 @@ def get_weight_initalized_model(model: nn.Module, model_initializer: ModelInitia return model @staticmethod - def get_activation_checkpointed_model(model: FSDP, activation_checkpointing_modules: List[str]) -> FSDP: + def get_activation_checkpointed_model(model: FSDP, activation_checkpointing_modules: list[str]) -> FSDP: """Apply activation checkpointing to the given model (in-place operation). Args: model (FSDP): The FSDP-wrapped model to apply activation checkpointing to. - activation_checkpointing_modules (List[str]): List of module names to apply activation checkpointing to. + activation_checkpointing_modules (list[str]): List of module names to apply activation checkpointing to. Raises: ValueError: Activation checkpointing can only be applied to FSDP-wrapped models! diff --git a/src/modalities/models/utils.py b/src/modalities/models/utils.py index debf0cab0..77c75b6aa 100644 --- a/src/modalities/models/utils.py +++ b/src/modalities/models/utils.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Dict from pydantic import BaseModel @@ -22,12 +21,12 @@ class ModelTypeEnum(Enum): CHECKPOINTED_MODEL = "checkpointed_model" -def get_model_from_config(config: Dict, model_type: ModelTypeEnum): +def get_model_from_config(config: dict, model_type: ModelTypeEnum): """ Retrieves a model from the given configuration based on the specified model type. Args: - config (Dict): The configuration dictionary. + config (dict): The configuration dictionary. model_type (ModelTypeEnum): The type of the model to retrieve. Returns: diff --git a/src/modalities/models/vision_transformer/vision_transformer_model.py b/src/modalities/models/vision_transformer/vision_transformer_model.py index 616523a6a..ab09db8c3 100644 --- a/src/modalities/models/vision_transformer/vision_transformer_model.py +++ b/src/modalities/models/vision_transformer/vision_transformer_model.py @@ -1,5 +1,5 @@ from math import floor -from typing import Annotated, Dict, Optional, Tuple, Union +from typing import Annotated, Optional import torch from einops.layers.torch import Rearrange @@ -18,7 +18,7 @@ class VisionTransformerConfig(BaseModel): Args: sample_key (str): The key for the input sample. prediction_key (str): The key for the model prediction. - img_size (Union[Tuple[int, int], int], optional): The size of the input image. Defaults to 224. + img_size (tuple[int, int] | int optional): The size of the input image. Defaults to 224. n_classes (int, optional): The number of output classes. Defaults to 1000. n_layer (int): The number of layers in the model. Defaults to 12. attention_config (AttentionConfig, optional): The configuration for the attention mechanism. Defaults to None. @@ -34,7 +34,7 @@ class VisionTransformerConfig(BaseModel): sample_key: str prediction_key: str - img_size: Annotated[Union[Tuple[int, int], int], Field(ge=1)] = 224 + img_size: Annotated[tuple[int, int] | int, Field(ge=1)] = 224 n_classes: Optional[Annotated[int, Field(ge=1)]] = 1000 n_layer: Annotated[int, Field(ge=1)] = 12 attention_config: AttentionConfig = None @@ -247,7 +247,7 @@ def __init__( self, sample_key: str, prediction_key: str, - img_size: Union[Tuple[int, int], int] = 224, + img_size: tuple[int, int] | int = 224, n_classes: int = 1000, n_layer: int = 12, attention_config: AttentionConfig = None, @@ -269,7 +269,7 @@ def __init__( Args: sample_key (str): The key for the samples. prediction_key (str): The key for the predictions. - img_size (Union[Tuple[int, int], int], optional): The size of the input image. Defaults to 224. + img_size (tuple[int, int] | int, optional): The size of the input image. Defaults to 224. n_classes (int, optional): The number of classes. Defaults to 1000. n_layer (int, optional): The number of layers. Defaults to 12. attention_config (AttentionConfig, optional): The attention configuration. Defaults to None. @@ -373,7 +373,7 @@ def forward_videos(self, x: torch.Tensor) -> torch.Tensor: latents = block(x, latents) return latents - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # TODO video adapt + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: # TODO video adapt """ Forward pass of the VisionTransformer module. @@ -398,12 +398,12 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {self.prediction_key: x} @staticmethod - def _calculate_block_size(img_size: Tuple[int, int], patch_size: int, patch_stride: int, add_cls_token: bool): + def _calculate_block_size(img_size: tuple[int, int], patch_size: int, patch_stride: int, add_cls_token: bool): """ Calculates the block size. Args: - img_size (Tuple[int, int]): The size of the input image. + img_size (tuple[int, int]): The size of the input image. patch_size (int): The size of each patch. patch_stride (int): The stride of each patch. add_cls_token (bool): Flag indicating whether to add a classification token. diff --git a/src/modalities/nn/attention.py b/src/modalities/nn/attention.py index ceecff542..35ac8f45b 100644 --- a/src/modalities/nn/attention.py +++ b/src/modalities/nn/attention.py @@ -1,6 +1,6 @@ import math from enum import Enum -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F @@ -95,7 +95,7 @@ def _flash_without_mask(self, query: Tensor, key: Tensor, value: Tensor) -> Tens is_causal=self.is_causal, ) - def _forward_input_projection(self, x: Tensor, context: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + def _forward_input_projection(self, x: Tensor, context: Tensor) -> tuple[Tensor, Tensor, Tensor]: B, T, C = x.shape # batch size, sequence length, embedding dimensionality (n_embd) _, Tc, Cc = context.shape # batch size, context length, context embedding dimensionality # Note that the context length (Tc), sequence length (T) and embedding dimensionalities (C and Cc) diff --git a/src/modalities/nn/model_initialization/composed_initialization.py b/src/modalities/nn/model_initialization/composed_initialization.py index 3b02d0bd6..4a4d9f330 100644 --- a/src/modalities/nn/model_initialization/composed_initialization.py +++ b/src/modalities/nn/model_initialization/composed_initialization.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional import torch.nn as nn from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -15,7 +15,7 @@ class ModelInitializerWrapperConfig(BaseModel): - model_initializers: List[PydanticModelInitializationIFType] + model_initializers: list[PydanticModelInitializationIFType] # avoid warning about protected namespace 'model_', see # https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces @@ -78,7 +78,7 @@ def _check_values(self): class ModelInitializerWrapper(ModelInitializationIF): - def __init__(self, model_initializers: List[ModelInitializationIF]): + def __init__(self, model_initializers: list[ModelInitializationIF]): self.model_initializers = model_initializers def initialize_in_place(self, model: nn.Module): @@ -88,7 +88,7 @@ def initialize_in_place(self, model: nn.Module): class ComposedInitializationRoutines: @staticmethod - def get_model_initializer_wrapper(model_initializers: List[ModelInitializationIF]) -> ModelInitializationIF: + def get_model_initializer_wrapper(model_initializers: list[ModelInitializationIF]) -> ModelInitializationIF: initializer_wrapper = ModelInitializerWrapper(model_initializers) return initializer_wrapper diff --git a/src/modalities/nn/model_initialization/initialization_routines.py b/src/modalities/nn/model_initialization/initialization_routines.py index aa14ebf84..743297ece 100644 --- a/src/modalities/nn/model_initialization/initialization_routines.py +++ b/src/modalities/nn/model_initialization/initialization_routines.py @@ -1,6 +1,6 @@ import math import re -from typing import Annotated, List, Optional +from typing import Annotated, Optional import torch.nn as nn from pydantic import BaseModel, Field, model_validator @@ -12,7 +12,7 @@ class PlainInitializationConfig(BaseModel): mean: float std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto" - parameter_name_regexes: List[str] # here we filter for the parameter names, e.g., "c_proj.weight" + parameter_name_regexes: list[str] # here we filter for the parameter names, e.g., "c_proj.weight" hidden_dim: Optional[int] = None @model_validator(mode="after") @@ -30,12 +30,12 @@ class ScaledInitializationConfig(BaseModel): mean: float std: Annotated[float, Field(strict=True, ge=0.0)] num_layers: Annotated[int, Field(strict=True, gt=0)] - parameter_name_regexes: List[str] # here we filter for the parameter names, e.g., "c_proj.weight" + parameter_name_regexes: list[str] # here we filter for the parameter names, e.g., "c_proj.weight" class ScaledEmbedInitializationConfig(BaseModel): mean: float - parameter_name_regexes: List[str] # here we filter for the parameter names, e.g., "c_proj.weight" + parameter_name_regexes: list[str] # here we filter for the parameter names, e.g., "c_proj.weight" class NamedParameterwiseNormalInitialization(ModelInitializationIF): @@ -59,7 +59,7 @@ def initialize_in_place(self, model: nn.Module): class InitializationRoutines: @staticmethod def get_plain_initialization( - mean: float, std: float | str, parameter_name_regexes: List[str], hidden_dim: Optional[int] = None + mean: float, std: float | str, parameter_name_regexes: list[str], hidden_dim: Optional[int] = None ) -> NamedParameterwiseNormalInitialization: """Initializes the weights of a model by sampling from a normal distribution. NOTE: This class supports the initialization of nn.Linear and nn.Embedding layers. @@ -86,7 +86,7 @@ def get_plain_initialization( @staticmethod def get_scaled_initialization( - mean: float, std: float, num_layers: int, parameter_name_regexes: List[str] + mean: float, std: float, num_layers: int, parameter_name_regexes: list[str] ) -> ModelInitializationIF: """Implementation of scaled weight initialization. As defined in https://arxiv.org/abs/2312.16903 @@ -94,7 +94,7 @@ def get_scaled_initialization( mean (float): Mean of the normal distribution std (float): Standard deviation of the normal distribution used to initialize the other weights num_layers (int): Number of layers in the model which we use to downscale std with - parameter_name_regexes (List[str]): List of parameter name regexes to which the initialization + parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization should be applied Returns: @@ -109,13 +109,13 @@ def get_scaled_initialization( return initialization @staticmethod - def get_scaled_embed_initialization(mean: float, parameter_name_regexes: List[str]) -> ModelInitializationIF: + def get_scaled_embed_initialization(mean: float, parameter_name_regexes: list[str]) -> ModelInitializationIF: """Implementation of scaled weight initialization for embeddings, see https://arxiv.org/abs/2312.16903 We fix the standard deviation to sqrt(0.4). Args: mean (float): Mean of the normal distribution - parameter_name_regexes (List[str], optional): List of parameter name regexes to which the initialization + parameter_name_regexes (list[str], optional): List of parameter name regexes to which the initialization should be applied Defaults to None. Returns: diff --git a/src/modalities/nn/model_initialization/parameter_name_filters.py b/src/modalities/nn/model_initialization/parameter_name_filters.py index eca76ea5e..df569094e 100644 --- a/src/modalities/nn/model_initialization/parameter_name_filters.py +++ b/src/modalities/nn/model_initialization/parameter_name_filters.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Optional +from typing import Optional from pydantic import BaseModel, Field @@ -16,8 +16,8 @@ class SupportWeightInitModels(Enum): class RegexFilter(BaseModel): - weights: List[str] - biases: Optional[List[str]] = Field(default_factory=list) + weights: list[str] + biases: Optional[list[str]] = Field(default_factory=list) NAMED_PARAMETER_INIT_GROUPS = { diff --git a/src/modalities/optimizers/lr_schedulers.py b/src/modalities/optimizers/lr_schedulers.py index bf53fc18d..5e0e6f5b2 100644 --- a/src/modalities/optimizers/lr_schedulers.py +++ b/src/modalities/optimizers/lr_schedulers.py @@ -1,5 +1,4 @@ import warnings -from typing import List from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -9,7 +8,7 @@ class DummyLRScheduler(LRScheduler): def __init__(self, optimizer: Optimizer, last_epoch=-1, verbose=False): super().__init__(optimizer, last_epoch, verbose) - def get_lr(self) -> List[float]: + def get_lr(self) -> list[float]: if not self._get_lr_called_within_step: # type error expected due to internal pytorch implementation warnings.warn( "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning @@ -17,5 +16,5 @@ def get_lr(self) -> List[float]: return [group["lr"] for group in self.optimizer.param_groups] - def _get_closed_form_lr(self) -> List[float]: + def _get_closed_form_lr(self) -> list[float]: return self.base_lrs diff --git a/src/modalities/optimizers/optimizer_factory.py b/src/modalities/optimizers/optimizer_factory.py index 371c9c480..fdff5d0f9 100644 --- a/src/modalities/optimizers/optimizer_factory.py +++ b/src/modalities/optimizers/optimizer_factory.py @@ -1,6 +1,5 @@ import re from pathlib import Path -from typing import Dict, List, Tuple import torch.nn as nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -11,16 +10,16 @@ from modalities.models.model import NNModel from modalities.util import get_local_number_of_trainable_parameters, print_rank_0 -OptimizerGroups = List[Dict[str, List[nn.Parameter] | float]] +OptimizerGroups = list[dict[str, list[nn.Parameter] | float]] class OptimizerFactory: def get_adam( lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, weight_decay: float, - weight_decay_groups_excluded: List[str], + weight_decay_groups_excluded: list[str], wrapped_model: nn.Module, ) -> Optimizer: optimizer_groups = get_optimizer_groups(wrapped_model, weight_decay, weight_decay_groups_excluded) @@ -29,10 +28,10 @@ def get_adam( def get_adam_w( lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, weight_decay: float, - weight_decay_groups_excluded: List[str], + weight_decay_groups_excluded: list[str], wrapped_model: nn.Module, ) -> Optimizer: optimizer_groups = get_optimizer_groups(wrapped_model, weight_decay, weight_decay_groups_excluded) @@ -49,7 +48,7 @@ def get_checkpointed_optimizer( return wrapped_optimizer -def get_optimizer_groups(model: FSDP, weight_decay: float, weight_decay_groups_excluded: List[str]) -> OptimizerGroups: +def get_optimizer_groups(model: FSDP, weight_decay: float, weight_decay_groups_excluded: list[str]) -> OptimizerGroups: """ divide model parameters into optimizer groups (with or without weight decay) @@ -73,7 +72,7 @@ def get_optimizer_groups(model: FSDP, weight_decay: float, weight_decay_groups_e return optimizer_groups -def _assert_existence_of_weight_decay_groups_excluded(model: FSDP, weight_decay_groups_excluded: List[str]) -> None: +def _assert_existence_of_weight_decay_groups_excluded(model: FSDP, weight_decay_groups_excluded: list[str]) -> None: """ checks the existence of all groups that are to be excluded from weight decay @@ -93,8 +92,8 @@ def _assert_existence_of_weight_decay_groups_excluded(model: FSDP, weight_decay_ def _create_optimizer_groups( - model: FSDP, weight_decay: float, weight_decay_groups_excluded: List[str] -) -> Tuple[OptimizerGroups, List[str]]: + model: FSDP, weight_decay: float, weight_decay_groups_excluded: list[str] +) -> tuple[OptimizerGroups, list[str]]: """ create optimizer groups of parameters with different weight decays that are to be used in Adam or AdamW """ @@ -118,8 +117,8 @@ def _create_optimizer_groups( def _filter_params_for_weight_decay_group( - params: Dict[str, List[nn.Parameter]], regex_expressions: List[str] -) -> List[nn.Parameter]: + params: dict[str, list[nn.Parameter]], regex_expressions: list[str] +) -> list[nn.Parameter]: """ filter parameters by their name. a parameter is kept if and only if it contains at least one of the regex expressions. @@ -139,7 +138,7 @@ def _print_params(params) -> None: print_rank_0(f"{i + 1} {name}") -def _print_optimizer_groups_overview(optimizer_groups: OptimizerGroups, optimizer_groups_names: List[str]) -> None: +def _print_optimizer_groups_overview(optimizer_groups: OptimizerGroups, optimizer_groups_names: list[str]) -> None: """ for each optimizer group, the following is printed: - the number of modules diff --git a/src/modalities/registry/registry.py b/src/modalities/registry/registry.py index aebd8cea0..020107001 100644 --- a/src/modalities/registry/registry.py +++ b/src/modalities/registry/registry.py @@ -1,25 +1,25 @@ from dataclasses import asdict -from typing import Dict, List, Optional, Tuple, Type +from typing import Optional, Type from pydantic import BaseModel from modalities.registry.components import ComponentEntity -Entity = Tuple[Type, Type[BaseModel]] +Entity = tuple[Type, Type[BaseModel]] class Registry: """Registry class to store the components and their config classes.""" - def __init__(self, components: Optional[List[ComponentEntity]] = None) -> None: + def __init__(self, components: Optional[list[ComponentEntity]] = None) -> None: """Initializes the Registry class with an optional list of components. Args: - components (List[ComponentEntity], optional): List of components to + components (list[ComponentEntity], optional): List of components to intialize the registry with . Defaults to None. """ # maps component_key -> variant_key -> entity = (component, config) - self._registry_dict: Dict[str, Dict[str, Entity]] = {} + self._registry_dict: dict[str, dict[str, Entity]] = {} if components is not None: for component in components: self.add_entity(**asdict(component)) diff --git a/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py index 22d463f3a..97ad69764 100644 --- a/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py +++ b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py @@ -1,7 +1,7 @@ import functools import logging from abc import ABC, abstractmethod -from typing import Callable, List +from typing import Callable import torch.nn as nn from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy @@ -17,13 +17,13 @@ def get_auto_wrap_policy(self) -> Callable: class FSDPTransformerAutoWrapPolicyFactory(FSDPAutoWrapFactoryIF): - def __init__(self, model: nn.Module, block_names: List[str]) -> None: + def __init__(self, model: nn.Module, block_names: list[str]) -> None: # TODO it's problematic that we store the model in-memory here. Might get too large in RAM... self.model = model self.block_names = block_names @staticmethod - def _get_fsdp_blocks_from_block_names(model: nn.Module, block_names: List[str]) -> List[nn.Module]: + def _get_fsdp_blocks_from_block_names(model: nn.Module, block_names: list[str]) -> list[nn.Module]: fsdp_block_types = [] for cls_block_name in block_names: # TODO FullyShardedDataParallelPlugin from Accelerate uses string matching to find the correct diff --git a/src/modalities/tokenization/tokenizer_wrapper.py b/src/modalities/tokenization/tokenizer_wrapper.py index 479c79548..e9e778fc0 100644 --- a/src/modalities/tokenization/tokenizer_wrapper.py +++ b/src/modalities/tokenization/tokenizer_wrapper.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Dict, List, Optional +from typing import Optional import sentencepiece as spm from transformers import AutoTokenizer @@ -8,7 +8,7 @@ class TokenizerWrapper(ABC): """Abstract interface for tokenizers.""" - def tokenize(self, text: str) -> List[int]: + def tokenize(self, text: str) -> list[int]: """Tokenizes a text into a list of token IDs. Args: @@ -18,15 +18,15 @@ def tokenize(self, text: str) -> List[int]: NotImplementedError: Must be implemented by a subclass. Returns: - List[int]: List of token IDs. + list[int]: List of token IDs. """ raise NotImplementedError - def decode(self, input_ids: List[int]) -> str: + def decode(self, input_ids: list[int]) -> str: """Decodes a list of token IDs into the original text. Args: - input_ids (List[int]): List of token IDs. + input_ids (list[int]): List of token IDs. Raises: NotImplementedError: Must be implemented by a subclass. @@ -72,7 +72,7 @@ def __init__( truncation: Optional[bool] = False, padding: Optional[bool | str] = False, max_length: Optional[int] = None, - special_tokens: Optional[Dict[str, str]] = None, + special_tokens: Optional[dict[str, str]] = None, ) -> None: """Initializes the PreTrainedHFTokenizer. @@ -81,7 +81,7 @@ def __init__( truncation (bool, optional): Flag whether to apply truncation. Defaults to False. padding (bool | str, optional): Defines the padding strategy. Defaults to False. max_length (int, optional): Maximum length of the tokenization output. Defaults to None. - special_tokens (Dict[str, str], optional): Added token keys should be in the list + special_tokens (dict[str, str], optional): Added token keys should be in the list of predefined special attributes: [bos_token, eos_token, unk_token, sep_token, pad_token, cls_token, mask_token, additional_special_tokens]. Example: {"pad_token": "[PAD]"} @@ -113,22 +113,22 @@ def vocab_size(self) -> int: return self.tokenizer.vocab_size @property - def special_tokens(self) -> Dict[str, str | List[str]]: + def special_tokens(self) -> dict[str, str | list[str]]: """Returns the special tokens of the tokenizer. Returns: - Dict[str, str | List[str]]: Special tokens dictionary. + dict[str, str | list[str]]: Special tokens dictionary. """ return self.tokenizer.special_tokens_map - def tokenize(self, text: str) -> List[int]: + def tokenize(self, text: str) -> list[int]: """Tokenizes a text into a list of token IDs. Args: text (str): Text to be tokenized. Returns: - List[int]: List of token IDs. + list[int]: List of token IDs. """ tokens = self.tokenizer.__call__( text, @@ -138,11 +138,11 @@ def tokenize(self, text: str) -> List[int]: )["input_ids"] return tokens - def decode(self, token_ids: List[int]) -> str: + def decode(self, token_ids: list[int]) -> str: """Decodes a list of token IDs into the original text. Args: - input_ids (List[int]): List of token IDs. + input_ids (list[int]): List of token IDs. Returns: str: Decoded text. @@ -180,23 +180,23 @@ def __init__(self, tokenizer_model_file: str): self.tokenizer = spm.SentencePieceProcessor() self.tokenizer.Load(tokenizer_model_file) - def tokenize(self, text: str) -> List[int]: + def tokenize(self, text: str) -> list[int]: """Tokenizes a text into a list of token IDs. Args: text (str): Text to be tokenized. Returns: - List[int]: List of token IDs. + list[int]: List of token IDs. """ tokens = self.tokenizer.encode(text) return tokens - def decode(self, token_ids: List[int]) -> str: + def decode(self, token_ids: list[int]) -> str: """Decodes a list of token IDs into the original text. Args: - input_ids (List[int]): List of token IDs. + input_ids (list[int]): List of token IDs. Returns: str: Decoded text. diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 1128a6b95..1bf187218 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Callable, List, Optional, Tuple +from typing import Callable, Optional import torch import torch.distributed as dist @@ -87,9 +87,9 @@ def _train_batch( model: FSDP, optimizer: Optimizer, scheduler: LRScheduler, - loss_fun: List[Loss], + loss_fun: list[Loss], micro_batch_id: int, - ) -> Tuple[bool, int, torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[bool, int, torch.Tensor, Optional[torch.Tensor]]: """ Conducts a training step on batch of data. @@ -102,7 +102,7 @@ def _train_batch( micro_batch_id (int): The ID of the micro batch. Returns: - Tuple[bool, int, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + tuple[bool, int, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple containing the following: - step_performed (bool): Indicates whether a training step was performed. - num_train_steps_done (int): The number of training steps done. diff --git a/src/modalities/training/activation_checkpointing.py b/src/modalities/training/activation_checkpointing.py index 4da526874..ee7c71527 100644 --- a/src/modalities/training/activation_checkpointing.py +++ b/src/modalities/training/activation_checkpointing.py @@ -1,5 +1,4 @@ from functools import partial -from typing import List import torch from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -13,12 +12,12 @@ def is_module_to_apply_activation_checkpointing( - submodule: torch.nn.Module, activation_checkpointing_modules: List[type] + submodule: torch.nn.Module, activation_checkpointing_modules: list[type] ) -> bool: return isinstance(submodule, tuple(activation_checkpointing_modules)) -def apply_activation_checkpointing_inplace(model: torch.nn.Module, activation_checkpointing_modules: List[str]): +def apply_activation_checkpointing_inplace(model: torch.nn.Module, activation_checkpointing_modules: list[str]): activation_checkpointing_module_types = [ get_module_class_from_name(model, m) for m in activation_checkpointing_modules ] diff --git a/src/modalities/util.py b/src/modalities/util.py index 9a9b7c3b1..bda55b873 100644 --- a/src/modalities/util.py +++ b/src/modalities/util.py @@ -5,7 +5,7 @@ from enum import Enum from pathlib import Path from types import TracebackType -from typing import Callable, Dict, Generic, Optional, Type, TypeVar +from typing import Callable, Generic, Optional, Type, TypeVar import torch import torch.distributed as dist @@ -151,7 +151,7 @@ def __repr__(self) -> str: class Aggregator(Generic[T]): def __init__(self): - self.key_to_value: Dict[T, torch.Tensor] = {} + self.key_to_value: dict[T, torch.Tensor] = {} def add_value(self, key: T, value: torch.Tensor): if key not in self.key_to_value: @@ -202,6 +202,7 @@ def flatten_dict(d, parent_key="", sep="_"): items.append((new_key, v)) return dict(items) + def get_module_class_from_name(module: torch.nn.Module, name: str) -> Type[torch.nn.Module] | None: """From Accelerate source code (https://github.com/huggingface/accelerate/blob/1f7a79b428749f45187ec69485f2c966fe21926e/src/accelerate/utils/dataclasses.py#L1902) diff --git a/src/modalities/utils/mfu.py b/src/modalities/utils/mfu.py index d33e165e5..48d9a604e 100644 --- a/src/modalities/utils/mfu.py +++ b/src/modalities/utils/mfu.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Tuple +from typing import Optional import torch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -76,7 +76,7 @@ def get_theoretical_gpu_peak_performance(model: FSDP, world_size: int) -> Option return None -def get_theoretical_flops_per_token(model: FSDP) -> Tuple[Optional[int], Optional[int]]: +def get_theoretical_flops_per_token(model: FSDP) -> tuple[Optional[int], Optional[int]]: """ Calculates the theoretical number of floating point operations (FLOPs) per token for a given model. compute theoretical_flops_per_token = 6*N + 12*L*T*H @@ -86,7 +86,7 @@ def get_theoretical_flops_per_token(model: FSDP) -> Tuple[Optional[int], Optiona model (FSDP): The model for which to calculate the FLOPs per token. Returns: - Tuple[(int, optional), (int, optional)]: A tuple containing the theoretical FLOPs per token + tuple[(int, optional), (int, optional)]: A tuple containing the theoretical FLOPs per token and the sequence length. - Theoretical FLOPs per token: The estimated number of FLOPs required to process each token in the model. - Sequence length: The length of the input sequence. Needed to convert samples to tokens in compute_mfu. diff --git a/tests/checkpointing/pytorch/test_torch_checkpoint_loading.py b/tests/checkpointing/pytorch/test_torch_checkpoint_loading.py index 6c6615118..080bcfc92 100644 --- a/tests/checkpointing/pytorch/test_torch_checkpoint_loading.py +++ b/tests/checkpointing/pytorch/test_torch_checkpoint_loading.py @@ -1,5 +1,3 @@ -from typing import Dict - import pytest import torch import torch.nn as nn @@ -13,7 +11,7 @@ def __init__(self): super().__init__() self._weights = nn.Linear(2, 3) - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: output = self._weights(**inputs) return {"output": output} diff --git a/tests/checkpointing/test_checkpoint_strategies.py b/tests/checkpointing/test_checkpoint_strategies.py index 9aef57dfc..9fd34580b 100644 --- a/tests/checkpointing/test_checkpoint_strategies.py +++ b/tests/checkpointing/test_checkpoint_strategies.py @@ -1,5 +1,3 @@ -from typing import List - import pytest from modalities.checkpointing.checkpoint_saving_strategies import SaveKMostRecentCheckpointsStrategy @@ -25,7 +23,7 @@ ], ) def test_checkpoint_strategy_k( - k: int, saved_instances: List[TrainingProgress], checkpoints_to_delete: List[int], save_current: bool + k: int, saved_instances: list[TrainingProgress], checkpoints_to_delete: list[int], save_current: bool ) -> None: training_progress = TrainingProgress( num_seen_steps_current_run=10, num_seen_tokens_current_run=10, num_target_steps=20, num_target_tokens=40 diff --git a/tests/checkpointing/test_fsdp_to_disc_checkpointing.py b/tests/checkpointing/test_fsdp_to_disc_checkpointing.py index 9a56d2741..cdcbecb70 100644 --- a/tests/checkpointing/test_fsdp_to_disc_checkpointing.py +++ b/tests/checkpointing/test_fsdp_to_disc_checkpointing.py @@ -2,7 +2,6 @@ import tempfile from copy import deepcopy from pathlib import Path -from typing import Dict import pytest import torch @@ -42,7 +41,7 @@ reason="This e2e test requires 2 GPUs and a torchrun distributed environment.", ) class TestFSDPToDiscCheckpointing: - def get_gpt2_model_from_config(self, gpt2_model_config_dict: Dict) -> GPT2LLM: + def get_gpt2_model_from_config(self, gpt2_model_config_dict: dict) -> GPT2LLM: class GPT2InstantationModel(BaseModel): model: PydanticPytorchModuleType @@ -57,7 +56,7 @@ class GPT2InstantationModel(BaseModel): return model @pytest.fixture(scope="function") - def gpt2_model_config_dict(self) -> Dict: + def gpt2_model_config_dict(self) -> dict: config_file_path = working_dir / "gpt2_config.yaml" config_dict = load_app_config_dict(config_file_path=config_file_path) return config_dict @@ -108,7 +107,7 @@ def _clone_parameters(fsdp_wrapped_model): return [p.clone() for p in fsdp_wrapped_model.parameters() if p.requires_grad and p.numel() > 0] @staticmethod - def _generate_batch(gpt2_model_config: Dict): + def _generate_batch(gpt2_model_config: dict): # prepare input and targets data = torch.randint( 0, # lowest token_id @@ -122,10 +121,10 @@ def _generate_batch(gpt2_model_config: Dict): @staticmethod def _forward_backward_pass( - gpt2_model_config: Dict, + gpt2_model_config: dict, model: FSDP, optimizer: Optimizer, - batch_input_ids_dict: Dict, + batch_input_ids_dict: dict, batch_target_ids: torch.Tensor, ): ce_loss = CrossEntropyLoss() @@ -148,7 +147,7 @@ def _forward_backward_pass( @staticmethod def _assert_equality_optimizer_param_group( - optimizer_1_state_dict: Dict, optimizer_2_state_dict: Dict, must_be_equal: bool + optimizer_1_state_dict: dict, optimizer_2_state_dict: dict, must_be_equal: bool ): if must_be_equal: assert ( @@ -161,7 +160,7 @@ def _assert_equality_optimizer_param_group( @staticmethod def _assert_equality_optimizer_state( - optimizer_1_state_dict: Dict, optimizer_2_state_dict: Dict, must_be_equal: bool + optimizer_1_state_dict: dict, optimizer_2_state_dict: dict, must_be_equal: bool ): optimizer_1_state = optimizer_1_state_dict["state"] optimizer_2_state = optimizer_2_state_dict["state"] @@ -195,7 +194,7 @@ def test_save_checkpoint_after_backward_pass( optimizer: Optimizer, temporary_checkpoint_folder_path: Path, gpt2_model_2: GPT2LLM, - gpt2_model_config_dict: Dict, + gpt2_model_config_dict: dict, ): experiment_id = "0" num_train_steps_done = 1 diff --git a/tests/config/components.py b/tests/config/components.py index 67c9e9a3c..f67c55318 100644 --- a/tests/config/components.py +++ b/tests/config/components.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import List class Component_V_W_X_IF: @@ -30,7 +29,7 @@ def __init__(self, val_x: str, single_dependency: Component_V_W_X_IF) -> None: class ComponentY: - def __init__(self, val_y: str, multi_dependency: List[Component_V_W_X_IF]) -> None: + def __init__(self, val_y: str, multi_dependency: list[Component_V_W_X_IF]) -> None: self.val_y = val_y self.multi_dependency = multi_dependency diff --git a/tests/config/configs.py b/tests/config/configs.py index 569c4d11a..2ed597ebc 100644 --- a/tests/config/configs.py +++ b/tests/config/configs.py @@ -1,4 +1,4 @@ -from typing import Annotated, List +from typing import Annotated from pydantic import BaseModel @@ -23,7 +23,7 @@ class CompXConfig(BaseModel): class CompYConfig(BaseModel): val_y: str - multi_dependency: List[PydanticComponent_V_W_X_IF_Type] + multi_dependency: list[PydanticComponent_V_W_X_IF_Type] class CompZConfig(BaseModel): diff --git a/tests/conftest.py b/tests/conftest.py index 1aa87f871..be2fa5713 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,6 @@ import os import pickle from pathlib import Path -from typing import Dict from unittest.mock import MagicMock import pytest @@ -51,7 +50,7 @@ def dummy_config_path() -> Path: @pytest.fixture -def dummy_config(monkeypatch, dummy_config_path) -> Dict: +def dummy_config(monkeypatch, dummy_config_path) -> dict: monkeypatch.setenv("RANK", "0") monkeypatch.setenv("LOCAL_RANK", "0") monkeypatch.setenv("WORLD_SIZE", "1") diff --git a/tests/dataloader/dummy_sequential_dataset.py b/tests/dataloader/dummy_sequential_dataset.py index 8eb412a42..0d3a8dee8 100644 --- a/tests/dataloader/dummy_sequential_dataset.py +++ b/tests/dataloader/dummy_sequential_dataset.py @@ -1,5 +1,3 @@ -from typing import Dict - from pydantic import BaseModel from torch.utils.data.dataset import Dataset as TorchdataSet @@ -11,7 +9,7 @@ def __init__(self, num_samples: int): def __len__(self) -> int: return len(self.samples) - def __getitem__(self, idx: int) -> Dict: + def __getitem__(self, idx: int) -> dict: return self.samples[idx] diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 3b0ef7be6..65139ce6a 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from pathlib import Path -from typing import Any, Dict, List +from typing import Any import numpy as np import pytest @@ -44,7 +44,7 @@ def test_resumable_dataloader(): assert (flat_samples == original_samples).all() -def test_dataloader_from_config(dummy_config: Dict): +def test_dataloader_from_config(dummy_config: dict): start_index = 2 dummy_config["train_dataloader"]["config"]["skip_num_batches"] = start_index @@ -248,7 +248,7 @@ class DataloaderTestModel(BaseModel): fixed_num_batches: int class IdentityCollateFn(CollateFnIF): - def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> List[Dict[str, torch.Tensor]]: + def __call__(self, batch: list[dict[str, torch.Tensor]]) -> list[dict[str, torch.Tensor]]: return batch root_dir = Path(__file__).parents[0] diff --git a/tests/end2end_tests/test_fsdp_warmstart.py b/tests/end2end_tests/test_fsdp_warmstart.py index f41233ffc..3261eb4b4 100644 --- a/tests/end2end_tests/test_fsdp_warmstart.py +++ b/tests/end2end_tests/test_fsdp_warmstart.py @@ -2,7 +2,7 @@ import os import tempfile from pathlib import Path -from typing import Any, Dict, List +from typing import Any import pytest import torch @@ -31,13 +31,13 @@ class SaveAllResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]): def __init__(self): - self.message_list: List[Message[EvaluationResultBatch]] = [] + self.message_list: list[Message[EvaluationResultBatch]] = [] def consume_message(self, message: Message[EvaluationResultBatch]): """Consumes a message from a message broker.""" self.message_list.append(message) - def consume_dict(self, mesasge_dict: Dict[str, Any]): + def consume_dict(self, mesasge_dict: dict[str, Any]): pass @@ -55,7 +55,7 @@ class TrainDataloaderInstantiationModel(BaseModel): ) class TestWarmstart: @staticmethod - def get_loss_scores(messages: List[Message[EvaluationResultBatch]], loss_key: str) -> List[float]: + def get_loss_scores(messages: list[Message[EvaluationResultBatch]], loss_key: str) -> list[float]: return [message.payload.losses[loss_key].value.item() for message in messages] def test_warm_start(self): @@ -130,7 +130,7 @@ def test_warm_start(self): # we collect the loss values from rank 0 and store them in the temporary experiment folder if dist.get_rank() == 0: - messages_0: List[Message[EvaluationResultBatch]] = components_0.evaluation_subscriber.message_list + messages_0: list[Message[EvaluationResultBatch]] = components_0.evaluation_subscriber.message_list loss_scores_0 = TestWarmstart.get_loss_scores(messages_0, "train loss avg") with open(loss_values_experiment_0_path, "w") as f: json.dump(loss_scores_0, f) @@ -156,7 +156,7 @@ def test_warm_start(self): # we collect the loss values from rank 0 for the warmstart model # and store them in the temporary experiment folder if dist.get_rank() == 0: - messages_1: List[Message[EvaluationResultBatch]] = components_1.evaluation_subscriber.message_list + messages_1: list[Message[EvaluationResultBatch]] = components_1.evaluation_subscriber.message_list loss_scores_1 = TestWarmstart.get_loss_scores(messages_1, "train loss avg") with open(loss_values_experiment_1_path, "w") as f: json.dump(loss_scores_1, f) diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 9edbb3c0f..a169ade87 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -2,7 +2,7 @@ import os import re from pathlib import Path -from typing import Dict, Optional +from typing import Optional import pytest import torch @@ -27,7 +27,7 @@ # $(which pytest) path/to/test_initialization.py -def get_model_from_config(model_config_dict: Dict) -> GPT2LLM | CoCa: +def get_model_from_config(model_config_dict: dict) -> GPT2LLM | CoCa: """get gpt2 or coca model from config_dict""" class InstantationModel(BaseModel): @@ -44,7 +44,7 @@ class InstantationModel(BaseModel): return model -def _replace_config_dict(_config_dict: Dict, _initialization_type: str, _std: str) -> Dict: +def _replace_config_dict(_config_dict: dict, _initialization_type: str, _std: str) -> dict: """dynamically replace initialization_type, std and dependent fields in config_dict""" _config_dict["model"]["config"]["model_initializer"]["config"]["weight_init_type"] = _initialization_type # replace _config_dict["model"]["config"]["model_initializer"]["config"]["std"] = _std # replace @@ -120,7 +120,7 @@ def _load_model(model_name: str, initialization: str = "plain", std: float | str } -def get_group_params(model: FSDP, model_name: str) -> Dict[str, Optional[torch.Tensor]]: +def get_group_params(model: FSDP, model_name: str) -> dict[str, Optional[torch.Tensor]]: """ divide all model parameters into initialization groups """ diff --git a/tests/test_optimizer_factory.py b/tests/test_optimizer_factory.py index 840003d34..4f273ad01 100644 --- a/tests/test_optimizer_factory.py +++ b/tests/test_optimizer_factory.py @@ -1,6 +1,5 @@ import os from pathlib import Path -from typing import Dict import pytest import torch @@ -26,7 +25,7 @@ # $(which pytest) path/to/test_optimizer_factory.py -def get_gpt2_model_from_config(gpt2_model_config_dict: Dict) -> GPT2LLM: +def get_gpt2_model_from_config(gpt2_model_config_dict: dict) -> GPT2LLM: class GPT2InstantationModel(BaseModel): model: PydanticPytorchModuleType diff --git a/tests/utils/test_mfu.py b/tests/utils/test_mfu.py index e0672004a..ea1d424a5 100644 --- a/tests/utils/test_mfu.py +++ b/tests/utils/test_mfu.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Dict, Optional +from typing import Optional import pytest import torch @@ -26,7 +26,7 @@ # $(which pytest) path/to/test_mfu.py -def get_model_from_config(model_config_dict: Dict) -> GPT2LLM: +def get_model_from_config(model_config_dict: dict) -> GPT2LLM: """get gpt2 model from config_dict""" class InstantationModel(BaseModel): diff --git a/tutorials/library_usage/README.md b/tutorials/library_usage/README.md index e23abe32c..74729a284 100644 --- a/tutorials/library_usage/README.md +++ b/tutorials/library_usage/README.md @@ -28,7 +28,7 @@ class CustomGPT2LLMCollateFn(CollateFnIF): self.target_key = target_key self.custom_attribute = custom_attribute - def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: sample_tensor = torch.stack([torch.tensor(d[self.sample_key]) for d in batch]) samples = {self.sample_key: sample_tensor[:, :-1]} targets = {self.target_key: sample_tensor[:, 1:]} diff --git a/tutorials/library_usage/main.py b/tutorials/library_usage/main.py index 6e65e4e65..727a4db37 100644 --- a/tutorials/library_usage/main.py +++ b/tutorials/library_usage/main.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List import torch from pydantic import BaseModel @@ -24,7 +23,7 @@ def __init__(self, sample_key: str, target_key: str, custom_attribute: str): self.target_key = target_key self.custom_attribute = custom_attribute - def __call__(self, batch: List[List[int]]) -> DatasetBatch: + def __call__(self, batch: list[list[int]]) -> DatasetBatch: sample_tensor = torch.tensor(batch) samples = {self.sample_key: sample_tensor[:, :-1]} targets = {self.target_key: sample_tensor[:, 1:]} From 3eb77a3ac0c76b1e182ee84a6f157bc96d75cbcc Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Tue, 24 Sep 2024 13:01:05 +0200 Subject: [PATCH 138/161] chore: refactor loss related items to match main --- src/modalities/config/config.py | 19 +++++++++ src/modalities/evaluator.py | 59 ++++++--------------------- src/modalities/loss_functions.py | 45 ++++---------------- src/modalities/registry/components.py | 14 +++---- src/modalities/trainer.py | 46 +++++++-------------- 5 files changed, 58 insertions(+), 125 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index df1197066..92e24ee17 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -62,6 +62,25 @@ class ReferenceConfig(BaseModel): pass_type: PassType +class CLMCrossEntropyLossConfig(BaseModel): + target_key: str + prediction_key: str + + +class NCELossConfig(BaseModel): + prediction_key1: str + prediction_key2: str + is_asymmetric: bool = True + temperature: float = 1.0 + tag: str = "NCELoss" + + +class ClipLossConfig(BaseModel): + logit_scale_key: str + prediction_keys: list[str] + local_loss: bool = True + tag: str = "ClipLoss" + # Checkpointing class SaveEveryKStepsCheckpointingStrategyConfig(BaseModel): k: PositiveInt diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index 42f4e710e..0db60af72 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -1,14 +1,13 @@ -from typing import Dict, List +from typing import Callable, Dict, List import torch import torch.distributed as dist import torch.nn as nn -from modalities.batch import DatasetBatch, EvaluationResultBatch, ResultItem +from modalities.batch import DatasetBatch, EvaluationResultBatch, InferenceResultBatch, ResultItem from modalities.dataloader.dataloader import LLMDataLoader from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate from modalities.logging_broker.publisher import MessagePublisher -from modalities.loss_functions import Loss from modalities.models.model import model_predict_batch from modalities.running_env.fsdp.reducer import Reducer from modalities.trainer import ThroughputAggregationKeys @@ -36,7 +35,7 @@ def evaluate_batch( self, batch: DatasetBatch, model: nn.Module, - loss_fun: List[Loss], + loss_fun: Callable[[InferenceResultBatch], torch.Tensor], ) -> torch.Tensor: """Evaluate a single batch by forwarding it through the model and calculating the loss. @@ -50,30 +49,14 @@ def evaluate_batch( """ with torch.no_grad(): result_batch = model_predict_batch(model=model, batch=batch) + loss = loss_fun(result_batch) + return loss - total_loss = None - losses = [] - for lfn in loss_fun: - # Calculate loss - weighted_loss = lfn(result_batch) * lfn.weight - - # Add loss to total loss - if total_loss is None: - total_loss = weighted_loss - else: - total_loss += weighted_loss - - # Append individual losses (for logging) - losses.append(weighted_loss.clone().detach()) - - return total_loss, *losses - - @torch.no_grad() def evaluate( self, model: nn.Module, data_loaders: List[LLMDataLoader], - loss_fun: List[Loss], + loss_fun: Callable[[InferenceResultBatch], torch.Tensor], num_train_steps_done: int, ) -> Dict[str, EvaluationResultBatch]: """Evaluate the model on a set of datasets. @@ -93,7 +76,7 @@ def evaluate( device = torch.device("cuda" if torch.cuda.is_available() else "cpu") for data_loader in data_loaders: - cumulated_loss = torch.zeros(len(loss_fun) + 1 + 1).to(device) # total loss, indidual losses, count + cumulated_loss = torch.zeros(3).to(device) Evaluator._publish_progress( progress_publisher=self.progress_publisher, @@ -101,20 +84,16 @@ def evaluate( dataloader_tag=data_loader.dataloader_tag, ) thoughput_aggregator = Aggregator[ThroughputAggregationKeys]() - with TimeRecorder() as forward_backward_timer_recorder: for batch_id, batch in enumerate(data_loader): - batch_losses = self.evaluate_batch( + batch_loss = self.evaluate_batch( batch=batch, model=model, loss_fun=loss_fun, ) - # Accumulate losses - for i, batch_loss in enumerate(batch_losses): - cumulated_loss[i] += batch_loss.item() - cumulated_loss[-1] += 1 - + cumulated_loss[0] += batch_loss.item() # sum up batch loss + cumulated_loss[1] += 1 batch_length_tensor = torch.tensor(len(batch)).to(device) thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) @@ -124,11 +103,10 @@ def evaluate( dataloader_tag=data_loader.dataloader_tag, ) # TODO: insert reducer from outside so Evaluator is independent of FSDP - # Agreggate loss from all ranks total_loss = Reducer.reduce( tensor=cumulated_loss, operation=dist.ReduceOp.SUM, - post_processing_fun=lambda t: torch.cat([t[:-1] / t[-1], t[-1:] / dist.get_world_size()]), + post_processing_fun=lambda t: t[0] / t[1], ) forward_backward_time = torch.tensor(forward_backward_timer_recorder.delta_t).to(device) @@ -141,21 +119,8 @@ def evaluate( ) num_samples_per_second = synced_num_samples / synced_forward_backward_time - # Fill logging dict with total loss and the individual losses - loss_avg, loss_last_batch = ( - total_loss[0], - total_loss[-1], - ) - - losses = { - "total_loss average": ResultItem(loss_avg, decimal_places=2), - "total_loss last step": ResultItem(loss_last_batch, decimal_places=2), - } - for i, lfn in enumerate(loss_fun): - losses[lfn.tag] = ResultItem(total_loss[i + 1], decimal_places=2) - evaluation_result = EvaluationResultBatch( - losses=losses, + losses={loss_fun.tag: ResultItem(total_loss, decimal_places=2)}, # TODO: hardcoded metric key throughput_metrics={ "evaluation_num_samples_per_second": ResultItem(num_samples_per_second, decimal_places=1) diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index ebc976117..d9d2d244e 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -4,15 +4,14 @@ import torch.distributed as dist import torch.nn.functional as F from pydantic import BaseModel -from torch.nn import CrossEntropyLoss as TorchCrossEntropyLoss +from torch.nn import CrossEntropyLoss from modalities.batch import InferenceResultBatch class Loss(ABC): - def __init__(self, tag: str, weight: float = 1.0): + def __init__(self, tag: str): self._tag = tag - self.weight = weight @property def tag(self) -> str: @@ -27,22 +26,13 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: raise NotImplementedError -class CrossEntropyLossConfig(BaseModel): - target_key: str - prediction_key: str - weight: float = 1 - tag: str = "CLMCrossEntropyLoss" - - -class CrossEntropyLoss(Loss): - def __init__(self, target_key: str, prediction_key: str, weight: float, tag: str = "CLMCrossEntropyLoss"): - super().__init__(tag, weight) +class CLMCrossEntropyLoss(Loss): + def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEntropyLoss"): + super().__init__(tag) self.target_key = target_key self.prediction_key = prediction_key # Mean over the tokens in the local-batch (batch per rank) - self.loss_fun = TorchCrossEntropyLoss( - reduction="mean", - ) + self.loss_fun = CrossEntropyLoss(reduction="mean") def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: labels = forward_batch.get_targets(self.target_key) @@ -97,15 +87,6 @@ def nce_loss( return torch.mean(denominator - numerator) # calculated in log space -class NCELossConfig(BaseModel): - prediction_key1: str - prediction_key2: str - is_asymmetric: bool = True - temperature: float = 1.0 - weight: float = 1 - tag: str = "NCELoss" - - class NCELoss(Loss): def __init__( self, @@ -113,7 +94,6 @@ def __init__( prediction_key2: str, is_asymmetric: bool, temperature: float, - weight: float, tag: str = "NCELoss", ): """ @@ -126,7 +106,7 @@ def __init__( temperature (float, optional): temperature. Defaults to 1.0. tag (str, optional): Defaults to "NCELoss". """ - super().__init__(tag, weight) + super().__init__(tag) self.prediction_key1 = prediction_key1 self.prediction_key2 = prediction_key2 self.is_asymmetric = is_asymmetric @@ -152,20 +132,11 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: return loss -class ClipLossConfig(BaseModel): - logit_scale_key: str - prediction_keys: list[str] - weight: float = 1 - local_loss: bool = True - tag: str = "ClipLoss" - - class ClipLoss(Loss): def __init__( self, logit_scale_key: str, prediction_keys: list[str], - weight: float, local_loss: bool, tag: str = "ClipLoss", ): @@ -177,7 +148,7 @@ def __init__( prediction_keys (list[str]): Keys to access embeddings. tag (str, optional): Defaults to "ClipLoss". """ - super().__init__(tag, weight) + super().__init__(tag) self.logit_scale_key = logit_scale_key self.prediction_keys = prediction_keys self.local_loss = local_loss diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 5c598e627..8c48ceb1d 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -22,6 +22,8 @@ CheckpointedModelConfig, CheckpointedOptimizerConfig, CheckpointSavingConfig, + ClipLossConfig, + CLMCrossEntropyLossConfig, ConstantLRSchedulerConfig, CosineAnnealingLRSchedulerConfig, DistributedSamplerConfig, @@ -34,6 +36,7 @@ GPT2LLMCollateFnConfig, LLMDataLoaderConfig, MemMapDatasetConfig, + NCELossConfig, OneCycleLRSchedulerConfig, PackedMemMapDatasetContinuousConfig, PackedMemMapDatasetMegatronConfig, @@ -72,14 +75,7 @@ ProgressSubscriberFactory, ResultsSubscriberFactory, ) -from modalities.loss_functions import ( - ClipLoss, - ClipLossConfig, - CrossEntropyLoss, - CrossEntropyLossConfig, - NCELoss, - NCELossConfig, -) +from modalities.loss_functions import ClipLoss, CLMCrossEntropyLoss, NCELoss from modalities.models.coca.coca_model import CoCa, CoCaConfig from modalities.models.coca.collator import CoCaCollateFnConfig, CoCaCollatorFn from modalities.models.components.layer_norms import LayerNormConfig, RMSLayerNorm, RMSLayerNormConfig @@ -163,7 +159,7 @@ class ComponentEntity: ComposedModelInitializationConfig, ), # losses - ComponentEntity("loss", "cross_entropy_loss", CrossEntropyLoss, CrossEntropyLossConfig), + ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig), ComponentEntity("loss", "nce_loss", NCELoss, NCELossConfig), ComponentEntity("loss", "clip_loss", ClipLoss, ClipLossConfig), # optmizers diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 1128a6b95..e15e47c08 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Callable, List, Optional, Tuple +from typing import Callable, Optional, Tuple import torch import torch.distributed as dist @@ -87,7 +87,7 @@ def _train_batch( model: FSDP, optimizer: Optimizer, scheduler: LRScheduler, - loss_fun: List[Loss], + loss_fun: Loss, micro_batch_id: int, ) -> Tuple[bool, int, torch.Tensor, Optional[torch.Tensor]]: """ @@ -111,23 +111,8 @@ def _train_batch( if a training step was performed otherwise return None. """ result_batch = model_predict_batch(model=model, batch=batch) - - total_loss = None - losses = [] - for lfn in loss_fun: - # Calculate loss - weighted_loss = lfn(result_batch) * lfn.weight - - # Add loss to total loss - if total_loss is None: - total_loss = weighted_loss - else: - total_loss += weighted_loss - - # Append individual losses (for logging) - losses.append(weighted_loss.clone().detach()) - - (total_loss / self.gradient_acc_steps).backward() + loss = loss_fun(result_batch) + (loss / self.gradient_acc_steps).backward() if (micro_batch_id + 1) % self.gradient_acc_steps == 0: gradient_norm_score = self.gradient_clipper.clip_gradients() @@ -142,7 +127,7 @@ def _train_batch( num_train_steps_done = Trainer._get_num_train_steps_done( micro_batch_id=micro_batch_id, gradient_acc_steps=self.gradient_acc_steps ) - return total_loss, *losses, step_performed, num_train_steps_done, gradient_norm_score + return step_performed, num_train_steps_done, loss, gradient_norm_score def train( self, @@ -172,7 +157,7 @@ def train( None """ model.train() - cumulated_losses = self._reset_tracked_losses(len(loss_fun)) + cumulated_losses = self._reset_tracked_losses() # throughput & MFU thoughput_aggregator = Aggregator[ThroughputAggregationKeys]() @@ -207,9 +192,9 @@ def train( for _, (micro_batch_id, batch) in zip(range(num_batches_todo), enumerate(train_loader)): # Train single batch ( - *batch_losses, step_performed, num_train_steps_done, + batch_loss, gradient_norm_score, ) = self._train_batch( batch=batch, @@ -224,8 +209,7 @@ def train( training_progress.num_seen_tokens_current_run = self.global_num_tokens_per_train_step * num_train_steps_done # Save the batch loss - for i, batch_loss in enumerate(batch_losses): - cumulated_losses[i] += batch_loss.item() + cumulated_losses[0] += batch_loss.item() # This works, because we always drop the last batch in case it has less samples than the batch size cumulated_losses[-1] += 1 # number of local batches @@ -256,26 +240,24 @@ def train( synced_num_samples_per_second = synced_num_samples / synced_forward_backward_time # TODO: insert reducer from outside so Trainer is independent of FSDP # add the loss and gradient norm for the LAST batch + cumulated_losses[1] = batch_loss.item() reduced_losses = Reducer.reduce( tensor=cumulated_losses, operation=dist.ReduceOp.SUM, # 1.) summed batch loss / (num batches * world size) # 2.) last batch loss / world size - post_processing_fun=lambda t: torch.cat([t[:-1] / t[-1], t[-1:] / dist.get_world_size()]), + post_processing_fun=lambda t: torch.stack([t[0] / t[-1], t[1] / dist.get_world_size()]), ) train_loss_avg, train_loss_last_batch = ( reduced_losses[0], - reduced_losses[-1], + reduced_losses[1], ) - losses = { "train loss avg": ResultItem(train_loss_avg, decimal_places=2), "train loss last": ResultItem(train_loss_last_batch, decimal_places=2), } - for i, lfn in enumerate(loss_fun): - losses[lfn.tag] = ResultItem(reduced_losses[i + 1], decimal_places=2) consumed_tokens = torch.Tensor([training_progress.num_seen_tokens_total]) metrics = { @@ -310,7 +292,7 @@ def train( ) thoughput_aggregator.remove_keys() - cumulated_losses = self._reset_tracked_losses(len(loss_fun)) + cumulated_losses = self._reset_tracked_losses() if step_performed: evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total) checkpointing_callback(training_progress=training_progress) @@ -318,11 +300,11 @@ def train( # via the dataloader. forward_backward_time_recorder.start() - def _reset_tracked_losses(self, num_loss_functions: int): + def _reset_tracked_losses(self): # Initializes and returns a tensor representing the cumulated loss and gradient norm. # The tensor is initialized with zeros and its device is set based on the availability of CUDA. - cumulated_loss_and_gradient_norm = torch.zeros(num_loss_functions + 1 + 1) + cumulated_loss_and_gradient_norm = torch.zeros(3) if torch.cuda.is_available(): cumulated_loss_and_gradient_norm = cumulated_loss_and_gradient_norm.to(torch.device("cuda")) else: From 3366bc95252892a702cba6708b8a9072b80db7aa Mon Sep 17 00:00:00 2001 From: Thomas Holz Date: Tue, 24 Sep 2024 12:37:41 +0000 Subject: [PATCH 139/161] docs: add docstrings and type hints for audio-related classes and functions --- src/modalities/dataloader/dataset.py | 93 +++++++++- .../audio_transformer_model.py | 171 +++++++++++++++++- 2 files changed, 251 insertions(+), 13 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index e157e7fcb..624937b4f 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -483,6 +483,21 @@ def __call__(self, text): class AudioTransformConfig(TransformConfig): + """ + Configuration class for the audio transformation module. + + This class defines various parameters that control the behavior of the AudioTransform. + These parameters include whether the module is in training mode, the number of mel-frequency bands, + lengths for frequency and time domain masking during training, and the target block size for audio encoding. + + Attributes: + is_training (bool): Whether the module is in training mode. Defaults to False. + n_mels (int): Number of mel-frequency bands. Defaults to 128. + freq_domain_mask_length (int): Length of frequency masking during training. Defaults to 30. + time_domain_mask_length (int): Length of time masking during training. Defaults to 100. + block_size_audio_encoder (int): The target block size for audio encoding. + """ + is_training: bool = False n_mels: int = 128 freq_domain_mask_length: int = 30 @@ -491,6 +506,13 @@ class AudioTransformConfig(TransformConfig): class AudioTransform(Transform): + """ + An audio transformation module that processes raw audio into mel-spectrogram features. + + This module includes steps such as feature extraction, frequency and time domain masking during training, + padding to match a fixed block size, and returns the processed features along with their length. + """ + def __init__( self, block_size_audio_encoder: int, @@ -499,6 +521,19 @@ def __init__( freq_domain_mask_length: int = 30, time_domain_mask_length: int = 100, ): + """ + Initializes the AudioTransform class. + + Args: + block_size_audio_encoder (int): The target block size for audio encoding. + is_training (bool, optional): Whether the module is in training mode. Defaults to False. + n_mels (int, optional): Number of mel-frequency bands. Defaults to 128. + freq_domain_mask_length (int, optional): Length of frequency masking. Defaults to 30. + time_domain_mask_length (int, optional): Length of time masking. Defaults to 100. + + Returns: + tuple[torch.Tensor, int]: A tuple containing the processed audio features and their length. + """ self.block_size_audio_encoder = block_size_audio_encoder self.is_training = is_training self.n_mels = n_mels @@ -506,7 +541,17 @@ def __init__( self.time_domain_mask_length = time_domain_mask_length def __call__(self, raw_audio: tuple[torch.Tensor, int]) -> tuple[torch.Tensor, int]: - SUB_SAMPLING_FACTOR = 4 + """ + Processes the input raw audio into mel-spectrogram features. + + Args: + raw_audio (tuple[torch.Tensor, int]): A tuple containing the raw audio tensor and its sample rate. + + Returns: + tuple[torch.Tensor, int]: A tuple containing the processed audio features and their length. + """ + + SUB_SAMPLING_FACTOR = 4 # reduce the number of features (i.e., time frames) self.extract_features = torchaudio.transforms.MelSpectrogram(n_mels=self.n_mels) @@ -567,9 +612,25 @@ def __call__(self, video): return self.spatial_transform(video) -def decord_video(key, data): - """Based on the torch_video decoder in webdataset +def decord_video(key: str, data: bytes) -> None | tuple[torch.Tensor, Optional[torch.Tensor], int]: + """ + Based on the torch_video decoder in webdataset https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py#L394 + + Decode a video file using Decord and optionally extract audio. + + This function decodes a video file from the provided data. + It first checks if the file extension is one of the supported formats. + If an audio stream exists, it extracts the audio with a mean across channels (if there are multiple). + It then uses Decord to decode uniformly sampled frames from the video. + + Parameters: + key (str): The key or identifier for the video data. + data (bytes): The binary data of the video file. + + Returns: + tuple: A tuple containing the decoded video frames, audio tensor (if available), and audio sample rate. + If no audio stream exists, the audio tensor will be None. """ extension = re.sub(r".*[.]", "", key) if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split(): @@ -592,19 +653,34 @@ def decord_video(key, data): frame_ids = torch.linspace(0, len(vr) - 1, clip_num_frames, dtype=torch.int64) frames = vr.get_batch(frame_ids.tolist()) # T x H x W x C - return (frames, audio, audio_sample_rate) # audio can be None if no audio stream exists + return (frames, audio, audio_sample_rate) -def torch_audio(key, data): - """Based on the torch_audio decoder in webdataset +def torch_audio(key: str, data: bytes) -> None | tuple[torch.Tensor, int]: + """ + Based on the torch_audio decoder in webdataset https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py#L418 + + Decode an audio file using torchaudio. + + This function decodes an audio file from the provided data. + It first checks if the file extension is one of the supported formats. + If there are multiple channels in the audio file, it averages them to produce a mono audio tensor. + + Parameters: + key (str): The key or identifier for the audio data. + data (bytes): The binary data of the audio file. + + Returns: + tuple: A tuple containing the decoded audio tensor and its sample rate. If the file extension is not supported, + the function will return None. """ + extension = re.sub(r".*[.]", "", key) valid_extensions = "mp4 ogv mjpeg avi mov h264 mpg webm wmv flac mp3 sox wav m4a ogg wma".split() if extension not in valid_extensions: return None - # torchaudio.load returns (torch.Tensor, int) audio, sample_rate = torchaudio.load(data) if audio.shape[0] > 1: # more than one channel audio = torch.mean(audio, dim=0, keepdim=True) @@ -777,7 +853,8 @@ def _transform_video(self, sample): del sample[source_key] return sample - def _transform_audio(self, sample): + def _transform_audio(self, sample: dict): + # Apply audio transforms to the input sample. source_key, target_key = self.modality_key_mapping[ModalityEnum.AUDIO] transform: AudioTransform = self.modality_transforms[ModalityEnum.AUDIO] sample[target_key], sample["audio_len"] = transform(sample[source_key]) diff --git a/src/modalities/models/audio_transformer/audio_transformer_model.py b/src/modalities/models/audio_transformer/audio_transformer_model.py index 5852aa01d..435075845 100644 --- a/src/modalities/models/audio_transformer/audio_transformer_model.py +++ b/src/modalities/models/audio_transformer/audio_transformer_model.py @@ -9,6 +9,52 @@ class AudioTransformerConfig(BaseModel): + """ + Configuration for an audio transformer model using conformer blocks. + + This configuration class defines all necessary parameters to instantiate and configure an `AudioTransformer` model. + + Args: + sample_key (str): The key in the input dictionary that contains the audio samples. + prediction_key (str): The key under which the model's output will be stored in the output dictionary. + block_size (int): The size of each block for positional embeddings. Must be a positive integer. + n_mels (int): The number of mel-frequency bands used for input audio feature extraction. + Must be a positive integer. + n_embd (int): The embedding dimension used throughout the model. Must be a positive integer. + n_heads (int): The number of attention heads in the conformer blocks. Must be a positive integer. + n_conformer_blocks (int): The number of conformer blocks to include in the transformer model. + Must be a positive integer. + attention_config (AttentionConfig): Configuration object for attention mechanisms. + pointwise_conv_kernel_size (int): Kernel size for the pointwise convolutional layers in conformer blocks. + Must be a positive integer. + depthwise_conv_kernel_size (int): Kernel size for the depthwise convolutional layers in conformer blocks. + Must be a positive integer. + ffmodule_dropout (float, optional): Dropout rate for feed-forward modules in conformer blocks. + Must be a float less than 1.0. Default is 0.1. + attn_dropout (float, optional): Dropout rate for attention mechanisms. Must be a float less than 1.0. + Default is 0.1. + convmodule_dropout (float, optional): Dropout rate for depthwise convolutional layers in conformer blocks. + Must be a float less than 1.0. Default is 0.1. + + Returns: + AudioTransformerConfig: A configuration object that can be used to instantiate an `AudioTransformer` model with\ + the specified parameters. + + Examples: + >>> audio_encoder_config = AudioTransformerConfig( + sample_key="audio", + prediction_key="audio_embeddings", + block_size=2_000, + n_mels=128, + n_embd=768, + n_heads=8, + n_conformer_blocks=2, + attention_config=AttentionConfig(attention_engine_type="default_attention"), + pointwise_conv_kernel_size=1, + depthwise_conv_kernel_size=31 + ) + """ + sample_key: str prediction_key: str block_size: Annotated[int, Field(ge=1)] @@ -25,13 +71,36 @@ class AudioTransformerConfig(BaseModel): class ConvolutionModule(nn.Module): + """ + A convolutional module designed to process sequences using a series of layers including LayerNorm, + pointwise convolutions, GLU activation, depthwise convolution, batch normalization, SiLU (Swish) activation, + and a final pointwise convolution. + """ + def __init__( self, n_embd: int, pointwise_conv_kernel_size: int, depthwise_conv_kernel_size: int, - dropout: int, + dropout: float, ): + """ + Initializes the ConvolutionModule class. + + Args: + n_embd (int): The number of embedding dimensions. Must be a positive integer. + pointwise_conv_kernel_size (int): The kernel size for both the first and second pointwise convolutions. + depthwise_conv_kernel_size (int): The kernel size for the depthwise convolution. + dropout (float): Dropout rate applied after each layer. Must be a float between 0 and 1. + + Examples: + >>> module = ConvolutionModule( + n_embd=768, + pointwise_conv_kernel_size=1, + depthwise_conv_kernel_size=31, + dropout=0.1 + ) + """ super().__init__() self.ln = nn.LayerNorm(n_embd) self.pointwise_conv_1 = nn.Conv1d( @@ -64,26 +133,54 @@ def forward( self, x: torch.Tensor, ) -> torch.Tensor: + """ + Forward pass through the convolutional module. + + Args: + x (torch.Tensor): Input tensor of shape (B, T, D), where B is the batch size, + T is the number of time steps, and D is the embedding dimension. + + Returns: + torch.Tensor: Output tensor of shape (B, T, D). + """ x = self.ln(x) x = x.transpose(1, 2) x = self.glu(self.pointwise_conv_1(x)) x = self.swish(self.bn(self.depthwise_conv(x))) x = self.pointwise_conv_2(x) - return self.dropout(x.transpose(1, 2)) # shape: B, T, D + return self.dropout(x.transpose(1, 2)) class ConformerBlock(nn.Module): + """ + This block combines self-attention, feed-forward modules, and depthwise convolutional layers to provide + efficient processing of sequential data. + """ + def __init__( self, n_embd: int, n_heads: int, - attention_config, + attention_config: AttentionConfig, pointwise_conv_kernel_size: int, depthwise_conv_kernel_size: int, ffmodule_dropout: float, attn_dropout: float, convmodule_dropout: float, ) -> None: + """Initializes the ConformerBlock class. + + Args: + n_embd (int): The number of expected features in the input. + n_heads (int): Number of parallel attention heads. + attention_config (AttentionConfig): Configuration for the attention mechanism, typically a dictionary or \ + class instance. + pointwise_conv_kernel_size (int): Kernel size of the depthwise convolutional layer. + depthwise_conv_kernel_size (int): The kernel size for the depthwise convolutional module. + ffmodule_dropout (float): Dropout rate for feed-forward modules. + attn_dropout (float): Dropout rate for attention mechanism. + convmodule_dropout (float): Dropout rate for the convolutional module. + """ super().__init__() self.ln1 = nn.LayerNorm(n_embd) @@ -123,7 +220,19 @@ def forward( x: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: - x = self.ln1(x) # x.shape: B, T, D + """ + Forward pass through the conformer block. + + Args: + x (torch.Tensor): Input tensor of shape (B, T, D), where B is the batch size, + T is the number of time steps, and D is the embedding dimension. + mask (torch.Tensor): Attention mask of shape (N, 1, L) or (N, L, L), where N is the batch size, + L is the sequence length. If not provided, no attention mask will be used. + + Returns: + torch.Tensor: Output tensor of shape (B, T, D). + """ + x = self.ln1(x) x = x + 0.5 * self.entry_ffmodule(x) x = x + self.attn(self.ln_mhsa(x), mask=mask) x = x + self.convmodule(x) @@ -133,6 +242,11 @@ def forward( class AudioTransformer(nn.Module): + """An audio transformer model using conformer blocks for processing audio data and generating predictions. + + This model includes convolutional layers, subsampling, positional embeddings, + and multiple conformer blocks for feature extraction and processing.""" + def __init__( self, *, @@ -150,6 +264,41 @@ def __init__( attn_dropout: float = 0.1, convmodule_dropout: float = 0.1, ): + """ + Initializes the AudioTransformer model. + + Args: + sample_key (str): The key in the input dictionary that contains the audio samples. + prediction_key (str): The key under which the model's output will be stored in the output dictionary. + block_size (int): The size of each block for positional embeddings. + n_mels (int): The number of mel-frequency bands used for input audio feature extraction. + n_embd (int): The embedding dimension used throughout the model. + n_heads (int): The number of attention heads in the conformer blocks. + n_conformer_blocks (int): The number of conformer blocks to include in the transformer model. + attention_config (AttentionConfig): Configuration object for attention mechanisms. + pointwise_conv_kernel_size (int): Kernel size for the pointwise convolutional layers in conformer blocks. + depthwise_conv_kernel_size (int): Kernel size for the depthwise convolutional layers in conformer blocks. + ffmodule_dropout (float): Dropout rate for feed-forward modules in conformer blocks. Default is 0.1. + attn_dropout (float): Dropout rate for attention mechanisms. Default is 0.1. + convmodule_dropout (float): Dropout rate for depthwise convolutional layers in conformer blocks. + Default is 0.1. + + Examples: + >>> audio_encoder_config = { + "sample_key": "audio", + "prediction_key": "audio_embeddings", + "block_size": 2000, + "n_mels": 128, + "n_embd": 768, + "n_heads": 8, + "n_conformer_blocks": 2, + "attention_config": { + "attention_engine_type": "default_attention" + }, + "pointwise_conv_kernel_size": 1, + "depthwise_conv_kernel_size": 31 + } + """ super().__init__() self.sample_key = sample_key self.prediction_key = prediction_key @@ -196,6 +345,17 @@ def forward( self, inputs: dict[str, tuple[torch.Tensor, torch.Tensor]], ) -> dict[str, tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass of the AudioTransformer model. + + Args: + inputs (dict[str, tuple[torch.Tensor, torch.Tensor]]): A dictionary containing the input tensors. + It must include the key specified by `sample_key`. + + Returns: + dict[str, tuple[torch.Tensor, torch.Tensor]]: A dictionary with a single key specified by `prediction_key`,\ + containing the model's output. + """ x = inputs[self.sample_key] # x.shape: B, T, D attn_key_mask = self._get_attn_key_mask(inputs["audio_len"]) # x.shape: B, T, D @@ -212,7 +372,8 @@ def forward( def _get_attn_key_mask( self, lengths: torch.Tensor, - ): + ) -> torch.Tensor: + # Generates an attention key mask based on input sequence lengths. return ( torch.nn.utils.rnn.pad_sequence( [torch.ones(length, self.block_size) for length in lengths] From 61cfe51aa49eb87884949789d6580c4725b0b710 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Wed, 25 Sep 2024 08:42:45 +0200 Subject: [PATCH 140/161] test: update coca model test and add coca collator test --- tests/models/coca/coca_config_aud_vid.yaml | 68 +++++++++ tests/models/coca/coca_config_audio.yaml | 24 +-- tests/models/coca/coca_config_av.yaml | 59 -------- ...fig_vision.yaml => coca_config_image.yaml} | 20 ++- .../models/coca/coca_config_img_aud_vid.yaml | 87 +++++++++++ tests/models/coca/coca_config_video.yaml | 49 ++++++ tests/models/coca/test_coca.py | 121 +++++++++------ tests/models/coca/test_collator.py | 142 ++++++++++++++++++ .../vision_transformer_config2.yaml | 15 -- 9 files changed, 448 insertions(+), 137 deletions(-) create mode 100644 tests/models/coca/coca_config_aud_vid.yaml delete mode 100644 tests/models/coca/coca_config_av.yaml rename tests/models/coca/{coca_config_vision.yaml => coca_config_image.yaml} (72%) create mode 100644 tests/models/coca/coca_config_img_aud_vid.yaml create mode 100644 tests/models/coca/coca_config_video.yaml create mode 100644 tests/models/coca/test_collator.py delete mode 100644 tests/models/vision_transformer/vision_transformer_config2.yaml diff --git a/tests/models/coca/coca_config_aud_vid.yaml b/tests/models/coca/coca_config_aud_vid.yaml new file mode 100644 index 000000000..d6b14d305 --- /dev/null +++ b/tests/models/coca/coca_config_aud_vid.yaml @@ -0,0 +1,68 @@ +prediction_key: logits +audio_embd_prediction_key: audio_embeddings +video_embd_prediction_key: video_embeddings +text_embd_prediction_key: text_embeddings +audio_cls_prediction_key: audio_cls +audio_text_cls_prediction_key: audio_text_cls +video_cls_prediction_key: video_cls +video_text_cls_prediction_key: video_text_cls +text_cls_prediction_key: text_cls +modality_keys: + - images + - audio + - audio_len + - video + - input_ids +is_audio_video: true +individual_datasets: false +logit_scale_prediction_key: logit_scale +audio_encoder_config: + sample_key: audio + prediction_key: audio_embeddings + block_size: 500 + n_mels: 128 + n_embd: 768 + n_heads: 4 + n_conformer_blocks: 3 + attention_config: + attention_engine_type: default_attention + pointwise_conv_kernel_size: 1 + depthwise_conv_kernel_size: 31 +video_encoder_config: + sample_key: video + prediction_key: video_embeddings + img_size: 224 + n_classes: Null + n_layer: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 8 + n_embd: 768 + dropout: 0.0 + patch_size: 16 + patch_stride: 16 + n_img_channels: 3 + add_cls_token: False + bias: True + num_video_frames: 16 + n_latents: 64 +text_decoder_config: + sample_key: input_ids + prediction_key: text_embeddings + block_size: 1024 + vocab_size: 50304 + n_layer_text: 6 + n_layer_multimodal_text: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: swiglu + epsilon: 1e-5 +n_pool_head: 12 +n_queries: 256 +bias_attn_pool: False +epsilon_attn_pool: 1e-5 diff --git a/tests/models/coca/coca_config_audio.yaml b/tests/models/coca/coca_config_audio.yaml index 3ba12f113..0b28b8b5e 100644 --- a/tests/models/coca/coca_config_audio.yaml +++ b/tests/models/coca/coca_config_audio.yaml @@ -1,12 +1,18 @@ prediction_key: logits -modality_key: modality -modality_embd_prediction_key: modality_embeddings +audio_embd_prediction_key: audio_embeddings text_embd_prediction_key: text_embeddings -modality_cls_prediction_key: modality_cls +audio_cls_prediction_key: audio_cls text_cls_prediction_key: text_cls -modality_encoder_config: +modality_keys: + - audio + - audio_len + - input_ids +is_audio_video: false +individual_datasets: false +logit_scale_prediction_key: logit_scale +audio_encoder_config: sample_key: audio - prediction_key: modality_embeddings + prediction_key: audio_embeddings block_size: 500 n_mels: 128 n_embd: 768 @@ -30,13 +36,9 @@ text_decoder_config: n_embd: 768 dropout: 0.0 bias: true - activation: fused_swiglu + activation: swiglu epsilon: 1e-5 n_pool_head: 8 -n_vision_queries: Null -n_audio_queries: 256 +n_queries: 256 bias_attn_pool: False epsilon_attn_pool: 1e-5 -weight_init: - mean: 0.0 - std: 0.02 \ No newline at end of file diff --git a/tests/models/coca/coca_config_av.yaml b/tests/models/coca/coca_config_av.yaml deleted file mode 100644 index 466277a9e..000000000 --- a/tests/models/coca/coca_config_av.yaml +++ /dev/null @@ -1,59 +0,0 @@ -prediction_key: logits -modality_key: modality -modality_embd_prediction_key: modality_embeddings -text_embd_prediction_key: text_embeddings -modality_cls_prediction_key: modality_cls -text_cls_prediction_key: text_cls -modality_encoder_config: - vision_transformer_config: - sample_key: images - prediction_key: modality_embeddings - img_size: 224 - n_classes: Null # Disable vision transformer head - n_layer: 12 - attention_config: - attention_engine_type: default_attention - n_head: 12 - n_embd: 768 - dropout: 0.0 - patch_size: 16 - patch_stride: 16 - n_img_channels: 3 - add_cls_token: False - bias: True - audio_transformer_config: - sample_key: audio - prediction_key: modality_embeddings - block_size: 500 - n_mels: 128 - n_embd: 768 - n_heads: 4 - n_conformer_blocks: 3 - attention_config: - attention_engine_type: default_attention - pointwise_conv_kernel_size: 1 - depthwise_conv_kernel_size: 31 -text_decoder_config: - sample_key: input_ids - prediction_key: text_embeddings - block_size: 1_024 - vocab_size: 50_304 - n_layer_text: 6 - n_layer_multimodal_text: 6 - attention_config: - attention_engine_type: pytorch_flash_attention - n_head: 12 - ffn_hidden: 2_048 - n_embd: 768 - dropout: 0.0 - bias: true - activation: fused_swiglu - epsilon: 1e-5 -n_pool_head: 8 -n_vision_queries: 256 -n_audio_queries: 256 -bias_attn_pool: False -epsilon_attn_pool: 1e-5 -weight_init: - mean: 0.0 - std: 0.02 \ No newline at end of file diff --git a/tests/models/coca/coca_config_vision.yaml b/tests/models/coca/coca_config_image.yaml similarity index 72% rename from tests/models/coca/coca_config_vision.yaml rename to tests/models/coca/coca_config_image.yaml index 38b0864ec..21d6318f4 100644 --- a/tests/models/coca/coca_config_vision.yaml +++ b/tests/models/coca/coca_config_image.yaml @@ -1,12 +1,17 @@ prediction_key: logits -modality_key: modality -modality_embd_prediction_key: modality_embeddings +image_embd_prediction_key: image_embeddings text_embd_prediction_key: text_embeddings -modality_cls_prediction_key: modality_cls +image_cls_prediction_key: image_cls text_cls_prediction_key: text_cls -modality_encoder_config: +modality_keys: + - images + - input_ids +is_audio_video: false +individual_datasets: false +logit_scale_prediction_key: logit_scale +image_encoder_config: sample_key: images - prediction_key: modality_embeddings + prediction_key: image_embeddings img_size: 224 n_classes: Null # Disable vision transformer head n_layer: 6 @@ -37,7 +42,6 @@ text_decoder_config: activation: swiglu epsilon: 1e-5 n_pool_head: 8 -n_vision_queries: 256 -n_audio_queries: Null +n_queries: 256 bias_attn_pool: False -epsilon_attn_pool: 1e-5 \ No newline at end of file +epsilon_attn_pool: 1e-5 diff --git a/tests/models/coca/coca_config_img_aud_vid.yaml b/tests/models/coca/coca_config_img_aud_vid.yaml new file mode 100644 index 000000000..bcb2b9d51 --- /dev/null +++ b/tests/models/coca/coca_config_img_aud_vid.yaml @@ -0,0 +1,87 @@ +prediction_key: logits +audio_embd_prediction_key: audio_embeddings +image_embd_prediction_key: image_embeddings +video_embd_prediction_key: video_embeddings +text_embd_prediction_key: text_embeddings +image_cls_prediction_key: image_cls +image_text_cls_prediction_key: image_text_cls +audio_cls_prediction_key: audio_cls +audio_text_cls_prediction_key: audio_text_cls +video_cls_prediction_key: video_cls +video_text_cls_prediction_key: video_text_cls +text_cls_prediction_key: text_cls +modality_keys: + - images + - audio + - audio_len + - video + - input_ids +is_audio_video: false +individual_datasets: true +logit_scale_prediction_key: logit_scale +audio_encoder_config: + sample_key: audio + prediction_key: audio_embeddings + block_size: 500 + n_mels: 128 + n_embd: 768 + n_heads: 4 + n_conformer_blocks: 3 + attention_config: + attention_engine_type: default_attention + pointwise_conv_kernel_size: 1 + depthwise_conv_kernel_size: 31 +image_encoder_config: + sample_key: images + prediction_key: image_embeddings + img_size: 224 + n_classes: Null # Disable vision transformer head + n_layer: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 8 + n_embd: 768 + dropout: 0.0 + patch_size: 16 + patch_stride: 16 + n_img_channels: 3 + add_cls_token: False + bias: True +video_encoder_config: + sample_key: video + prediction_key: video_embeddings + img_size: 224 + n_classes: Null + n_layer: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 8 + n_embd: 768 + dropout: 0.0 + patch_size: 16 + patch_stride: 16 + n_img_channels: 3 + add_cls_token: False + bias: True + num_video_frames: 16 + n_latents: 64 +text_decoder_config: + sample_key: input_ids + prediction_key: text_embeddings + block_size: 1024 + vocab_size: 50304 + n_layer_text: 6 + n_layer_multimodal_text: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: swiglu + epsilon: 1e-5 +n_pool_head: 12 +n_queries: 256 +bias_attn_pool: False +epsilon_attn_pool: 1e-5 diff --git a/tests/models/coca/coca_config_video.yaml b/tests/models/coca/coca_config_video.yaml new file mode 100644 index 000000000..aa2b45576 --- /dev/null +++ b/tests/models/coca/coca_config_video.yaml @@ -0,0 +1,49 @@ +prediction_key: logits +video_embd_prediction_key: video_embeddings +text_embd_prediction_key: text_embeddings +video_cls_prediction_key: video_cls +text_cls_prediction_key: text_cls +modality_keys: + - video + - input_ids +is_audio_video: false +individual_datasets: false +logit_scale_prediction_key: logit_scale +video_encoder_config: + sample_key: video + prediction_key: video_embeddings + img_size: 224 + n_classes: Null # Disable vision transformer head + n_layer: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 8 + n_embd: 768 + dropout: 0.0 + patch_size: 16 + patch_stride: 16 + n_img_channels: 3 + add_cls_token: False + bias: True + num_video_frames: 16 + n_latents: 64 +text_decoder_config: + sample_key: input_ids + prediction_key: text_embeddings + block_size: 1024 + vocab_size: 50304 + n_layer_text: 6 + n_layer_multimodal_text: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: swiglu + epsilon: 1e-5 +n_pool_head: 8 +n_queries: 256 +bias_attn_pool: False +epsilon_attn_pool: 1e-5 diff --git a/tests/models/coca/test_coca.py b/tests/models/coca/test_coca.py index 9a85e6697..5d2d523b8 100644 --- a/tests/models/coca/test_coca.py +++ b/tests/models/coca/test_coca.py @@ -21,42 +21,85 @@ N_IMAGE_CLASSES = 1_000 IMG_SIZE = 224 N_IMG_CHANNELS = 3 +N_FRAMES = 16 # audio_transformer_config AUDIO_BLOCK_SIZE = 500 N_MELS = 128 SUB_SAMPLING_FACTOR = 4 +BATCH_SIZE = 2 + def dummy_image_sample(): - input_image = torch.randn(1, N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) - input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (1, TEXT_DECODER_BLOCK_SIZE)) - VISION = torch.tensor([1]) + input_image = torch.randn(BATCH_SIZE, N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (BATCH_SIZE, TEXT_DECODER_BLOCK_SIZE)) return dict( images=input_image, input_ids=input_text, - modality=VISION, + ) + + +def dummy_video_sample(): + input_video = torch.randn(BATCH_SIZE, N_FRAMES, N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (BATCH_SIZE, TEXT_DECODER_BLOCK_SIZE)) + return dict( + video=input_video, + input_ids=input_text, ) def dummy_audio_sample(): - audio_features = torch.randn(1, AUDIO_BLOCK_SIZE * SUB_SAMPLING_FACTOR, N_MELS) + audio_features = torch.randn(BATCH_SIZE, AUDIO_BLOCK_SIZE * SUB_SAMPLING_FACTOR, N_MELS) audio_len = torch.tensor([N_IMAGE_CLASSES / SUB_SAMPLING_FACTOR]).type(torch.int16) - input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (1, TEXT_DECODER_BLOCK_SIZE)) - AUDIO = torch.tensor([0]) + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (BATCH_SIZE, TEXT_DECODER_BLOCK_SIZE)) return dict( audio=audio_features, - feats_len=audio_len, + audio_len=audio_len, + input_ids=input_text, + ) + + +def dummy_img_aud_vid_sample(): + # separate image, audio, and video datasets + input_image = torch.randn(BATCH_SIZE, N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) + audio_features = torch.randn(BATCH_SIZE, AUDIO_BLOCK_SIZE * SUB_SAMPLING_FACTOR, N_MELS) + audio_len = torch.tensor([N_IMAGE_CLASSES / SUB_SAMPLING_FACTOR]).type(torch.int16) + input_video = torch.randn(BATCH_SIZE, N_FRAMES, N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) + + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (BATCH_SIZE * 3, TEXT_DECODER_BLOCK_SIZE)) + return dict( + images=input_image, + audio=audio_features, + audio_len=audio_len, + video=input_video, + input_ids=input_text, + ) + + +def dummy_aud_vid_sample(): + # single video dataset which contains audio + audio_features = torch.randn(BATCH_SIZE, AUDIO_BLOCK_SIZE * SUB_SAMPLING_FACTOR, N_MELS) + audio_len = torch.tensor([N_IMAGE_CLASSES / SUB_SAMPLING_FACTOR]).type(torch.int16) + input_video = torch.randn(BATCH_SIZE, N_FRAMES, N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) + + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (BATCH_SIZE, TEXT_DECODER_BLOCK_SIZE)) + return dict( + audio=audio_features, + audio_len=audio_len, + video=input_video, input_ids=input_text, - modality=AUDIO, ) @pytest.mark.parametrize( "yaml,dummy_sample", [ - ("tests/models/coca/coca_config_vision.yaml", dummy_image_sample()), + ("tests/models/coca/coca_config_image.yaml", dummy_image_sample()), ("tests/models/coca/coca_config_audio.yaml", dummy_audio_sample()), + ("tests/models/coca/coca_config_video.yaml", dummy_video_sample()), + ("tests/models/coca/coca_config_img_aud_vid.yaml", dummy_img_aud_vid_sample()), + ("tests/models/coca/coca_config_aud_vid.yaml", dummy_aud_vid_sample()), ], ) def test_coca(yaml, dummy_sample): @@ -77,40 +120,30 @@ def test_coca(yaml, dummy_sample): optimizer.step() # Test outputs - assert "logits" in out - assert "modality_cls" in out - assert "text_cls" in out - assert out["logits"].shape == (1, TEXT_DECODER_BLOCK_SIZE, TEXT_DECODER_VOCAB_SIZE) - assert out["modality_cls"].shape == (1, N_EMBD) - assert out["text_cls"].shape == (1, N_EMBD) - - -def test_coca_audio_vision_together(): - # Create model - config_file_path = _ROOT_DIR / Path("coca/coca_config_av.yaml") - config_dict = load_app_config_dict(config_file_path=config_file_path) - coca_config = CoCaConfig.model_validate(config_dict) - model = CoCa(**dict(coca_config)) - - # Create optimizer - optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) - - audio_sample = dummy_audio_sample() - image_sample = dummy_image_sample() - - for dummy_samples in [audio_sample, image_sample]: - optimizer.zero_grad() - out = model(dummy_samples) - loss = out["logits"].sum() - loss.backward() - optimizer.step() - - assert "logits" in out - assert "modality_cls" in out - assert "text_cls" in out - assert out["logits"].shape == (1, TEXT_DECODER_BLOCK_SIZE, TEXT_DECODER_VOCAB_SIZE) - assert out["modality_cls"].shape == (1, N_EMBD) - assert out["text_cls"].shape == (1, N_EMBD) + text_output_batch_size = 0 + if coca_config.audio_encoder_config: + assert "audio_cls" in out + assert out["audio_cls"].shape == (BATCH_SIZE, N_EMBD) + if coca_config.individual_datasets: + assert out["audio_text_cls"].shape == (BATCH_SIZE, N_EMBD) + if not coca_config.is_audio_video: + text_output_batch_size += BATCH_SIZE + if coca_config.image_encoder_config: + assert "image_cls" in out + assert out["image_cls"].shape == (BATCH_SIZE, N_EMBD) + if coca_config.individual_datasets: + assert out["image_text_cls"].shape == (BATCH_SIZE, N_EMBD) + text_output_batch_size += BATCH_SIZE + if coca_config.video_encoder_config: + assert "video_cls" in out + assert out["video_cls"].shape == (BATCH_SIZE, N_EMBD) + if coca_config.individual_datasets: + assert out["video_text_cls"].shape == (BATCH_SIZE, N_EMBD) + text_output_batch_size += BATCH_SIZE + if not coca_config.individual_datasets: + assert out["text_cls"].shape == (BATCH_SIZE, N_EMBD) + assert out["logits"].shape == (text_output_batch_size, TEXT_DECODER_BLOCK_SIZE, TEXT_DECODER_VOCAB_SIZE) + assert "logit_scale" in out @pytest.mark.skip( diff --git a/tests/models/coca/test_collator.py b/tests/models/coca/test_collator.py new file mode 100644 index 000000000..af342c9d6 --- /dev/null +++ b/tests/models/coca/test_collator.py @@ -0,0 +1,142 @@ +import pytest +import torch + +from modalities.models.coca.collator import CoCaCollatorFn + +# shared config +N_EMBD = 768 + +# text_decoder_config +TEXT_DECODER_VOCAB_SIZE = 50_304 +TEXT_DECODER_BLOCK_SIZE = 100 + +# vision_transformer_config +N_IMAGE_CLASSES = 1_000 +IMG_SIZE = 224 +N_IMG_CHANNELS = 3 +N_FRAMES = 16 + +# audio_transformer_config +AUDIO_BLOCK_SIZE = 500 +N_MELS = 128 +SUB_SAMPLING_FACTOR = 4 + + +def dummy_image_sample(): + input_image = torch.randn(N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (TEXT_DECODER_BLOCK_SIZE,)) + attn_mask = torch.randint(0, 2, (TEXT_DECODER_BLOCK_SIZE,)) + return dict( + images=input_image, + input_ids=input_text, + attention_mask=attn_mask, + ) + + +def dummy_video_sample(): + input_video = torch.randn(N_FRAMES, N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (TEXT_DECODER_BLOCK_SIZE,)) + attn_mask = torch.randint(0, 2, (TEXT_DECODER_BLOCK_SIZE,)) + return dict( + video=input_video, + input_ids=input_text, + attention_mask=attn_mask, + ) + + +def dummy_audio_sample(): + audio_features = torch.randn(AUDIO_BLOCK_SIZE * SUB_SAMPLING_FACTOR, N_MELS) + audio_len = torch.tensor([N_IMAGE_CLASSES / SUB_SAMPLING_FACTOR]).type(torch.int16) + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (TEXT_DECODER_BLOCK_SIZE,)) + attn_mask = torch.randint(0, 2, (TEXT_DECODER_BLOCK_SIZE,)) + return dict( + audio=audio_features, + audio_len=audio_len, + input_ids=input_text, + attention_mask=attn_mask, + ) + + +@pytest.mark.parametrize( + "modality_sequence", + [ + ("iiiii"), + ("aaaaa"), + ("vvvvv"), + ("iiaav"), + ("iaiav"), + ("iviaa"), + ("iaiavaivaiiiiaaaviaa"), + ], +) +def test_collator(modality_sequence): + sample_keys = ["input_ids"] + target_keys = [] + text_sample_key = "input_ids" + text_target_key = "target_ids" + + num_image = modality_sequence.count("i") + num_audio = modality_sequence.count("a") + num_video = modality_sequence.count("v") + + # sample_keys in the order: images, audio, video + if num_image: + sample_keys.append("images") + if num_audio: + sample_keys.append("audio") + sample_keys.append("audio_len") + if num_video: + sample_keys.append("video") + + # create samples + image_samples = [] + for idx in range(num_image): + image_samples.append(dummy_image_sample()) + audio_samples = [] + for idx in range(num_audio): + audio_samples.append(dummy_audio_sample()) + + video_samples = [] + for idx in range(num_video): + video_samples.append(dummy_video_sample()) + + modality_samples = {"images": image_samples, "audio": audio_samples, "video": video_samples} + + collate_fn = CoCaCollatorFn(sample_keys, target_keys, text_sample_key, text_target_key) + + batch = [] + image_idx = 0 + video_idx = 0 + audio_idx = 0 + # create the batch according to the specified modality sequence + for ch in modality_sequence: + if ch == "i": + batch.append(image_samples[image_idx]) + image_idx += 1 + if ch == "a": + batch.append(audio_samples[audio_idx]) + audio_idx += 1 + if ch == "v": + batch.append(video_samples[video_idx]) + video_idx += 1 + + dataset_batch = collate_fn(batch) + + batch_idx = 0 + + # regardless of the order of the modality sequence, + # the batch (esp. input_ids and target_ids) should be in the same order as sample_keys + # i.e. batch.samples['input_ids'] = [*image input_ids, *audio_input_ids, *video_input_ids] + for modality_key in sample_keys: + if modality_key in ["audio_len", "input_ids"]: + continue + if modality_key in dataset_batch.samples: + for modality_idx, gt_sample in enumerate(modality_samples[modality_key]): + assert torch.equal(gt_sample[modality_key], dataset_batch.samples[modality_key][modality_idx]) + assert torch.equal(gt_sample["input_ids"][:-1], dataset_batch.samples[text_sample_key][batch_idx]) + assert torch.equal(gt_sample["input_ids"][1:], dataset_batch.targets[text_target_key][batch_idx]) + assert torch.equal(gt_sample["attention_mask"][:-1], dataset_batch.samples["attention_mask"][batch_idx]) + assert torch.equal(gt_sample["attention_mask"][1:], dataset_batch.targets["attention_mask"][batch_idx]) + if modality_key == "audio": + assert torch.equal(gt_sample["audio_len"], dataset_batch.samples["audio_len"][modality_idx]) + batch_idx += 1 diff --git a/tests/models/vision_transformer/vision_transformer_config2.yaml b/tests/models/vision_transformer/vision_transformer_config2.yaml deleted file mode 100644 index 7951ec737..000000000 --- a/tests/models/vision_transformer/vision_transformer_config2.yaml +++ /dev/null @@ -1,15 +0,0 @@ -sample_key: videos -prediction_key: logits -img_size: 224 -n_classes: 1000 -n_layer: 6 -n_head: 8 -n_embd: 768 -dropout: 0.0 -patch_size: 16 -patch_stride: 16 -n_img_channels: 3 -add_cls_token: True -bias: True -num_video_frames: 16 -n_latents: 64 From 56ea087f7dd0e2327fefad4fc9acf589fedfb74b Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Wed, 25 Sep 2024 08:44:35 +0200 Subject: [PATCH 141/161] refactor: verify correctness of coca model config --- src/modalities/models/coca/coca_model.py | 27 ++++++++++++++++++------ src/modalities/models/coca/collator.py | 7 +----- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index d6ca7debf..dd8697481 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -75,8 +75,8 @@ class CoCaConfig(BaseModel): prediction_key: str = "logits" text_embd_prediction_key: str - text_cls_prediction_key: str logit_scale_prediction_key: str + text_cls_prediction_key: Optional[str] = None audio_embd_prediction_key: Optional[str] = None image_embd_prediction_key: Optional[str] = None video_embd_prediction_key: Optional[str] = None @@ -115,8 +115,8 @@ def __init__( self, prediction_key: str, text_embd_prediction_key: str, - text_cls_prediction_key: str, logit_scale_prediction_key: str, + text_cls_prediction_key: Optional[str], audio_embd_prediction_key: Optional[str], image_embd_prediction_key: Optional[str], video_embd_prediction_key: Optional[str], @@ -161,14 +161,27 @@ def __init__( None """ weight_decay_groups = { - "linear": ["attention", "\.attn", "\.cross_attn", "\.post_subsampler", "_ffmodule", "mlp"], - "conv": ["embedding_fn\.conv", "project", "\.subsampler", "pointwise_conv", "depthwise_conv"], - "embedding": ["wte", "wpe", "positional_embedding", "time_embd"], - "norm": ["norm", "\.ln_", "\.ln", "\.bn", "exit_ln"], - "parameter": ["_queries", "logit_scale", "\.latents", "cls_token"], + "linear": [r"attention", r"\.attn", r"\.cross_attn", r"\.post_subsampler", r"_ffmodule", r"mlp"], + "conv": [r"embedding_fn\.conv", r"project", r"\.subsampler", r"pointwise_conv", r"depthwise_conv"], + "embedding": [r"wte", r"wpe", r"positional_embedding", r"time_embd"], + "norm": [r"norm", r"\.ln_", r"\.ln", r"\.bn", r"exit_ln"], + "parameter": [r"_queries", r"logit_scale", r"\.latents", r"cls_token"], } super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) + if individual_datasets: + if ( + not audio_text_cls_prediction_key + and not image_text_cls_prediction_key + and not video_text_cls_prediction_key + ): + raise ValueError("All text_cls_prediction_keys cannot be None") + else: + if not text_cls_prediction_key: + raise ValueError("text_cls_prediction key cannot be None") + if not audio_encoder_config and not image_encoder_config and not video_encoder_config: + raise ValueError("Atleast one modality encoder config should be specified") + self.prediction_key = prediction_key self.text_embd_prediction_key = text_embd_prediction_key self.logit_scale_prediction_key = logit_scale_prediction_key diff --git a/src/modalities/models/coca/collator.py b/src/modalities/models/coca/collator.py index d77b7f481..c476c6831 100644 --- a/src/modalities/models/coca/collator.py +++ b/src/modalities/models/coca/collator.py @@ -107,12 +107,7 @@ def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: samples[self.text_sample_key] = torch.cat([text_samples[sample_key] for sample_key in text_samples]) samples["attention_mask"] = torch.cat([attention_masks[sample_key] for sample_key in attention_masks]) - ## TODO: this will not work when there is data from multiple datasets per batch - targets = { - target_key: torch.stack([self._prepare_sample(d[target_key]) for d in batch]) - for target_key in self.target_keys - } - + targets = {} # Create target for text input targets[self.text_target_key] = samples[self.text_sample_key][:, 1:].clone().detach() samples[self.text_sample_key] = samples[self.text_sample_key][:, :-1] From 8f5aea4462513d976907584ac865bce72818e553 Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Tue, 24 Sep 2024 13:49:37 +0200 Subject: [PATCH 142/161] feat: multiple loss functions --- .../config_coca_img_aud_vid_dataset.yaml | 32 +++++----- src/modalities/config/config.py | 6 ++ src/modalities/loss_functions.py | 59 +++++++++++++++++++ src/modalities/registry/components.py | 4 +- src/modalities/trainer.py | 21 ++++++- 5 files changed, 106 insertions(+), 16 deletions(-) diff --git a/config_files/training/config_coca_img_aud_vid_dataset.yaml b/config_files/training/config_coca_img_aud_vid_dataset.yaml index 59b6c0dd9..6661ea57a 100644 --- a/config_files/training/config_coca_img_aud_vid_dataset.yaml +++ b/config_files/training/config_coca_img_aud_vid_dataset.yaml @@ -326,12 +326,10 @@ checkpoint_saving: captioning_loss: component_key: loss - variant_key: cross_entropy_loss + variant_key: clm_cross_entropy_loss config: target_key: ${settings.referencing_keys.target_key} prediction_key: ${model_raw.config.prediction_key} - tag: captioning_loss - weight: 2.0 contrastive_loss_audio: component_key: loss @@ -342,7 +340,6 @@ contrastive_loss_audio: - ${model_raw.config.audio_text_cls_prediction_key} logit_scale_key: ${model_raw.config.logit_scale_prediction_key} tag: contrastive_loss_audio - weight: 1.0 contrastive_loss_image: component_key: loss @@ -353,7 +350,6 @@ contrastive_loss_image: - ${model_raw.config.image_text_cls_prediction_key} logit_scale_key: ${model_raw.config.logit_scale_prediction_key} tag: contrastive_loss_image - weight: 1.0 contrastive_loss_video: component_key: loss @@ -364,17 +360,25 @@ contrastive_loss_video: - ${model_raw.config.video_text_cls_prediction_key} logit_scale_key: ${model_raw.config.logit_scale_prediction_key} tag: contrastive_loss_image - weight: 1.0 loss_fn: - - instance_key: captioning_loss - pass_type: BY_REFERENCE - - instance_key: contrastive_loss_audio - pass_type: BY_REFERENCE - - instance_key: contrastive_loss_image - pass_type: BY_REFERENCE - - instance_key: contrastive_loss_video - pass_type: BY_REFERENCE + component_key: loss + variant_key: multiple_functions_loss + config: + losses: + - instance_key: captioning_loss + pass_type: BY_REFERENCE + - instance_key: contrastive_loss_audio + pass_type: BY_REFERENCE + - instance_key: contrastive_loss_image + pass_type: BY_REFERENCE + - instance_key: contrastive_loss_video + pass_type: BY_REFERENCE + corrsp_weights: + - 2.0 + - 1.0 + - 1.0 + - 1.0 wrapped_model: component_key: model diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 92e24ee17..24ce369ab 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -81,6 +81,12 @@ class ClipLossConfig(BaseModel): local_loss: bool = True tag: str = "ClipLoss" + +class MultipleFunctionsLossConfig(BaseModel): + losses: list + corrsp_weights: list + + # Checkpointing class SaveEveryKStepsCheckpointingStrategyConfig(BaseModel): k: PositiveInt diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index d9d2d244e..36d6c5ee2 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -26,6 +26,65 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: raise NotImplementedError +class MultipleFunctionsLoss(Loss): + """Loss objects of this type use more + than one loss function and weights corresponding + to the losses to compute total loss. + """ + + def __init__( + self, + losses: list[Loss], + corrsp_weights: list[float], + tag: str = "MultipleFunctionsLoss", + ) -> None: + """MultipleFunctionsLoss Constructor + + Args: + losses (list): Initialized losses. This list should contain more than one loss. + corrsp_weights (list): Weights to be multiplied to each loss while summing up. + + Returns: + None + """ + super().__init__(tag) + + if len(losses) <= 1: + raise ValueError("Number of losses used should be more than 1.") + + self.groups = [(loss_func, weight) for loss_func, weight in zip(losses, corrsp_weights, strict=True)] + + self.cumulated_individual_losses = None + # variable storing each loss, + # summed over local batches, + # separately. + + self.reset_cumulated_individual_losses() + + def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: + device = forward_batch.predictions[list(forward_batch.predictions.keys())[0]].device + total_loss = torch.tensor(0, dtype=torch.float, device=device) + for ind, (loss_func, weight) in enumerate(self.groups): + loss = loss_func(forward_batch) + self.cumulated_individual_losses[ind] += loss + total_loss += weight * loss + return total_loss + + def reset_cumulated_individual_losses( + self, + ) -> None: + """Initializes and resets the variable + accumulating each loss separately. + + Called first when the class is initialized, and then + after every logging step in trainer.py. + """ + if torch.cuda.is_available(): + self.cumulated_individual_losses = torch.zeros(len(self.groups)).to(torch.device("cuda")) + else: + self.cumulated_individual_losses = torch.zeros(len(self.groups)).to("cpu") + + class CLMCrossEntropyLoss(Loss): def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEntropyLoss"): super().__init__(tag) diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 8c48ceb1d..4645abd48 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -36,6 +36,7 @@ GPT2LLMCollateFnConfig, LLMDataLoaderConfig, MemMapDatasetConfig, + MultipleFunctionsLossConfig, NCELossConfig, OneCycleLRSchedulerConfig, PackedMemMapDatasetContinuousConfig, @@ -75,7 +76,7 @@ ProgressSubscriberFactory, ResultsSubscriberFactory, ) -from modalities.loss_functions import ClipLoss, CLMCrossEntropyLoss, NCELoss +from modalities.loss_functions import ClipLoss, CLMCrossEntropyLoss, MultipleFunctionsLoss, NCELoss from modalities.models.coca.coca_model import CoCa, CoCaConfig from modalities.models.coca.collator import CoCaCollateFnConfig, CoCaCollatorFn from modalities.models.components.layer_norms import LayerNormConfig, RMSLayerNorm, RMSLayerNormConfig @@ -162,6 +163,7 @@ class ComponentEntity: ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig), ComponentEntity("loss", "nce_loss", NCELoss, NCELossConfig), ComponentEntity("loss", "clip_loss", ClipLoss, ClipLossConfig), + ComponentEntity("loss", "multiple_functions_loss", MultipleFunctionsLoss, MultipleFunctionsLossConfig), # optmizers ComponentEntity("optimizer", "adam", OptimizerFactory.get_adam, AdamOptimizerConfig), ComponentEntity("optimizer", "adam_w", OptimizerFactory.get_adam_w, AdamWOptimizerConfig), diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index e15e47c08..798b569a0 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -12,7 +12,7 @@ from modalities.dataloader.dataloader import LLMDataLoader from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate from modalities.logging_broker.publisher import MessagePublisher -from modalities.loss_functions import Loss +from modalities.loss_functions import Loss, MultipleFunctionsLoss from modalities.models.model import model_predict_batch from modalities.running_env.fsdp.reducer import Reducer from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF @@ -259,6 +259,25 @@ def train( "train loss last": ResultItem(train_loss_last_batch, decimal_places=2), } + # If there are multiple loss functions being used, + # this block computes and logs all the individual + # losses, averaged over the global batch size. + if isinstance(loss_fun, MultipleFunctionsLoss): + global_batch_size = Reducer.reduce( + tensor=cumulated_losses[-1], operation=dist.ReduceOp.SUM, post_processing_fun=None + ) + reduced_individual_losses = Reducer.reduce( + tensor=loss_fun.cumulated_individual_losses, + operation=dist.ReduceOp.SUM, + post_processing_fun=lambda t: torch.stack( + [t[ind] / global_batch_size for ind in range(len(t))] + ), + ) + for ind, (loss, _) in enumerate(loss_fun.groups): + losses[f"train {loss.tag} avg"] = ResultItem(reduced_individual_losses[ind], decimal_places=2) + + loss_fun.reset_cumulated_individual_losses() + consumed_tokens = torch.Tensor([training_progress.num_seen_tokens_total]) metrics = { "consumed tokens": ResultItem(consumed_tokens, 0), From c291fcca8fbb51ca097568033c633698e89c56fc Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Fri, 27 Sep 2024 15:53:55 +0200 Subject: [PATCH 143/161] test: add more tests for loss functions --- tests/test_loss_functions.py | 102 ++++++++++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 1 deletion(-) diff --git a/tests/test_loss_functions.py b/tests/test_loss_functions.py index 8825f15c3..86b6c0d7c 100644 --- a/tests/test_loss_functions.py +++ b/tests/test_loss_functions.py @@ -2,7 +2,7 @@ import torch from modalities.batch import InferenceResultBatch -from modalities.loss_functions import NCELoss, nce_loss +from modalities.loss_functions import ClipLoss, CLMCrossEntropyLoss, MultipleFunctionsLoss, NCELoss, nce_loss @pytest.fixture @@ -36,3 +36,103 @@ def test_nce_loss_correctness(embedding1, embedding2): bidirectional_loss = nce_loss(embedding1, embedding2, device="cpu", is_asymmetric=False, temperature=1.0) assert unidirectional_loss == pytest.approx(1.1300, 0.0001) assert bidirectional_loss == pytest.approx(2.2577, 0.0001) + + +@pytest.fixture +def clm_cross_entropy_loss_object() -> CLMCrossEntropyLoss: + return CLMCrossEntropyLoss(target_key="target_ids", prediction_key="logits") + + +@pytest.fixture +def clip_loss_object() -> ClipLoss: + return ClipLoss( + logit_scale_key="logit_scale", + prediction_keys=["image_cls", "image_text_cls"], + local_loss=False, + ) + + +@pytest.fixture +def clip_loss_forward_batch() -> InferenceResultBatch: + # BATCH SIZE, LENGTH OF SEQUENCE, EMBEDDING SIZE + predictions = { + "image_cls": torch.Tensor([[1, 2, 3], [4, 5, 6]]).to("cuda"), + "image_text_cls": torch.Tensor([[7, 8, 9], [10, 11, 12]]).to("cuda"), + "logit_scale": 0.07, + } + return InferenceResultBatch(targets={}, predictions=predictions) + + +@pytest.fixture +def setup_distributed(monkeypatch): + import torch.distributed as dist + + monkeypatch.setenv("RANK", "0") + monkeypatch.setenv("LOCAL_RANK", "0") + monkeypatch.setenv("WORLD_SIZE", "1") + monkeypatch.setenv("MASTER_ADDR", "localhost") + monkeypatch.setenv("MASTER_PORT", "9948") + + dist.init_process_group(backend="nccl") + yield + dist.destroy_process_group() + + +def test_clip_loss(clip_loss_object, clip_loss_forward_batch, setup_distributed): + + loss_fn = clip_loss_object + forward_batch = clip_loss_forward_batch + loss_fn(clip_loss_forward_batch) + + +@pytest.fixture +def multiple_functions_loss_object_with_two_losses( + clm_cross_entropy_loss_object, clip_loss_object +) -> MultipleFunctionsLoss: + return MultipleFunctionsLoss( + [clm_cross_entropy_loss_object, clip_loss_object], + corrsp_weights=[1.0, 1.0], + ) + + +def test_multiple_functions_loss_initialized_with_single_loss( + clm_cross_entropy_loss_object, +): + with pytest.raises(ValueError, match="Number of losses used should be more than 1."): + MultipleFunctionsLoss([clm_cross_entropy_loss_object], corrsp_weights=[1.0]) + + +def test_multiple_functions_loss_reset_cumulated_individual_losses( + multiple_functions_loss_object_with_two_losses, +): + + loss = multiple_functions_loss_object_with_two_losses + num_losses = len(loss.groups) + loss.cumulated_individual_losses = torch.randn(num_losses) + loss.reset_cumulated_individual_losses() + + assert (loss.cumulated_individual_losses, torch.zeros(num_losses)) + + +@pytest.fixture +def multiple_functions_loss_forward_batch() -> InferenceResultBatch: + + targets = {"target_ids": torch.Tensor([[1, 2, 1], [1, 1, 2]])} + predictions = { + "image_cls": torch.Tensor([[1, 2, 3], [4, 5, 6]]).to("cuda"), + "image_text_cls": torch.Tensor([[7, 8, 9], [10, 11, 12]]).to("cuda"), + "logit_scale": 0.07, + "logits": torch.Tensor( + [[[0.1, 0.2, 0.7], [0.3, 0.2, 0.5], [0.0, 0.3, 0.7]], [[0.1, 0.2, 0.7], [0.3, 0.2, 0.5], [0.0, 0.3, 0.7]]] + ), + } + + return InferenceResultBatch(targets=targets, predictions=predictions) + + +def test_multiple_functions_loss( + multiple_functions_loss_object_with_two_losses, + multiple_functions_loss_forward_batch, + setup_distributed, +): + multiple_functions_loss_object_with_two_losses(multiple_functions_loss_forward_batch) From f1dbe91112dff7a2f6bf0420e5c8af64a5953b52 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Fri, 27 Sep 2024 16:33:30 +0200 Subject: [PATCH 144/161] revert: add back default values for NCELoss --- src/modalities/loss_functions.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index 36d6c5ee2..63e5c3cd1 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -3,7 +3,6 @@ import torch import torch.distributed as dist import torch.nn.functional as F -from pydantic import BaseModel from torch.nn import CrossEntropyLoss from modalities.batch import InferenceResultBatch @@ -151,8 +150,8 @@ def __init__( self, prediction_key1: str, prediction_key2: str, - is_asymmetric: bool, - temperature: float, + is_asymmetric: bool = True, + temperature: float = 1.0, tag: str = "NCELoss", ): """ From 146682f73e54f8793a33311e829ad4d9ee1656d0 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Fri, 27 Sep 2024 16:53:27 +0200 Subject: [PATCH 145/161] refactor: use composition to wrap the pytorch DataLoader using LLMDataLoader, so that both LLMDataLoader and WebLoader inherit only from DataLoaderIF --- src/modalities/config/config.py | 10 ++-- src/modalities/config/instantiation_models.py | 6 +- src/modalities/config/pydanctic_if_types.py | 2 +- src/modalities/dataloader/dataloader.py | 57 ++++++++++++++++--- .../test_distributed_dataloader.py | 4 +- .../test_distributed_repeating_dataloader.py | 4 +- tests/dataloader/test_dataloader.py | 8 +-- tests/end2end_tests/test_fsdp_warmstart.py | 4 +- 8 files changed, 68 insertions(+), 27 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 63abfef3b..77a0419a8 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -16,9 +16,9 @@ PydanticCheckpointSavingExecutionIFType, PydanticCheckpointSavingStrategyIFType, PydanticCollateFnIFType, + PydanticDataLoaderIFType, PydanticDatasetIFType, PydanticFSDPModuleType, - PydanticLLMDataLoaderIFType, PydanticModelInitializationIFType, PydanticOptimizerIFType, PydanticPytorchDeviceType, @@ -343,7 +343,7 @@ class WebLoaderConfig(BaseModel): class RepeatingDataLoaderConfig(BaseModel): - dataloader: PydanticLLMDataLoaderIFType + dataloader: PydanticDataLoaderIFType reshuffle_after_epoch: Optional[bool] = False num_epochs: Annotated[int, Field(strict=True, ge=1)] @@ -353,15 +353,15 @@ class DummyProgressSubscriberConfig(BaseModel): class SimpleProgressSubscriberConfig(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType - eval_dataloaders: Optional[list[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) + train_dataloader: PydanticDataLoaderIFType + eval_dataloaders: Optional[list[PydanticDataLoaderIFType]] = Field(default_factory=list) world_size: int global_num_seen_samples: int local_rank: int class RichProgressSubscriberConfig(BaseModel): - eval_dataloaders: Optional[list[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) + eval_dataloaders: Optional[list[PydanticDataLoaderIFType]] = Field(default_factory=list) train_dataloader_tag: str num_seen_steps: Annotated[int, Field(strict=True, ge=0)] num_target_steps: Annotated[int, Field(strict=True, gt=0)] diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index 91a30c259..4ac4a4bf8 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -6,9 +6,9 @@ from modalities.config.pydanctic_if_types import ( PydanticCheckpointSavingIFType, + PydanticDataLoaderIFType, PydanticDatasetIFType, PydanticGradientClipperIFType, - PydanticLLMDataLoaderIFType, PydanticLossIFType, PydanticLRSchedulerIFType, PydanticMessageSubscriberIFType, @@ -170,8 +170,8 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel scheduler: PydanticLRSchedulerIFType loss_fn: PydanticLossIFType | list[PydanticLossIFType] train_dataset: PydanticDatasetIFType - train_dataloader: PydanticLLMDataLoaderIFType - eval_dataloaders: list[PydanticLLMDataLoaderIFType] + train_dataloader: PydanticDataLoaderIFType + eval_dataloaders: list[PydanticDataLoaderIFType] progress_subscriber: PydanticMessageSubscriberIFType evaluation_subscriber: PydanticMessageSubscriberIFType checkpoint_saving: PydanticCheckpointSavingIFType diff --git a/src/modalities/config/pydanctic_if_types.py b/src/modalities/config/pydanctic_if_types.py index b6c5d555e..25d879737 100644 --- a/src/modalities/config/pydanctic_if_types.py +++ b/src/modalities/config/pydanctic_if_types.py @@ -56,7 +56,7 @@ def __get_pydantic_core_schema__( PydanticDatasetIFType = Annotated[Dataset, PydanticThirdPartyTypeIF(Dataset)] PydanticSamplerIFType = Annotated[Sampler, PydanticThirdPartyTypeIF(Sampler)] PydanticCollateFnIFType = Annotated[CollateFnIF, PydanticThirdPartyTypeIF(CollateFnIF)] -PydanticLLMDataLoaderIFType = Annotated[DataLoaderIF, PydanticThirdPartyTypeIF(DataLoaderIF)] +PydanticDataLoaderIFType = Annotated[DataLoaderIF, PydanticThirdPartyTypeIF(DataLoaderIF)] PydanticOptimizerIFType = Annotated[Optimizer, PydanticThirdPartyTypeIF(Optimizer)] PydanticLRSchedulerIFType = Annotated[LRScheduler, PydanticThirdPartyTypeIF(LRScheduler)] PydanticLossIFType = Annotated[Loss, PydanticThirdPartyTypeIF(Loss)] diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index ef2cedc39..f91563583 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -1,3 +1,4 @@ +import multiprocessing from typing import Iterable, Optional import webdataset as wd @@ -11,7 +12,7 @@ class DataLoaderIF: pass -class LLMDataLoader(DataLoader[T_co], DataLoaderIF): +class LLMDataLoader(DataLoaderIF): """LLMDataLoader is a custom DataLoader class that extends the PyTorch DataLoader class.""" def __init__( @@ -62,7 +63,9 @@ def __init__( None """ assert batch_sampler is not None and batch_size == 1 - super().__init__( + self._dataloader_tag = dataloader_tag + self._batch_size = batch_sampler.batch_size + self._torch_dataloader = DataLoader( dataset=dataset, batch_size=batch_size, shuffle=False, # shuffling must be implemented on a dataset level @@ -81,9 +84,6 @@ def __init__( pin_memory_device=pin_memory_device, ) - self._dataloader_tag = dataloader_tag - self._batch_size = batch_sampler.batch_size - @property def dataloader_tag(self) -> str: """ @@ -125,6 +125,47 @@ def batch_size(self, value: int): """ self._batch_size = value + def __len__(self): + return self._torch_dataloader.__len__() + + def __iter__(self): + return self._torch_dataloader.__iter__() + + @property + def dataset(self) -> Dataset[T_co]: + return self._torch_dataloader.dataset + + @property + def batch_sampler(self) -> ResumableBatchSampler: + return self._torch_dataloader.batch_sampler + + @property + def sampler(self) -> Sampler | Iterable | None: + return self._torch_dataloader.sampler + + @property + def collate_fn(self) -> _collate_fn_t: + return self._torch_dataloader.collate_fn + + @property + def multiprocessing_context(self) -> str | multiprocessing.context.BaseContext: + return self._torch_dataloader.multiprocessing_context + + @multiprocessing_context.setter + def multiprocessing_context(self, multiprocessing_context): + self._torch_dataloader.multiprocessing_context = multiprocessing_context + + @property + def _auto_collation(self): + return self._torch_dataloader._auto_collation + + @property + def _index_sampler(self): + return self._torch_dataloader._index_sampler + + def check_worker_number_rationality(self): + return self._torch_dataloader.check_worker_number_rationality() + @property def fast_forward_batch_id(self) -> int: """ @@ -133,15 +174,15 @@ def fast_forward_batch_id(self) -> int: Returns: int: fast forward batch ID """ - return self.batch_sampler.start_index + return self._torch_dataloader.batch_sampler.start_index -class RepeatingDataLoader(LLMDataLoader[T_co]): +class RepeatingDataLoader(LLMDataLoader): """ RepeatingDataLoader is a custom DataLoader class that repeats the given dataloader for the specified number of epochs.""" - def __init__(self, dataloader: LLMDataLoader[T_co], num_epochs: int, reshuffle_after_epoch: bool = False): + def __init__(self, dataloader: LLMDataLoader, num_epochs: int, reshuffle_after_epoch: bool = False): """ Initializes a RepeatingDataLoader object that repeats the given dataloader for the specified number of epochs. This is especially useful for DataLoader types that we wish to automatically restart upon completion. diff --git a/tests/dataloader/distributed/test_distributed_dataloader.py b/tests/dataloader/distributed/test_distributed_dataloader.py index 0038d04a6..0d2b0b098 100644 --- a/tests/dataloader/distributed/test_distributed_dataloader.py +++ b/tests/dataloader/distributed/test_distributed_dataloader.py @@ -9,7 +9,7 @@ from modalities.__main__ import Main from modalities.config.config import ProcessGroupBackendType -from modalities.config.pydanctic_if_types import PydanticLLMDataLoaderIFType +from modalities.config.pydanctic_if_types import PydanticDataLoaderIFType from modalities.running_env.cuda_env import CudaEnv from tests.dataloader.dummy_sequential_dataset import TestDataset, TestDatasetConfig @@ -18,7 +18,7 @@ class DataloaderInstantiationModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType @pytest.mark.skipif( diff --git a/tests/dataloader/distributed/test_distributed_repeating_dataloader.py b/tests/dataloader/distributed/test_distributed_repeating_dataloader.py index 7f40cc974..418793a43 100644 --- a/tests/dataloader/distributed/test_distributed_repeating_dataloader.py +++ b/tests/dataloader/distributed/test_distributed_repeating_dataloader.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from modalities.__main__ import Main -from modalities.config.config import ProcessGroupBackendType, PydanticLLMDataLoaderIFType +from modalities.config.config import ProcessGroupBackendType, PydanticDataLoaderIFType from modalities.running_env.cuda_env import CudaEnv from tests.dataloader.dummy_sequential_dataset import TestDataset, TestDatasetConfig @@ -17,7 +17,7 @@ class DataloaderInstantiationModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType @pytest.mark.skipif( diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 65139ce6a..9d6f171a9 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -10,7 +10,7 @@ from modalities.config.component_factory import ComponentFactory from modalities.config.config import load_app_config_dict -from modalities.config.pydanctic_if_types import PydanticLLMDataLoaderIFType +from modalities.config.pydanctic_if_types import PydanticDataLoaderIFType from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader from modalities.dataloader.dataset import Dataset from modalities.dataloader.samplers import ResumableBatchSampler @@ -49,7 +49,7 @@ def test_dataloader_from_config(dummy_config: dict): dummy_config["train_dataloader"]["config"]["skip_num_batches"] = start_index class DataloaderTestModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType registry = Registry(COMPONENTS) component_factory = ComponentFactory(registry=registry) @@ -167,7 +167,7 @@ def test_repeating_dataloader_with_shuffling(): def test_skipped_and_distributed_dataloader_from_config(): class DataloaderTestModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType skip_num_batches: int root_dir = Path(__file__).parents[0] @@ -244,7 +244,7 @@ class DataloaderTestModel(BaseModel): ) def test_dataloader_with_fixed_num_batches(global_rank): class DataloaderTestModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType fixed_num_batches: int class IdentityCollateFn(CollateFnIF): diff --git a/tests/end2end_tests/test_fsdp_warmstart.py b/tests/end2end_tests/test_fsdp_warmstart.py index 3261eb4b4..dac0b402c 100644 --- a/tests/end2end_tests/test_fsdp_warmstart.py +++ b/tests/end2end_tests/test_fsdp_warmstart.py @@ -11,7 +11,7 @@ from modalities.__main__ import Main, load_app_config_dict from modalities.batch import EvaluationResultBatch -from modalities.config.config import ProcessGroupBackendType, PydanticLLMDataLoaderIFType +from modalities.config.config import ProcessGroupBackendType, PydanticDataLoaderIFType from modalities.config.instantiation_models import TrainingComponentsInstantiationModel from modalities.dataloader.dataloader import LLMDataLoader from modalities.logging_broker.messages import Message @@ -46,7 +46,7 @@ class SaveAllResultSubscriberConfig(BaseModel): class TrainDataloaderInstantiationModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType @pytest.mark.skipif( From b985b9bbb5299d723e48953ef1399eb3fdcbffaf Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Mon, 30 Sep 2024 08:02:30 +0200 Subject: [PATCH 146/161] refactor: rename WebLoader to WebDataLoader --- .../training/config_coca_img_aud_vid_dataset.yaml | 4 ++-- src/modalities/config/config.py | 2 +- src/modalities/dataloader/dataloader.py | 2 +- src/modalities/dataloader/dataloader_factory.py | 8 ++++---- src/modalities/registry/components.py | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/config_files/training/config_coca_img_aud_vid_dataset.yaml b/config_files/training/config_coca_img_aud_vid_dataset.yaml index 6661ea57a..99eb0a189 100644 --- a/config_files/training/config_coca_img_aud_vid_dataset.yaml +++ b/config_files/training/config_coca_img_aud_vid_dataset.yaml @@ -273,7 +273,7 @@ val_dataset: train_dataloader: component_key: data_loader - variant_key: web_loader + variant_key: web_dataloader config: num_workers: 8 pin_memory: true @@ -289,7 +289,7 @@ train_dataloader: val_dataloader: component_key: data_loader - variant_key: web_loader + variant_key: web_dataloader config: num_workers: 8 pin_memory: true diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 77a0419a8..341358ba8 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -332,7 +332,7 @@ class LLMDataLoaderConfig(BaseModel): fixed_num_batches: Optional[int] = None -class WebLoaderConfig(BaseModel): +class WebDataLoaderConfig(BaseModel): dataloader_tag: str dataset: PydanticDatasetIFType batch_size: int diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index f91563583..ac62eec44 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -288,7 +288,7 @@ def __len__(self) -> int: return self.num_epochs * len(self.dataloader) -class WebLoader(DataLoaderIF): +class WebDataLoader(DataLoaderIF): def __init__( self, dataloader_tag: str, diff --git a/src/modalities/dataloader/dataloader_factory.py b/src/modalities/dataloader/dataloader_factory.py index 5647143cc..7b923849d 100644 --- a/src/modalities/dataloader/dataloader_factory.py +++ b/src/modalities/dataloader/dataloader_factory.py @@ -3,7 +3,7 @@ from torch.utils.data import BatchSampler from torch.utils.data.dataset import Dataset -from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader, WebLoader +from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader, WebDataLoader from modalities.dataloader.samplers import ResumableBatchSampler from modalities.exceptions import ConfigError @@ -91,7 +91,7 @@ def get_repeating_dataloader( return dataloader @staticmethod - def get_web_loader( + def get_web_dataloader( dataloader_tag: str, dataset: Dataset, batch_size: int, @@ -99,8 +99,8 @@ def get_web_loader( num_workers: int, pin_memory: bool, drop_last: bool, - ) -> WebLoader: - dataloader = WebLoader( + ) -> WebDataLoader: + dataloader = WebDataLoader( dataloader_tag=dataloader_tag, dataset=dataset, batch_size=batch_size, diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 4645abd48..6f07036c3 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -52,7 +52,7 @@ StepLRSchedulerConfig, TorchCheckpointLoadingConfig, WandBEvaluationResultSubscriberConfig, - WebLoaderConfig, + WebDataLoaderConfig, WeightInitializedModelConfig, ) from modalities.dataloader.dataloader_factory import DataloaderFactory @@ -213,7 +213,7 @@ class ComponentEntity: ComponentEntity("collate_fn", "coca_collator", CoCaCollatorFn, CoCaCollateFnConfig), # data loaders ComponentEntity("data_loader", "default", DataloaderFactory.get_dataloader, LLMDataLoaderConfig), - ComponentEntity("data_loader", "web_loader", DataloaderFactory.get_web_loader, WebLoaderConfig), + ComponentEntity("data_loader", "web_dataloader", DataloaderFactory.get_web_dataloader, WebDataLoaderConfig), ComponentEntity( "data_loader", "repeating_data_loader", DataloaderFactory.get_repeating_dataloader, RepeatingDataLoaderConfig ), From 9113f8af655d630b67f944ba2a748537ffbbc2e3 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Mon, 30 Sep 2024 08:42:37 +0200 Subject: [PATCH 147/161] docs: update docs for WebDataLoader and MultimodalWebDataset --- .../dataloader/dataloader_factory.py | 19 +++- src/modalities/dataloader/dataset.py | 92 ++++++++++++++----- 2 files changed, 86 insertions(+), 25 deletions(-) diff --git a/src/modalities/dataloader/dataloader_factory.py b/src/modalities/dataloader/dataloader_factory.py index 7b923849d..fe21ab679 100644 --- a/src/modalities/dataloader/dataloader_factory.py +++ b/src/modalities/dataloader/dataloader_factory.py @@ -4,6 +4,7 @@ from torch.utils.data.dataset import Dataset from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader, WebDataLoader +from modalities.dataloader.dataset import MultimodalWebDataset from modalities.dataloader.samplers import ResumableBatchSampler from modalities.exceptions import ConfigError @@ -93,13 +94,28 @@ def get_repeating_dataloader( @staticmethod def get_web_dataloader( dataloader_tag: str, - dataset: Dataset, + dataset: MultimodalWebDataset, batch_size: int, collate_fn: Callable, num_workers: int, pin_memory: bool, drop_last: bool, ) -> WebDataLoader: + """ + Returns a WebDataLoader object for a MultimodalWebDataset + + Args: + dataloader_tag (str): Tag for the dataloader + dataset (Dataset): The MultimodalWebDataset to be used + batch_size (int): batch size per device + collate_fn (Callable): Callable for shaping the batch + num_workers (int): Number of workers for the dataloader + pin_memory (bool): Flag indicating whether to pin memory + drop_last (bool): Flag indicating whether to drop the last non-full batch + + Returns: + WebDataLoader: A WebDataLoader object + """ dataloader = WebDataLoader( dataloader_tag=dataloader_tag, dataset=dataset, @@ -107,5 +123,6 @@ def get_web_dataloader( collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory, + drop_last=drop_last, ) return dataloader diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 624937b4f..c8af7e7af 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -710,19 +710,32 @@ def fixed_ratio_round_robin(*sources, samples_per_batch): class FixedRatioRoundRobinMix(IterableDataset): - """ - returns an iterator for a list of datasets; samples are yielded in a round robin manner - with a fixed ratio of samples per dataset. There is no random sampling, so the number of - samples per modality is guaranteed to be fixed per batch. - """ + def __init__( + self, + datasets: list[wds.WebDataset], + mixing_ratios: list[float], + batch_size: int, + ): + """An iterator for a list of datasets. + Samples are yielded in a round robin manner + with a fixed ratio of samples per dataset. There is no random sampling, so the number of + samples per modality is guaranteed to be fixed per batch. - def __init__(self, datasets, mixing_ratios, batch_size): + Args: + datasets (list[WebDataset]): a list of WebDatasets to be iterated over + mixing_ratios (list[float]): the ratio of samples from each dataset that should be present in a batch + batch_size (int): size of batch containing samples from all datasets in the specified ratio + """ self.datasets = datasets self.samples_per_batch = [int(batch_size * ratio) for ratio in mixing_ratios] + # ensure ratio sums up to 1.0 self.samples_per_batch[0] += batch_size - sum(self.samples_per_batch) def __iter__(self): - """Return an iterator over the sources.""" + """ + Returns: + an iterator over the source datasets + """ sources = [iter(d) for d in self.datasets] return fixed_ratio_round_robin(*sources, samples_per_batch=self.samples_per_batch) @@ -748,12 +761,16 @@ def __init__( """A multimodal dataset instance for the WebDataset. Args: - urls: A webdataset url. For example: "/data/path/{00000..00012.tar". - modality_key_mapping: Mapping from dataset keys to keys expected by the forward pass of the model. + urls (list[str] or str): A webdataset url. For example: "/data/path/{00000..00012}.tar". + modality_key_mapping (dict[str, tuple[str, str]]): Mapping from dataset keys to keys + expected by the forward pass of the model. For example: {ModalityEnum.IMAGE: ("jpg", "image"), ModalityEnum.TEXT: ("text", "caption")}} - modality_transforms: The transforms for each modality. - num_samples: The number of samples for each modality combination. - is_audio_video: Whether the dataset is a video dataset which contains audio + modality_transforms (dict[str, Transform]): The transforms for each modality as a dictionary. + is_audio_video (bool): Whether the dataset is a video dataset which contains audio + num_samples (int): The number of samples for each modality combination. + + Returns: + None """ self.urls = urls self.is_audio_video = is_audio_video @@ -775,12 +792,12 @@ def __init__( ModalityEnum.AUDIO: wds.torch_audio, } - self.additional_extreacted_keys = [] + self.additional_extracted_keys = [] if ModalityEnum.TEXT in self.modality_transforms: - self.additional_extreacted_keys.append("attention_mask") + self.additional_extracted_keys.append("attention_mask") if ModalityEnum.AUDIO in self.modality_transforms or ModalityEnum.VIDEO in self.modality_transforms: - self.additional_extreacted_keys.append("audio_len") + self.additional_extracted_keys.append("audio_len") # Mapping between modality and transform self.modality_to_transform_fn = { @@ -793,6 +810,21 @@ def __init__( def prepare( self, shardshuffle: int = 100, resample: bool = True, repeat: bool = False, shuffle_buffer: int = 10_000 ): + """ + Prepares a WebDataset object as a pipeline that includes shuffling, decoding data, and transformations + + Args: + shardshuffle (int): Number of shards that should be used for shuffling. Defaults to 100. + resample (bool): Instead of iterating in order sample random shards. + This has the issue that the model will see sample multiple times but is significantly more + efficient. Defaults to True. + repeat (bool): Repeat the dataset. Defaults to False. + shuffle_buffer (Optional[int]): Number of samples that should be used for shuffling. Defaults to 10_000. + + Returns: + None + + """ self.web_dataset = wds.WebDataset( urls=self.urls, nodesplitter=self.dummy_nodesplitter if not resample else None, @@ -848,8 +880,8 @@ def _transform_video(self, sample): if sample[source_key][1] is not None and ModalityEnum.AUDIO in self.modality_transforms and self.is_audio_video: transform: AudioTransform = self.modality_transforms[ModalityEnum.AUDIO] sample["audio"], sample["audio_len"] = transform((sample[source_key][1], sample[source_key][2])) - if "audio" not in self.additional_extreacted_keys: - self.additional_extreacted_keys.append("audio") + if "audio" not in self.additional_extracted_keys: + self.additional_extracted_keys.append("audio") del sample[source_key] return sample @@ -865,7 +897,10 @@ def _flatten_sample(self, sample): return flatten_dict(sample) def _select_keys(self, sample): - select_keys = self.additional_extreacted_keys + [v[1] for v in self.modality_key_mapping.values()] + # only select the required keys from the sample + # i.e. the keys specified in modality_key_mapping + # and the additional_extracted_keys + select_keys = self.additional_extracted_keys + [v[1] for v in self.modality_key_mapping.values()] new_sample = {} for k, v in sample.items(): if k not in select_keys: @@ -915,14 +950,23 @@ def __init__( Args: builders: WebDatasetBuilder instances. - mixing_ratios: Mixing ratios of the different modality combinations. + batch_size (int): batch size per device + mixing_ratios (Optinal[list[float]]): Mixing ratios of the different modality combinations. For example: [0.3, 0.7] - shardshuffle: Number of sharfs that should be used for shuffling. Defaults to 100. - repeat: Repeat the dataset. Defaults to False. - resample: Instead if iterating in order sample random shards. - This has the issue that the model will see sample multiple times but if significantly more + shardshuffle (int): Number of shards that should be used for shuffling. Defaults to 100. + repeat (bool): Repeat the dataset. Defaults to False. + resample (bool): Instead of iterating in order sample random shards. + This has the issue that the model will see sample multiple times but is significantly more efficient. Defaults to True. - shuffle_buffer: Number of samples that should be used for shuffling. Defaults to 10_000. + shuffle_buffer (Optional[int]): Number of samples that should be used for shuffling. Defaults to 10_000. + + Raises: + NotImplementedError: if multiple builders are specified and at least one builder contains a + video dataset which contains audio + ValueError: if multiple builders are specified and batch size is None + + Returns: + None """ super().__init__() self.builders = builders From 2a1cfa110d2bee6b2996fe0bd57412437dd47471 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Mon, 30 Sep 2024 11:54:56 +0200 Subject: [PATCH 148/161] docs: update docs for coca model --- src/modalities/models/coca/coca_model.py | 113 ++++++++++++++--------- 1 file changed, 67 insertions(+), 46 deletions(-) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index dd8697481..e551cb173 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -60,16 +60,31 @@ class CoCaConfig(BaseModel): Args: prediction_key (str): The key for the predictions. - vision_embd_prediction_key (str): The key for the vision embeddings. text_embd_prediction_key (str): The key for the text embeddings. - vision_cls_prediction_key (str): The key for the vision cls token. - text_cls_prediction_key (str): The key for the text cls token. - vision_encoder_config (VisionTransformerConfig): Configuration for the vision encoder. + logit_scale_prediction_key (str): The key for the logit scale + text_cls_prediction_key (Optional[str]): The key for the text cls token. + audio_embd_prediction_key (Optional[str]): The key for audio embeddings + image_embd_prediction_key (Optional[str]): The key for image embeddings + video_embd_prediction_key (Optional[str]): The key for video embeddings + audio_cls_prediction_key (Optional[str]): Th key for the audio cls token + audio_text_cls_prediction_key (Optional[str]): Th key for the text cls token associated with the audio samples + image_cls_prediction_key (Optional[str]): Th key for the image cls token + image_text_cls_prediction_key (Optional[str]): Th key for the text cls token associated with the image samples + video_cls_prediction_key (Optional[str]): Th key for the video cls token + video_text_cls_prediction_key (Optional[str]): Th key for the text cls token associated with the video samples + modality_keys (list[str]): sample keys in the input associated with the input modalities + individual_datasets (Optional[bool]): flag indicating whether + there are separate datasets for different modalities + is_audio_video (Optional[bool]): flag indicating whether the video samples contain audio + audio_encoder_config (Optional[AudioTransformerConfig]): config for the audio encoder. Defaults to None. + image_encoder_config (Optional[VisionTransformerConfig]): config for the image encoder. Defaults to None + video_encoder_config (Optional[VisionTransformerConfig]): config for the video encoder. Defaults to None text_decoder_config (TextDecoderConfig): Configuration for the text decoder. n_pool_head (int): Number of attention heads for pooling. - n_vision_queries (int): Number of vision queries. + n_queries (int): Number of vision queries. bias_attn_pool (bool): Flag indicating whether to use bias in attention pooling. epsilon_attn_pool (float): Epsilon value for attention pooling. + seed (Optional[int]): The random seed. Defaults to None """ @@ -144,18 +159,40 @@ def __init__( Args: prediction_key (str): The key for the predictions. - vision_cls_prediction_key (str): The key for the vision cls token. - text_cls_prediction_key (str): The key for the text cls token. - vision_embd_prediction_key (str): The key for the vision embeddings. text_embd_prediction_key (str): The key for the text embeddings. - - n_vision_queries (int): The number of vision queries. - n_pool_head (int): The number of pool heads. + logit_scale_prediction_key (str): The key for the logit scale + text_cls_prediction_key (Optional[str]): The key for the text cls token. + audio_embd_prediction_key (Optional[str]): The key for audio embeddings + image_embd_prediction_key (Optional[str]): The key for image embeddings + video_embd_prediction_key (Optional[str]): The key for video embeddings + audio_cls_prediction_key (Optional[str]): Th key for the audio cls token + audio_text_cls_prediction_key (Optional[str]): Th key for the text cls token + associated with the audio samples + image_cls_prediction_key (Optional[str]): Th key for the image cls token + image_text_cls_prediction_key (Optional[str]): Th key for the text cls token + associated with the image samples + video_cls_prediction_key (Optional[str]): Th key for the video cls token + video_text_cls_prediction_key (Optional[str]): Th key for the text cls token + associated with the video samples + modality_keys (list[str]): sample keys in the input associated with the input modalities + individual_datasets (Optional[bool]): flag indicating whether there are separate datasets + for different modalities + is_audio_video (Optional[bool]): flag indicating whether the video samples contain audio + audio_encoder_config (Optional[AudioTransformerConfig]): config for the audio encoder. Defaults to None. + image_encoder_config (Optional[VisionTransformerConfig]): config for the image encoder. Defaults to None + video_encoder_config (Optional[VisionTransformerConfig]): config for the video encoder. Defaults to None + text_decoder_config (TextDecoderConfig): Configuration for the text decoder. + n_pool_head (int): Number of attention heads for pooling. + n_queries (int): Number of vision queries. bias_attn_pool (bool): Flag indicating whether to use bias in attention pooling. - epsilon_attn_pool (float): The epsilon value for attention pooling. - vision_encoder_config (VisionTransformerConfig): The configuration for the vision encoder. - text_decoder_config (TextDecoderConfig): The configuration for the text decoder. - seed (int, optional): The random seed. Defaults to None. + epsilon_attn_pool (float): Epsilon value for attention pooling. + seed (Optional[int]): The random seed. Defaults to None + + Raises: + ValueError: if none of the modality encoders are defined + ValueError: if using individual dataset and none of the text cls tokens + corresponding to the modalities is defined + ValueError: if training on a single dataset and text_cls_prediction_key is not defined Returns: None @@ -289,10 +326,18 @@ def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: Forward pass of the CoCa model. Args: - inputs (dict[str, torch.Tensor]): Input dictionary containing the tensors. + inputs (dict[str, torch.Tensor]): Input dictionary containing the text and modality samples + In case of multiple modalities, the 'input_ids' key contain the token ids for + the text corresponding to all the modalities stacked together. Thus the length (batch size) + of 'input_ids' will be equal to the sum of the lengths of the individual modality + samples. Returns: - dict[str, torch.Tensor]: Output dictionary. + dict[str, torch.Tensor]: Output dictionary containing + - cls token(s) for the modality or modalities + - text cls token(s) corresponding to the modality sample(s) + - logits from the text decoder + - logit_scale """ output = {} @@ -364,15 +409,7 @@ def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: return output def _forward_encode_image(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: - """ - Encodes the input image using the vision encoder. - - Args: - inputs (dict[str, torch.Tensor]): dictionary containing vision inputs. - - Returns: - tuple[torch.Tensor, torch.Tensor]: Tuple containing encoded vision embeddings and classification token. - """ + # returns a tuple containing the image embeddings and cls token image_embd = self.image_encoder(inputs)[self.image_embd_prediction_key] queries = repeat(self.image_queries, "n d -> b n d", b=image_embd.shape[0]) image_embd = self.image_attn_pool(queries, context=image_embd) @@ -380,6 +417,7 @@ def _forward_encode_image(self, inputs: dict[str, torch.Tensor]) -> tuple[torch. return image_embd, image_cls_token def _forward_encode_video(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + # returns a tuple containing the video embeddings and cls token video_embd = self.video_encoder(inputs)[self.video_embd_prediction_key] queries = repeat(self.video_queries, "n d -> b n d", b=video_embd.shape[0]) video_embd = self.video_attn_pool(queries, context=video_embd) @@ -387,6 +425,7 @@ def _forward_encode_video(self, inputs: dict[str, torch.Tensor]) -> tuple[torch. return video_embd, video_cls_token def _forward_encode_audio(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + # returns a tuple containing the audio embeddings and cls token audio_embd = self.audio_encoder(inputs)[self.audio_embd_prediction_key] queries = repeat(self.audio_queries, "n d -> b n d", b=audio_embd.shape[0]) audio_embd = self.audio_attn_pool(queries, context=audio_embd) @@ -394,16 +433,7 @@ def _forward_encode_audio(self, inputs: dict[str, torch.Tensor]) -> tuple[torch. return audio_embd, audio_cls_token def _forward_encode_text(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: - """ - Encodes the input text using the text decoder. - - Args: - inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. - - Returns: - tuple[torch.Tensor, torch.Tensor]: A tuple containing the encoded text tensor - and the classification token tensor. - """ + # returns a tuple containing the encoded text tensor and the cls token text_embd = self.text_decoder(inputs)[self.text_embd_prediction_key] text_embd, text_cls_token = text_embd[:, :-1, :], F.normalize(text_embd[:, -1, :], dim=-1) return text_embd, text_cls_token @@ -411,16 +441,7 @@ def _forward_encode_text(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.T def _forward_decode( self, text_embd: torch.Tensor, modality_embd: list[torch.Tensor] | torch.Tensor ) -> torch.Tensor: - """ - Perform forward decoding using the given text and vision embeddings. - - Args: - text_embd (torch.Tensor): The text embeddings. - vision_embd (torch.Tensor): The vision embeddings. - - Returns: - torch.Tensor: The logits obtained from the multimodal decoder. - """ + # forward decode given the text and modality embedding(s) decoder_inputs = { self.text_embd_prediction_key: text_embd, "context": modality_embd, From b9bfaea95deb7c398b43c21fdcfe7b2643a086cb Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Tue, 1 Oct 2024 13:49:05 +0200 Subject: [PATCH 149/161] docs: update docs for vision transforms and make video transform parameters configurable --- .../config_coca_img_aud_vid_dataset.yaml | 8 ++ src/modalities/dataloader/dataset.py | 76 ++++++++++++++++++- .../vision_transformer_model.py | 5 +- 3 files changed, 84 insertions(+), 5 deletions(-) diff --git a/config_files/training/config_coca_img_aud_vid_dataset.yaml b/config_files/training/config_coca_img_aud_vid_dataset.yaml index 99eb0a189..4003e0e3d 100644 --- a/config_files/training/config_coca_img_aud_vid_dataset.yaml +++ b/config_files/training/config_coca_img_aud_vid_dataset.yaml @@ -95,6 +95,10 @@ train_video_transform: variant_key: video_transform config: is_training: True + hflip: 0.5 + color_jitter: [0.5, 0.5, 0.5, 0.5] + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] input_size: ${model_raw.config.video_encoder_config.img_size} num_frames: ${model_raw.config.video_encoder_config.num_video_frames} @@ -119,6 +123,10 @@ val_video_transform: variant_key: video_transform config: is_training: False + hflip: 0.0 + color_jitter: [0.0, 0.0, 0.0, 0.0] + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] input_size: ${model_raw.config.video_encoder_config.img_size} num_frames: ${model_raw.config.video_encoder_config.num_video_frames} diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index c8af7e7af..08387de57 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -440,7 +440,50 @@ class ImageTransformConfig(TransformConfig): # @register_component("transform", "image_transform", ImageTransformConfig) class ImageTransform(Transform): + """ImageTransform class.""" + def __init__(self, **kwargs): + """ + Initializes a Transform object for image transformations. + + The following argument descriptions are duplicated from: + https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/transforms_factory.py + + Args: + input_size (int, tuple[int,int], tuple[int, int, int]: + Target input size (channels, height, width) tuple or size scalar. + is_training (bool): Return training (random) transforms. + no_aug (bool): Disable augmentation for training (useful for debug). + train_crop_mode (Optional[str]): Training random crop mode ('rrc', 'rkrc', 'rkrr'). + scale (Optional[tuple[float, float]]) : Random resize scale range (crop area, < 1.0 => zoom in). + ratio (Optional[tuple[float, float]]): Random aspect ratio range + (crop ratio for RRC, ratio adjustment factor for RKR). + hflip (float): Horizontal flip probability. + vflip (float): Vertical flip probability. + color_jitter (float | tuple[float, ...]): Random color jitter component factors + (brightness, contrast, saturation, hue). + Scalar is applied as (scalar,) * 3 (no hue). + color_jitter_prob (Optional[float]): Apply color jitter with this + probability if not None (for SimlCLR-like aug). + grayscale_prob (float): Probability of converting image to grayscale (for SimCLR-like aug). + gaussian_blur_prob (float): Probability of applying gaussian blur (for SimCLR-like aug). + auto_augment (Optional[str]): Auto augment configuration string (see auto_augment.py). + interpolation (str): Image interpolation mode. + mean (tuple[float, ...]): Image normalization mean. + std (tuple[float, ...]): Image normalization standard deviation. + re_prob (float): Random erasing probability. + re_mode (str): Random erasing fill mode. + re_count (int): Number of random erasing regions. + re_num_splits (int): Control split of random erasing across batch size. + crop_pct (Optional[float]): Inference crop percentage (output size / resize size). + crop_mode (Optional[str]): Inference crop mode. + One of ['squash', 'border', 'center']. Defaults to 'center' when None. + crop_border_pixels (Optional[int]): Inference crop border of + specified # pixels around edge of original image. + tf_preprocessing (bool): Use TF 1.0 inference preprocessing for testing model ports + use_prefetcher (bool): Pre-fetcher enabled. Do not convert image to tensor or normalize. + """ + self._timm_image_transform = create_transform(**kwargs) def __call__(self, image): @@ -465,6 +508,17 @@ def __init__( truncation: bool = True, return_attention_mask: bool = True, ): + """ + Args: + tokenizer (TokenizerWrapper): text tokenizer + max_length (int): maximum number of tokens. Default 77 + padding (str): padding strategy. Default "max_length" + truncation (bool): Flag which determines whether to apply truncation. Default True. + return_attention_mask (bool): Flag which determines whether the attention mask is returned. Default True. + + Returns: + None + """ self.tokenizer = tokenizer self.max_length = max_length self.padding = padding @@ -586,6 +640,10 @@ class VideoTransformConfig(TransformConfig): input_size: int | tuple[int, int] | tuple[int, int, int] = 224 is_training: bool = False num_frames: int = 16 + hflip: float = 0.0 + color_jitter: list[float] = [0.0, 0.0, 0.0, 0.0] + mean: list[float] = IMAGENET_DEFAULT_MEAN + std: list[float] = IMAGENET_DEFAULT_STD class VideoTransform(Transform): @@ -594,14 +652,23 @@ def __init__( input_size: int | tuple[int, int] | tuple[int, int, int] = 224, is_training: bool = False, num_frames: int = 16, + hflip: float = 0.0, + color_jitter: list[float] = [0.0, 0.0, 0.0, 0.0], + mean: list[float] = IMAGENET_DEFAULT_MEAN, + std: list[float] = IMAGENET_DEFAULT_STD, ): self.spatial_transform = transforms.Compose( [ transforms.RandomResizedCrop(input_size, antialias=True), - transforms.RandomHorizontalFlip(), - transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), + transforms.RandomHorizontalFlip(p=hflip), + transforms.ColorJitter( + brightness=color_jitter[0], + contrast=color_jitter[1], + saturation=color_jitter[2], + hue=color_jitter[3], + ), transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + transforms.Normalize(mean=mean, std=std), ] ) self.temporal_transform = RandomTemporalCrop(num_frames=num_frames) @@ -706,7 +773,8 @@ def fixed_ratio_round_robin(*sources, samples_per_batch): i = (i + 1) % len(sources) yield sample except StopIteration: - del sources[i] + # stop if any modality runs out of samples + break class FixedRatioRoundRobinMix(IterableDataset): diff --git a/src/modalities/models/vision_transformer/vision_transformer_model.py b/src/modalities/models/vision_transformer/vision_transformer_model.py index ab09db8c3..407c60171 100644 --- a/src/modalities/models/vision_transformer/vision_transformer_model.py +++ b/src/modalities/models/vision_transformer/vision_transformer_model.py @@ -24,12 +24,15 @@ class VisionTransformerConfig(BaseModel): attention_config (AttentionConfig, optional): The configuration for the attention mechanism. Defaults to None. n_head (int): The number of attention heads. Defaults to 8. n_embd (int): The dimensionality of the embedding. Defaults to 768. + ffn_hidden (int): The number of hidden units in the feed-forward network. Defaults to 3072. dropout (float): The dropout rate. Defaults to 0.0. patch_size (int): The size of the image patches. Defaults to 16. patch_stride (int): The stride of the image patches. Defaults to 16. n_img_channels (int): The number of image channels. Defaults to 3. add_cls_token (bool): Flag indicating whether to add a classification token. Defaults to True. bias (bool): Flag indicating whether to include bias terms. Defaults to True. + num_video_frames (int): the number of video frames in case of video input + n_latents: the number of latent queries used for the Perceiver block in case of video input. Defaults to 64. """ sample_key: str @@ -47,7 +50,7 @@ class VisionTransformerConfig(BaseModel): n_img_channels: Annotated[int, Field(ge=1)] = 3 add_cls_token: bool = True bias: bool = True - num_video_frames: Annotated[int, Field(ge=0)] = 1 # TODO: read this from dataloader/train config + num_video_frames: Annotated[int, Field(ge=0)] = 1 n_latents: Annotated[int, Field(ge=1)] = 64 From d6da5833b35112773c9fe4e897723feef8777928 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Tue, 1 Oct 2024 13:52:38 +0200 Subject: [PATCH 150/161] test: add test for webdataset dataset and dataloader --- tests/dataloader/test_webdataset.py | 140 ++++++++++++++++++ .../yaml_configs/web_dataloader.yaml | 111 ++++++++++++++ 2 files changed, 251 insertions(+) create mode 100644 tests/dataloader/test_webdataset.py create mode 100644 tests/dataloader/yaml_configs/web_dataloader.yaml diff --git a/tests/dataloader/test_webdataset.py b/tests/dataloader/test_webdataset.py new file mode 100644 index 000000000..3ed7c7762 --- /dev/null +++ b/tests/dataloader/test_webdataset.py @@ -0,0 +1,140 @@ +import io +import tarfile +from pathlib import Path + +import numpy as np +import pytest +import torch +import torchaudio +import webdataset as wds +from pydantic import BaseModel + +from modalities.__main__ import load_app_config_dict +from modalities.config.component_factory import ComponentFactory +from modalities.config.pydanctic_if_types import PydanticDataLoaderIFType +from modalities.registry.components import COMPONENTS +from modalities.registry.registry import Registry +from tests.conftest import _ROOT_DIR + + +def create_image_sample(): + img = np.random.randint(0, 255, size=(224, 224, 3)).astype(np.uint8) + img = wds.writer.imageencoder(img, format="JPG") + text = {"text0": "this is an image caption %d" % np.random.randint(10)} + return img, text + + +@pytest.fixture(scope="session") +def image_tar_path(tmp_path_factory): + data_path = str(tmp_path_factory.mktemp("data") / "images.tar") + dataset_sink = wds.TarWriter(data_path) + # 10 image samples + for idx in range(10): + img, text = create_image_sample() + dataset_sink.write( + { + "__key__": "%02d" % idx, + "jpg": img, + "json": text, + } + ) + dataset_sink.close() + return data_path + + +def create_audio_sample(): + sample_rate = 16000 + audio = torch.from_numpy(np.random.uniform(-1, 1, sample_rate)).unsqueeze(0) + audio_buf = io.BytesIO() + torchaudio.save(audio_buf, audio, sample_rate, format="wav") + audio_buf.seek(0) + text = "this is an audio caption %d" % np.random.randint(10) + text_f = io.BytesIO() + text_f.write(text.encode("utf-8")) + text_f.seek(0) + return audio_buf, text_f + + +@pytest.fixture(scope="session") +def audio_tar_path(tmp_path_factory): + data_path = str(tmp_path_factory.mktemp("data") / "audio.tar") + with tarfile.open(data_path, "w") as fp: + # 25 audio samples + for idx in range(25): + key = "%02d" % idx + wav, text = create_audio_sample() + info = tarfile.TarInfo(key + ".wav") + info.size = wav.getbuffer().nbytes + fp.addfile(info, wav) + info = tarfile.TarInfo(key + ".transcript.txt") + info.size = text.getbuffer().nbytes + fp.addfile(info, text) + return data_path + + +@pytest.mark.parametrize( + "mixing_ratios,resample,batch_size", + [ + ([0.9, 0.1], False, 10), # we run out of image samples after the second batch + ([0.9, 0.1], True, 10), # since we resample, there are enough samples for >2 batches + ([0.7, 0.3], False, 20), # the first batch won't have 0.7*20 samples + ([0.3, 0.6], False, 10), # ratios don't add up to 1 + ([0.8, 0.2], True, 100), + ], +) +def test_web_dataloader(image_tar_path, audio_tar_path, mixing_ratios, resample, batch_size): + class DataloaderTestModel(BaseModel): + train_dataloader: PydanticDataLoaderIFType + + config_file_path = _ROOT_DIR / Path("tests/dataloader/yaml_configs/web_dataloader.yaml") + config_dict = load_app_config_dict(config_file_path=config_file_path) + config_dict["image_dataset"]["config"]["urls"] = image_tar_path + config_dict["audio_dataset"]["config"]["urls"] = audio_tar_path + config_dict["train_dataset"]["config"]["mixing_ratios"] = mixing_ratios + config_dict["train_dataset"]["config"]["resample"] = resample + config_dict["train_dataset"]["config"]["batch_size"] = batch_size + config_dict["train_dataloader"]["config"]["batch_size"] = batch_size + registry = Registry(COMPONENTS) + component_factory = ComponentFactory(registry=registry) + components = component_factory.build_components(config_dict=config_dict, components_model_type=DataloaderTestModel) + + expected_images = int(mixing_ratios[0] * batch_size) + expected_audio = int(mixing_ratios[1] * batch_size) + # if ratios don't add up to 1, extra samples are added to first modality + remaining = batch_size - (expected_audio + expected_images) + expected_images += remaining + + loader = iter(components.train_dataloader) + + # image, audio + total_samples = [10, 25] + seen_samples = [0, 0] + + for idx in range(5): + batch_expected_images = expected_images + batch_expected_audio = expected_audio + try: + batch = next(loader) + except StopIteration: + break + + if not resample: + # if resample is False, the last batch may have less + # samples than expected if one of the modalities + # runs out of samples + if total_samples[0] - seen_samples[0] < expected_images: + expected_images - (total_samples[0] - seen_samples[0]) + batch_expected_images = total_samples[0] - seen_samples[0] + if total_samples[1] - seen_samples[1] < expected_audio: + expected_audio - (total_samples[1] - seen_samples[1]) + batch_expected_audio = total_samples[1] - seen_samples[1] + + assert batch.samples["images"].shape[0] == batch_expected_images + seen_samples[0] += batch.samples["images"].shape[0] + assert batch.samples["audio"].shape[0] == batch_expected_audio + seen_samples[1] += batch.samples["audio"].shape[0] + assert batch.samples["input_ids"].shape[0] == batch_expected_audio + batch_expected_images + for idx in range(2): + # reset if the complete dataset has been seen already + if seen_samples[idx] == total_samples[idx]: + seen_samples[idx] = 0 diff --git a/tests/dataloader/yaml_configs/web_dataloader.yaml b/tests/dataloader/yaml_configs/web_dataloader.yaml new file mode 100644 index 000000000..843d3d801 --- /dev/null +++ b/tests/dataloader/yaml_configs/web_dataloader.yaml @@ -0,0 +1,111 @@ +tokenizer: + component_key: tokenizer + variant_key: pretrained_hf_tokenizer + config: + pretrained_model_name_or_path: openai/clip-vit-base-patch32 + padding: true + max_length: 50 + +train_image_transform: + component_key: transform + variant_key: image_transform + config: + is_training: True + input_size: 224 + +train_audio_transform: + component_key: transform + variant_key: audio_transform + config: + is_training: True + block_size_audio_encoder: 500 + freq_domain_mask_length: 30 + time_domain_mask_length: 100 + +text_transform: + component_key: transform + variant_key: text_transform + config: + tokenizer: + instance_key: tokenizer + pass_type: BY_REFERENCE + +collate_fn: + component_key: collate_fn + variant_key: coca_collator + config: + sample_keys: + - images + - audio + - audio_len + - input_ids + target_keys: [] + text_sample_key: input_ids + text_target_key: logits + +image_dataset: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: None + modality_key_mapping: + TEXT: ["json_text0", "input_ids"] + IMAGE: ["jpg", "images"] + modality_transforms: + IMAGE: + instance_key: train_image_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 10 + +audio_dataset: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: None + modality_key_mapping: + TEXT: ["transcript.txt", "input_ids"] # source and target keys + AUDIO: ["wav", "audio"] + modality_transforms: + AUDIO: + instance_key: train_audio_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 10 + + +train_dataset: + component_key: dataset + variant_key: web_dataset + config: + builders: + - instance_key: image_dataset + pass_type: BY_REFERENCE + - instance_key: audio_dataset + pass_type: BY_REFERENCE + mixing_ratios: [0.9, 0.1] + batch_size: 10 + shardshuffle: 100 + repeat: false + resample: false + shuffle_buffer: 10_000 + +train_dataloader: + component_key: data_loader + variant_key: web_dataloader + config: + num_workers: 0 + pin_memory: true + drop_last: true + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_size: 10 + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE From 152ebf29b883750cd5e76a5f08f86eb7a42fa337 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Kacz=C3=A9r?= Date: Mon, 7 Oct 2024 11:22:20 +0200 Subject: [PATCH 151/161] docs: add docstring to WebDataloader, typehints to _init_modality --- src/modalities/dataloader/dataloader.py | 14 ++++++++++++++ src/modalities/models/coca/coca_model.py | 5 ++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index ac62eec44..b083e057a 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -289,6 +289,8 @@ def __len__(self) -> int: class WebDataLoader(DataLoaderIF): + """WebDataLoader is a custom DataLoader class that wraps the webdataset.WebLoader class.""" + def __init__( self, dataloader_tag: str, @@ -299,6 +301,18 @@ def __init__( pin_memory: bool = False, drop_last: bool = False, ): + """Initializes WebDataLoader, which is a wrapper for webdataset.WebLoader. + + Args: + dataloader_tag (str): The tag for the dataloader. + dataset (Dataset[T_co]): The dataset to load the data from. + batch_size (Optional[int], optional): The batch size. Defaults to 1. + num_workers (int, optional): The number of worker processes to use for data loading. Defaults to 0. + collate_fn (Optional[_collate_fn_t], optional): The function used to collate the data samples. + Defaults to None. + pin_memory (bool, optional): Flag indicating whether to pin the memory. Defaults to False. + drop_last (bool, optional): Flag indicating whether to drop the last incomplete batch. Defaults to False. + """ self.num_batches = len(dataset) // batch_size + int(not drop_last) dataset = dataset.batched(batch_size, collation_fn=collate_fn) self.webloader = wd.WebLoader(dataset=dataset, batch_size=None, num_workers=num_workers, pin_memory=pin_memory) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index e551cb173..3d4aa7594 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -309,7 +309,10 @@ def __init__( # Logit scale for contrastive loss self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - def _init_modality(self, encoder_class, encoder_config, n_queries): + def _init_modality( + self, encoder_class: type, encoder_config: VisionTransformerConfig | AudioTransformerConfig, n_queries: int + ) -> tuple[VisionTransformer | AudioTransformer, nn.Parameter, AttentionPooling]: + # initialize modality encoder, returns a tuple containing the encoder, queries and attention pooling layer encoder = encoder_class(**dict(encoder_config)) queries = nn.Parameter(torch.randn(n_queries + 1, encoder_config.n_embd)) attn_pool = AttentionPooling( From 03539694d20f89456f3f9f10d7d4da50951355b3 Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Mon, 7 Oct 2024 14:12:02 +0200 Subject: [PATCH 152/161] fix: mask creation for audio inputs --- .../models/audio_transformer/audio_transformer_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/modalities/models/audio_transformer/audio_transformer_model.py b/src/modalities/models/audio_transformer/audio_transformer_model.py index 435075845..c62509068 100644 --- a/src/modalities/models/audio_transformer/audio_transformer_model.py +++ b/src/modalities/models/audio_transformer/audio_transformer_model.py @@ -374,10 +374,14 @@ def _get_attn_key_mask( lengths: torch.Tensor, ) -> torch.Tensor: # Generates an attention key mask based on input sequence lengths. + stack = [] + for length in lengths: + ones = torch.ones(length, self.block_size) + ones[1:, length:] = 0 + stack.append(ones) return ( torch.nn.utils.rnn.pad_sequence( - [torch.ones(length, self.block_size) for length in lengths] - + [torch.ones(self.block_size, self.block_size)], + stack + [torch.zeros(self.block_size, self.block_size)], batch_first=True, ) .transpose(1, 2)[:-1] From a12ff8a9422f020f24c9c83f48451ba553fbbe56 Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Mon, 7 Oct 2024 14:13:06 +0200 Subject: [PATCH 153/161] test: add tests for audio_transformer --- .../audio_transformer_model.py | 3 + tests/models/audio_transformer/__init__.py | 0 .../test_audio_transformer_model.py | 122 ++++++++++++++++++ 3 files changed, 125 insertions(+) create mode 100644 tests/models/audio_transformer/__init__.py create mode 100644 tests/models/audio_transformer/test_audio_transformer_model.py diff --git a/src/modalities/models/audio_transformer/audio_transformer_model.py b/src/modalities/models/audio_transformer/audio_transformer_model.py index c62509068..d40e63993 100644 --- a/src/modalities/models/audio_transformer/audio_transformer_model.py +++ b/src/modalities/models/audio_transformer/audio_transformer_model.py @@ -143,6 +143,9 @@ def forward( Returns: torch.Tensor: Output tensor of shape (B, T, D). """ + if x.shape[1] == 1: + raise ValueError("The time dimension of the input to the convolution module cannot be 1!") + x = self.ln(x) x = x.transpose(1, 2) x = self.glu(self.pointwise_conv_1(x)) diff --git a/tests/models/audio_transformer/__init__.py b/tests/models/audio_transformer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/models/audio_transformer/test_audio_transformer_model.py b/tests/models/audio_transformer/test_audio_transformer_model.py new file mode 100644 index 000000000..09d3d672a --- /dev/null +++ b/tests/models/audio_transformer/test_audio_transformer_model.py @@ -0,0 +1,122 @@ +import pytest +import torch + +from modalities.models.audio_transformer.audio_transformer_model import ( + AudioTransformer, + ConformerBlock, + ConvolutionModule, +) +from modalities.nn.attention import AttentionConfig + + +@pytest.fixture +def params() -> dict: + return { + "sample_key": "audio", + "prediction_key": "audio_embeddings", + "block_size": 5, + "n_mels": 1, + "n_conformer_blocks": 1, + "n_embd": 1, + "n_heads": 1, + "attention_config": AttentionConfig(attention_engine_type="pytorch_flash_attention"), + "pointwise_conv_kernel_size": 1, + "depthwise_conv_kernel_size": 1, + "dropout": 0.1, + } + + +@pytest.fixture +def audio_transformer_model(params) -> AudioTransformer: + return AudioTransformer( + sample_key=params["sample_key"], + prediction_key=params["prediction_key"], + block_size=params["block_size"], + n_mels=params["n_mels"], + n_conformer_blocks=params["n_conformer_blocks"], + n_embd=params["n_embd"], + n_heads=params["n_heads"], + attention_config=params["attention_config"], + pointwise_conv_kernel_size=params["pointwise_conv_kernel_size"], + depthwise_conv_kernel_size=params["depthwise_conv_kernel_size"], + ffmodule_dropout=params["dropout"], + attn_dropout=params["dropout"], + convmodule_dropout=params["dropout"], + ) + + +@pytest.fixture +def invalid_forward_input() -> torch.Tensor: + return torch.randn((1, 1, 256)) + + +@pytest.fixture +def forward_input() -> dict[str, torch.Tensor]: + return {"x": torch.randn((1, 2, 1)), "mask": torch.ones((1, 2))} + + +def test_convolution_module_forward_return_shape( + params, + forward_input, +): + convolution = ConvolutionModule( + params["n_embd"], + params["pointwise_conv_kernel_size"], + params["depthwise_conv_kernel_size"], + params["dropout"], + ) + + out = convolution(forward_input["x"]) + + assert out.shape == (1, 2, 1) + + +def test_convolution_module_forward_raise( + params, + invalid_forward_input, +): + convolution = ConvolutionModule( + params["n_embd"], + params["pointwise_conv_kernel_size"], + params["depthwise_conv_kernel_size"], + params["dropout"], + ) + + with pytest.raises(ValueError, match="The time dimension of the input to the convolution module cannot be 1!"): + convolution(invalid_forward_input) + + +def test_conformer_forward(params, forward_input): + conformer = ConformerBlock( + params["n_embd"], + params["n_heads"], + params["attention_config"], + params["pointwise_conv_kernel_size"], + params["depthwise_conv_kernel_size"], + params["dropout"], + params["dropout"], + params["dropout"], + ) + + conformer(forward_input["x"], forward_input["mask"]) + + +def test_audio_transformer__get_attn_key_mask(audio_transformer_model): + lengths = torch.tensor([3]) + + CORRECT_MASK = torch.Tensor( + [ + [ + [ + [1, 1, 1, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 0, 0], + [1, 0, 0, 0, 0], + [1, 0, 0, 0, 0], + ] + ] + ] + ) + + CREATED_MASK = audio_transformer_model._get_attn_key_mask(lengths) + assert torch.equal(CORRECT_MASK, CREATED_MASK) From 79c6c6b5f893118f2810487fd958d54270aec19f Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Mon, 7 Oct 2024 15:45:59 +0200 Subject: [PATCH 154/161] docs: misc. docstrings and type hints for VideoTransform, web dataset builder etc. --- src/modalities/config/instantiation_models.py | 2 +- src/modalities/dataloader/dataset.py | 179 +++++++++++++----- .../models/coca/multi_modal_decoder.py | 4 + .../vision_transformer_model.py | 66 ++++++- 4 files changed, 196 insertions(+), 55 deletions(-) diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index 4ac4a4bf8..6a8b7c411 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -168,7 +168,7 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel wrapped_model: PydanticPytorchModuleType optimizer: PydanticOptimizerIFType scheduler: PydanticLRSchedulerIFType - loss_fn: PydanticLossIFType | list[PydanticLossIFType] + loss_fn: PydanticLossIFType train_dataset: PydanticDatasetIFType train_dataloader: PydanticDataLoaderIFType eval_dataloaders: list[PydanticDataLoaderIFType] diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 08387de57..c4b1a6f9c 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -5,11 +5,12 @@ import re from enum import Enum from pathlib import Path -from typing import Annotated, Optional +from typing import Annotated, Any, Optional import decord import jq import numpy as np +import PIL import torch import torchaudio import webdataset as wds @@ -438,11 +439,10 @@ class ImageTransformConfig(TransformConfig): separate: bool = False -# @register_component("transform", "image_transform", ImageTransformConfig) class ImageTransform(Transform): """ImageTransform class.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: """ Initializes a Transform object for image transformations. @@ -486,7 +486,7 @@ def __init__(self, **kwargs): self._timm_image_transform = create_transform(**kwargs) - def __call__(self, image): + def __call__(self, image: PIL.Image.Image) -> torch.Tensor: return self._timm_image_transform(image) @@ -498,7 +498,6 @@ class TextTransformConfig(TransformConfig): return_attention_mask: bool = True -# @register_component("transform", "text_transform", TextTransformConfig) class TextTransform(Transform): def __init__( self, @@ -507,7 +506,7 @@ def __init__( padding: str = "max_length", truncation: bool = True, return_attention_mask: bool = True, - ): + ) -> None: """ Args: tokenizer (TokenizerWrapper): text tokenizer @@ -525,7 +524,7 @@ def __init__( self.truncation = truncation self.return_attention_mask = return_attention_mask - def __call__(self, text): + def __call__(self, text: str) -> BatchEncoding: batch_encoding: BatchEncoding = self.tokenizer.tokenizer( text, max_length=self.max_length, @@ -574,7 +573,7 @@ def __init__( n_mels: int = 128, freq_domain_mask_length: int = 30, time_domain_mask_length: int = 100, - ): + ) -> None: """ Initializes the AudioTransform class. @@ -626,13 +625,44 @@ def __call__(self, raw_audio: tuple[torch.Tensor, int]) -> tuple[torch.Tensor, i return log_mel_spec, feats_len -class RandomTemporalCrop: - def __init__(self, num_frames): +class TemporalCrop: + """ + This module crops a video along the temporal dimension + """ + + def __init__( + self, + num_frames: int, + is_training: bool = False, + ) -> None: + """ + Initializes the TemporalCrop class + + Args: + num_frames (int): The length of the clip to be cropped + is_training (bool, optional): Whether the module is in training mode. Defaults to False. + + Returns: + None + """ self.num_frames = num_frames + self.is_training = is_training + + def __call__(self, video: torch.Tensor) -> torch.Tensor: + """ + Crops the video to a length of `num_frames`. If in training mode, the start of the crop is chosen randomly. + + Args: + video (torch.Tensor): the video to be cropped with dimensions T x H x W x C - def __call__(self, video): + Returns: + cropped video (torch.Tensor): the cropped video with dimensions num_frames x C x H x W + """ total_frames = len(video) - start = random.randint(0, total_frames - self.num_frames) + if self.is_training: + start = random.randint(0, total_frames - self.num_frames) + else: + start = 0 return video[start : start + self.num_frames].permute(0, 3, 1, 2) # F C H W @@ -647,6 +677,10 @@ class VideoTransformConfig(TransformConfig): class VideoTransform(Transform): + """ + A video transformation module that performs spatial and temporal transformations. + """ + def __init__( self, input_size: int | tuple[int, int] | tuple[int, int, int] = 224, @@ -656,24 +690,65 @@ def __init__( color_jitter: list[float] = [0.0, 0.0, 0.0, 0.0], mean: list[float] = IMAGENET_DEFAULT_MEAN, std: list[float] = IMAGENET_DEFAULT_STD, - ): - self.spatial_transform = transforms.Compose( - [ - transforms.RandomResizedCrop(input_size, antialias=True), - transforms.RandomHorizontalFlip(p=hflip), - transforms.ColorJitter( - brightness=color_jitter[0], - contrast=color_jitter[1], - saturation=color_jitter[2], - hue=color_jitter[3], - ), - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), - ] - ) - self.temporal_transform = RandomTemporalCrop(num_frames=num_frames) + ) -> None: + """ + Initializes the VideoTransform class - def __call__(self, video): + Args: + input_size (int | tuple[int, int] | tuple[int, int, int] ): target spatial size of video frames. + is_training (bool, optional): Whether the module is in training mode. Defaults to False. + When not in training mode, resize and center crop is used instead of RandomResizedCrop, + no horizontal flip nor color jitter is performed, and the temporal crop is deterministic. + num_frames (int): target number of frames in the transformed video. Defaults to 16. + hflip (float): probability of performing horizontal flip on the frames. Defaults to 0.0. + color_jitter (list[float]): Random color jitter component factors + (brightness, contrast, saturation, hue). + Defaults to 0.0 for all components. + mean (list[float]): Image normalization mean. Defaults to IMAGENET defaults. + std (list[float]): Image normalization standard deviation. Defaults to IMAGENET defaults. + + + Returns: + None + """ + if is_training: + self.spatial_transform = transforms.Compose( + [ + transforms.RandomResizedCrop(input_size, antialias=True), + transforms.RandomHorizontalFlip(p=hflip), + transforms.ColorJitter( + brightness=color_jitter[0], + contrast=color_jitter[1], + saturation=color_jitter[2], + hue=color_jitter[3], + ), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + ) + else: + self.spatial_transform = transforms.Compose( + [ + transforms.Resize(input_size, antialias=True), + transforms.CenterCrop(input_size), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + ) + self.temporal_transform = TemporalCrop(num_frames=num_frames, is_training=is_training) + + def __call__(self, video: tuple[torch.Tensor, torch.Tensor | None, int]) -> torch.Tensor: + """ + Performs spatial and temporal transformations on the input video + + Args: + video (tuple[torch.Tensor, torch.Tensor, ]): the first element is the video + to be transformed T x H' x W' x C. + The second and third elements are ignored (optional audio, audio sample rate). + + Returns: + transformed video (torch.Tensor): with dimensions num_frames x C x H x W + """ video = video[0] video = self.temporal_transform(video) return self.spatial_transform(video) @@ -691,7 +766,7 @@ def decord_video(key: str, data: bytes) -> None | tuple[torch.Tensor, Optional[t If an audio stream exists, it extracts the audio with a mean across channels (if there are multiple). It then uses Decord to decode uniformly sampled frames from the video. - Parameters: + Args: key (str): The key or identifier for the video data. data (bytes): The binary data of the video file. @@ -734,7 +809,7 @@ def torch_audio(key: str, data: bytes) -> None | tuple[torch.Tensor, int]: It first checks if the file extension is one of the supported formats. If there are multiple channels in the audio file, it averages them to produce a mono audio tensor. - Parameters: + Args: key (str): The key or identifier for the audio data. data (bytes): The binary data of the audio file. @@ -754,7 +829,22 @@ def torch_audio(key: str, data: bytes) -> None | tuple[torch.Tensor, int]: return (audio, sample_rate) -def fixed_ratio_round_robin(*sources, samples_per_batch): +def fixed_ratio_round_robin(*sources, samples_per_batch: list[int]): + """ + Iterator over a list of iterators. + Samples from each source iterator are selected in a round-robin fashion, with a fixed number + of samples from each iterator for a given batch, as defined by `samples_per_batch` + + + Args: + sources (list[iterator]): An arbitrary number of source iterators + samples_per_batch (list[int]): Number of samples from each source iterator + which should be present in one batch + + Yields: + sample: a sample from one of the iterators + """ + sources = list(sources) remaining_samples_in_batch = samples_per_batch.copy() i = 0 @@ -783,7 +873,7 @@ def __init__( datasets: list[wds.WebDataset], mixing_ratios: list[float], batch_size: int, - ): + ) -> None: """An iterator for a list of datasets. Samples are yielded in a round robin manner with a fixed ratio of samples per dataset. There is no random sampling, so the number of @@ -793,6 +883,9 @@ def __init__( datasets (list[WebDataset]): a list of WebDatasets to be iterated over mixing_ratios (list[float]): the ratio of samples from each dataset that should be present in a batch batch_size (int): size of batch containing samples from all datasets in the specified ratio + + Returns: + None """ self.datasets = datasets self.samples_per_batch = [int(batch_size * ratio) for ratio in mixing_ratios] @@ -816,7 +909,6 @@ class MultimodalWebDatasetBuilderConfig(BaseModel): num_samples: Annotated[int, Field(ge=1)] -# @register_component("dataset", "web_dataset_builder", MultimodalWebDatasetBuilderConfig) class MultimodalWebDatasetBuilder: def __init__( self, @@ -825,7 +917,7 @@ def __init__( modality_transforms: dict[str, Transform], is_audio_video: bool, num_samples: int, - ): + ) -> None: """A multimodal dataset instance for the WebDataset. Args: @@ -877,7 +969,7 @@ def __init__( def prepare( self, shardshuffle: int = 100, resample: bool = True, repeat: bool = False, shuffle_buffer: int = 10_000 - ): + ) -> None: """ Prepares a WebDataset object as a pipeline that includes shuffling, decoding data, and transformations @@ -924,7 +1016,7 @@ def prepare( self.web_dataset.append(wds.filters.map(self._select_keys)) - def _transform_text(self, sample): + def _transform_text(self, sample: dict[str, Any]) -> dict[str, Any]: source_key, target_key = self.modality_key_mapping[ModalityEnum.TEXT] transform: TextTransform = self.modality_transforms[ModalityEnum.TEXT] batch_encoding: BatchEncoding = transform(sample[source_key]) @@ -933,14 +1025,14 @@ def _transform_text(self, sample): sample["attention_mask"] = batch_encoding.attention_mask return sample - def _transform_image(self, sample): + def _transform_image(self, sample: dict[str, Any]) -> dict[str, Any]: source_key, target_key = self.modality_key_mapping[ModalityEnum.IMAGE] transform: TextTransform = self.modality_transforms[ModalityEnum.IMAGE] sample[target_key] = transform(sample[source_key]) del sample[source_key] return sample - def _transform_video(self, sample): + def _transform_video(self, sample: dict[str, Any]) -> dict[str, Any]: source_key, target_key = self.modality_key_mapping[ModalityEnum.VIDEO] transform: VideoTransform = self.modality_transforms[ModalityEnum.VIDEO] sample[target_key] = transform(sample[source_key]) @@ -953,7 +1045,7 @@ def _transform_video(self, sample): del sample[source_key] return sample - def _transform_audio(self, sample: dict): + def _transform_audio(self, sample: dict[str, Any]) -> dict[str, Any]: # Apply audio transforms to the input sample. source_key, target_key = self.modality_key_mapping[ModalityEnum.AUDIO] transform: AudioTransform = self.modality_transforms[ModalityEnum.AUDIO] @@ -961,10 +1053,10 @@ def _transform_audio(self, sample: dict): del sample[source_key] return sample - def _flatten_sample(self, sample): + def _flatten_sample(self, sample: dict[str, Any]) -> dict[str, Any]: return flatten_dict(sample) - def _select_keys(self, sample): + def _select_keys(self, sample: dict[str, Any]) -> dict[str, Any]: # only select the required keys from the sample # i.e. the keys specified in modality_key_mapping # and the additional_extracted_keys @@ -1002,7 +1094,6 @@ class MultimodalWebDatasetConfig(BaseModel): shuffle_buffer: Optional[int] = 10_000 -# @register_component("dataset", "web_dataset", MultimodalWebDatasetConfig) class MultimodalWebDataset(wds.DataPipeline, wds.compat.FluidInterface): def __init__( self, @@ -1013,7 +1104,7 @@ def __init__( repeat: bool = False, resample: bool = True, shuffle_buffer: Optional[int] = 10_000, - ): + ) -> None: """WebDataset for loading and combining multimodal datasets. Args: diff --git a/src/modalities/models/coca/multi_modal_decoder.py b/src/modalities/models/coca/multi_modal_decoder.py index 981ed5398..aca2cea21 100644 --- a/src/modalities/models/coca/multi_modal_decoder.py +++ b/src/modalities/models/coca/multi_modal_decoder.py @@ -39,6 +39,8 @@ def __init__( dropout (float): The dropout rate. ffn_hidden (int): The number of hidden units in the feed-forward network. with_context (bool): Flag indicating whether to include context in the decoder. + is_audio_video (bool): Flag indicating whether an additional cross attention block is required for + data that consists of both audio and video from the same source. attention_type (AttentionType): The type of attention mechanism to use. attention_config (AttentionConfig, optional): The configuration for the attention mechanism. Defaults to None. @@ -143,6 +145,8 @@ def __init__( n_head (int): The number of attention heads. n_embd (int): The dimension of the embeddings. ffn_hidden (int): The size of the feed-forward network hidden layer. + is_audio_video (bool): Flag indicating whether an additional cross attention block is required for + data that consists of both audio and video from the same source. dropout (float): The dropout rate. bias (bool): Flag indicating whether to include bias terms. activation (ActivationType): The activation function to use. diff --git a/src/modalities/models/vision_transformer/vision_transformer_model.py b/src/modalities/models/vision_transformer/vision_transformer_model.py index 407c60171..dd48d72af 100644 --- a/src/modalities/models/vision_transformer/vision_transformer_model.py +++ b/src/modalities/models/vision_transformer/vision_transformer_model.py @@ -122,6 +122,19 @@ def __init__( patch_size: int = 16, patch_stride: int = 16, ) -> None: + """ + Initializes a VideoPatchEmbedding object. + + + Args: + n_img_channels (int): Number of image channels. Defaults to 3. + n_embd (int): Number of embedding dimensions. Defaults to 768. + patch_size (int): Patch size for convolutional layer. Defaults to 16. + patch_stride (int): Patch stride for convolutional layer. Defaults to 16. + + Returns: + None + """ super().__init__() self.input_rearrange = Rearrange("b T c h w -> b c T h w") self.conv = nn.Conv3d( @@ -135,6 +148,16 @@ def __init__( self.rearrange = Rearrange("b c T h w -> b T (h w) c") # TODO: this might change when implementing dataloader def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the VideoPatchEmbedding. + + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ x = self.input_rearrange(x) x = self.conv(x) x = self.rearrange(x) @@ -194,12 +217,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -# TODO: extend to all modalities based on the original paper (https://arxiv.org/pdf/2103.03206)! -# TODO: extend this to work with video and images! class PerceiverTransformerBlock(nn.Module): """Perceiver Resampler - This is a transformer based architecture that performs cross and self attention to compress and embed video inputs. + This is a transformer based architecture that performs cross and self attention to compress and embed video + or other high-dimensional inputs. paper: 'Flamingo: a Visual Language Model for Few-Shot Learning' Link: https://github.com/mlfoundations/open_flamingo """ @@ -213,6 +235,21 @@ def __init__( dropout: float = 0.0, attention_config: AttentionConfig = None, ) -> None: + """ + Initializes a PerceiverTransformerBlock object. + + Args: + n_embd (int, optional): The dimensionality of the embedding layer. Defaults to 768. + n_head (int, optional): The number of attention heads. Defaults to 8. + ffn_hidden (int, optional): The number of hidden units in the feed-forward network. Defaults to 3072. + bias (bool, optional): Flag indicating whether to include bias terms. Defaults to True. + dropout (float, optional): The dropout rate. Defaults to 0.0. + attention_config (AttentionConfig, optional): The configuration for the attention mechanism. + Defaults to None. + + Returns: + None + """ super().__init__() self.norm_latents = nn.LayerNorm(n_embd) self.norm = nn.LayerNorm(n_embd) @@ -225,6 +262,16 @@ def __init__( self.mlp = MLP(in_features=n_embd, hidden_features=ffn_hidden, bias=bias, dropout=dropout) def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the PerceiverTransformerBlock module. + + Args: + x (torch.Tensor): Input tensor. + latents (torch.Tensor): input latent array tensor + + Returns: + torch.Tensor: Output tensor. + """ latents = self.norm_latents(latents) x = self.norm(x) context = torch.cat((x, latents), dim=-2) # video features and the latent together @@ -243,7 +290,7 @@ class VisionTransformer(nn.Module): Paper: `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` Link: https://arxiv.org/abs/2010.11929 - This architecture is extended to encode videos using a perceiver resampler transformer model + This architecture is extended to encode videos using a Perceiver transformer model """ def __init__( @@ -263,7 +310,7 @@ def __init__( n_img_channels: int = 3, add_cls_token: bool = True, bias: bool = True, - num_video_frames: int = 1, # when dealing with video this is bigger than 1 + num_video_frames: int = 1, # 1: Image, >1: Video n_latents: int = 64, ) -> None: """ @@ -285,6 +332,8 @@ def __init__( n_img_channels (int, optional): The number of image channels. Defaults to 3. add_cls_token (bool, optional): Flag indicating whether to add a classification token. Defaults to True. bias (bool, optional): Flag indicating whether to include bias terms. Defaults to True. + num_video_frames (int): Number of frames. Defaults to 1. + n_latents (int, optional): Size of latent array. Defaults to 64. Returns: None @@ -315,9 +364,7 @@ def __init__( else: self.embedding_fn = ImagePatchEmbedding(n_img_channels, n_embd, patch_size, patch_stride, add_cls_token) - self.positional_embedding_fn = nn.Embedding( - num_embeddings=self.block_size, embedding_dim=n_embd - ) # [S D] #TODO: this needs to be adjusted for video with cls_token + self.positional_embedding_fn = nn.Embedding(num_embeddings=self.block_size, embedding_dim=n_embd) # [S D] block_classes = {"Video": PerceiverTransformerBlock, "Image": VisionTransformerBlock} self.blocks = nn.ModuleList( @@ -367,7 +414,6 @@ def forward_videos(self, x: torch.Tensor) -> torch.Tensor: """ x = self.embedding_fn(x) # [b T S D] b, T = x.shape[:2] - # TODO: check this! x = self.dropout(x + self.positional_embedding_fn.weight) x = self.dropout(x + self.time_embd.repeat(b, 1, 1, 1)) x = self.rearrange(x) # [b T*S D] @@ -376,7 +422,7 @@ def forward_videos(self, x: torch.Tensor) -> torch.Tensor: latents = block(x, latents) return latents - def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: # TODO video adapt + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the VisionTransformer module. From 23267e8d84e75b714f6fc106c0a8a86d1b6f7320 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Tue, 8 Oct 2024 10:08:51 +0200 Subject: [PATCH 155/161] fix: update simple progress subscriber config --- src/modalities/config/config.py | 8 +++---- .../subscriber_impl/progress_subscriber.py | 13 +++++++---- .../subscriber_impl/subscriber_factory.py | 22 ++++++------------- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 341358ba8..0a6bf9eea 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -353,11 +353,11 @@ class DummyProgressSubscriberConfig(BaseModel): class SimpleProgressSubscriberConfig(BaseModel): - train_dataloader: PydanticDataLoaderIFType eval_dataloaders: Optional[list[PydanticDataLoaderIFType]] = Field(default_factory=list) - world_size: int - global_num_seen_samples: int - local_rank: int + train_dataloader_tag: str + num_seen_steps: Annotated[int, Field(strict=True, ge=0)] + num_target_steps: Annotated[int, Field(strict=True, gt=0)] + global_rank: Annotated[int, Field(strict=True, ge=0)] class RichProgressSubscriberConfig(BaseModel): diff --git a/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py b/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py index 9a991fe0c..f54a3f782 100644 --- a/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py @@ -14,11 +14,13 @@ class DummyProgressSubscriber(MessageSubscriberIF[ProgressUpdate]): def consume_message(self, message: Message[ProgressUpdate]): pass - def consume_dict(self, mesasge_dict: dict[str, Any]): + def consume_dict(self, message_dict: dict[str, Any]): pass class SimpleProgressSubscriber(MessageSubscriberIF[ProgressUpdate]): + """A subscriber object for the RichProgress observable.""" + def __init__( self, train_split_num_samples: dict[str, int], @@ -40,19 +42,22 @@ def consume_message(self, message: Message[ProgressUpdate]): prefix = "" if message.payload.experiment_status == ExperimentStatus.TRAIN: prefix = "Train" - completed_samples = batch_progress.global_train_sample_id + 1 + completed_samples = batch_progress.num_steps_done total_samples = self.train_split_num_samples[batch_progress.dataloader_tag] elif message.payload.experiment_status == ExperimentStatus.EVALUATION: prefix = "Evaluation" - completed_samples = batch_progress.global_dataset_sample_id + 1 + completed_samples = batch_progress.num_steps_done total_samples = self.eval_splits_num_samples[batch_progress.dataloader_tag] print( f"{prefix}[{batch_progress.dataloader_tag}] " - f"[{completed_samples}/{total_samples} ({completed_samples/total_samples:.01f}%)]" + f"[{completed_samples}/{total_samples} ({completed_samples*100/total_samples:.01f}%)]" ) + def consume_dict(self, mesasge_dict: dict[str, Any]): + raise NotImplementedError + class RichProgressSubscriber(MessageSubscriberIF[ProgressUpdate]): """A subscriber object for the RichProgress observable.""" diff --git a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py index cbab71898..7f9d3b576 100644 --- a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py +++ b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py @@ -41,24 +41,16 @@ def get_rich_progress_subscriber( @staticmethod def get_simple_progress_subscriber( - train_dataloader: LLMDataLoader, eval_dataloaders: list[LLMDataLoader], - world_size: int, - global_num_seen_samples: int, - local_rank: int, + train_dataloader_tag: str, + num_seen_steps: int, + num_target_steps: int, + global_rank: int, ) -> SimpleProgressSubscriber: - if local_rank == 0: - skip_num_local_train_batches = global_num_seen_samples // world_size // train_dataloader.batch_size - train_split_num_samples = { - train_dataloader.dataloader_tag: (len(train_dataloader) + skip_num_local_train_batches) - * world_size - * train_dataloader.batch_size - } + if global_rank == 0: + train_split_num_samples = {train_dataloader_tag: (num_target_steps)} - eval_splits_num_samples = { - dataloader.dataloader_tag: len(dataloader) * world_size * dataloader.batch_size - for dataloader in eval_dataloaders - } + eval_splits_num_samples = {dataloader.dataloader_tag: len(dataloader) for dataloader in eval_dataloaders} subscriber = SimpleProgressSubscriber(train_split_num_samples, eval_splits_num_samples) else: From 29352b7545011c4b446a00d497eaf41303c1ed6a Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Fri, 11 Oct 2024 08:17:25 +0200 Subject: [PATCH 156/161] refactor: rename norm layers for easier regex for weight initialization and decay groups --- .../audio_transformer/audio_transformer_model.py | 16 ++++++++-------- src/modalities/models/coca/coca_model.py | 6 +++--- .../parameter_name_filters.py | 15 +++++++++++++-- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/modalities/models/audio_transformer/audio_transformer_model.py b/src/modalities/models/audio_transformer/audio_transformer_model.py index d40e63993..7e9612533 100644 --- a/src/modalities/models/audio_transformer/audio_transformer_model.py +++ b/src/modalities/models/audio_transformer/audio_transformer_model.py @@ -102,7 +102,7 @@ def __init__( ) """ super().__init__() - self.ln = nn.LayerNorm(n_embd) + self.ln_1 = nn.LayerNorm(n_embd) self.pointwise_conv_1 = nn.Conv1d( n_embd, 2 * n_embd, @@ -117,7 +117,7 @@ def __init__( groups=n_embd, padding="same", ) - self.bn = nn.BatchNorm1d( + self.batch_norm = nn.BatchNorm1d( n_embd, ) self.swish = nn.SiLU() @@ -146,10 +146,10 @@ def forward( if x.shape[1] == 1: raise ValueError("The time dimension of the input to the convolution module cannot be 1!") - x = self.ln(x) + x = self.ln_1(x) x = x.transpose(1, 2) x = self.glu(self.pointwise_conv_1(x)) - x = self.swish(self.bn(self.depthwise_conv(x))) + x = self.swish(self.batch_norm(self.depthwise_conv(x))) x = self.pointwise_conv_2(x) return self.dropout(x.transpose(1, 2)) @@ -186,7 +186,7 @@ class instance. """ super().__init__() - self.ln1 = nn.LayerNorm(n_embd) + self.ln_1 = nn.LayerNorm(n_embd) self.entry_ffmodule = MLP( in_features=n_embd, act_fn=nn.SiLU, @@ -206,7 +206,7 @@ class instance. depthwise_conv_kernel_size, convmodule_dropout, ) - self.ln2 = nn.LayerNorm( + self.ln_2 = nn.LayerNorm( n_embd, ) self.exit_ffmodule = MLP( @@ -235,11 +235,11 @@ def forward( Returns: torch.Tensor: Output tensor of shape (B, T, D). """ - x = self.ln1(x) + x = self.ln_1(x) x = x + 0.5 * self.entry_ffmodule(x) x = x + self.attn(self.ln_mhsa(x), mask=mask) x = x + self.convmodule(x) - x = self.ln2(x) + x = self.ln_2(x) x = x + 0.5 * self.exit_ffmodule(x) return self.exit_ln(x) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index 3d4aa7594..20bbaceba 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -81,7 +81,7 @@ class CoCaConfig(BaseModel): video_encoder_config (Optional[VisionTransformerConfig]): config for the video encoder. Defaults to None text_decoder_config (TextDecoderConfig): Configuration for the text decoder. n_pool_head (int): Number of attention heads for pooling. - n_queries (int): Number of vision queries. + n_queries (int): Number of queries for attention pooling. bias_attn_pool (bool): Flag indicating whether to use bias in attention pooling. epsilon_attn_pool (float): Epsilon value for attention pooling. seed (Optional[int]): The random seed. Defaults to None @@ -183,7 +183,7 @@ def __init__( video_encoder_config (Optional[VisionTransformerConfig]): config for the video encoder. Defaults to None text_decoder_config (TextDecoderConfig): Configuration for the text decoder. n_pool_head (int): Number of attention heads for pooling. - n_queries (int): Number of vision queries. + n_queries (int): Number of queries for attention pooling. bias_attn_pool (bool): Flag indicating whether to use bias in attention pooling. epsilon_attn_pool (float): Epsilon value for attention pooling. seed (Optional[int]): The random seed. Defaults to None @@ -201,7 +201,7 @@ def __init__( "linear": [r"attention", r"\.attn", r"\.cross_attn", r"\.post_subsampler", r"_ffmodule", r"mlp"], "conv": [r"embedding_fn\.conv", r"project", r"\.subsampler", r"pointwise_conv", r"depthwise_conv"], "embedding": [r"wte", r"wpe", r"positional_embedding", r"time_embd"], - "norm": [r"norm", r"\.ln_", r"\.ln", r"\.bn", r"exit_ln"], + "norm": [r"norm", r"norm_latents", r"\.ln_", r"\.batch_norm", r"exit_ln"], "parameter": [r"_queries", r"logit_scale", r"\.latents", r"cls_token"], } super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) diff --git a/src/modalities/nn/model_initialization/parameter_name_filters.py b/src/modalities/nn/model_initialization/parameter_name_filters.py index df569094e..c15ac99bd 100644 --- a/src/modalities/nn/model_initialization/parameter_name_filters.py +++ b/src/modalities/nn/model_initialization/parameter_name_filters.py @@ -67,8 +67,10 @@ class RegexFilter(BaseModel): }, SupportWeightInitModels.COCA: { # we reject all bias and weight parameters belonging to norms + # optional .weight so that we include nn.Parameters WeightInitTypes.PLAIN: RegexFilter( - weights=[r"^(?!.*norm)(?!.*ln).*\.weight$"], biases=[r"^(?!.*norm)(?!.*ln).*\.bias$"] + weights=[r"^(?!.*norm)(?!.*ln)(?!.*batch_norm).*(.weight)?$"], + biases=[r"^(?!.*norm)(?!.*ln)(?!.*batch_norm).*.bias$"], ), # scaled init for residual layers: # https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf (pp 4) @@ -77,6 +79,15 @@ class RegexFilter(BaseModel): r"transformer\.h\.\d+\.attn\.c_proj\.weight", ] ), - WeightInitTypes.SCALED_EMBED: RegexFilter(weights=[], biases=[]), + WeightInitTypes.SCALED_EMBED: RegexFilter( + weights=[ + # embedding weights + r"\.wte\.weight", + r"\.wpe\.weight", + r"positional_embeddings\.weight", + r"positional_embedding_fn\.weight", + r"time_embd$", + ] + ), }, } From 2aa4fc0fc4c98cbc86469e30485d4d98bbb04867 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Fri, 11 Oct 2024 08:20:18 +0200 Subject: [PATCH 157/161] test: fix weight initialization and weight decay tests for coca --- tests/test_initialization.py | 34 +++++++---- tests/test_optimizer_factory.py | 25 +++++++- .../coca_config_initialization.yaml | 58 ++++++++++++++++--- 3 files changed, 97 insertions(+), 20 deletions(-) diff --git a/tests/test_initialization.py b/tests/test_initialization.py index a169ade87..8e38ba113 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -82,7 +82,7 @@ def _load_coca(initialization_type: str, std: float | str) -> FSDP: coca_wrapped_model = ModelFactory.get_fsdp_wrapped_model( coca_model, sync_module_states=True, - block_names=["TransformerBlock", "VisionTransformerBlock"], + block_names=["TransformerBlock", "VisionTransformerBlock", "ConformerBlock"], mixed_precision_settings=MixedPrecisionSettings.FP_16, sharding_strategy=ShardingStrategy.NO_SHARD, ) @@ -111,9 +111,23 @@ def _load_model(model_name: str, initialization: str = "plain", std: float | str "other": [], } MAPPING_COCA = { - "embedding": [], # TODO + "embedding": [ + r"wte\.weight$", + r"wpe\.weight$", + r"positional_embeddings\.weight$", + "positional_embedding_fn.weight$", + "time_embd$", + ], "weight-projection": [r"c_proj\.weight$"], # TODO - "weight-norm": [r"norm[12]\.weight$", r"ln_[1234f]\.weight$"], # TODO + "weight-norm": [ + r"norm[12]?\.weight$", + r"norm_latents\.weight$", + r"ln_[1234f]\.weight$", + r"ln_mhsa.weight", + r"batch_norm.*weight$", + r"exit_ln.weight$", + r"attention_norm.weight$", + ], "weight-normal": [r"\.weight$"], "other": [r"conv", r".*(? dict[str, Optional[torch.T GPT2_WEIGHT_NORMAL = GPT2_ALL - GPT2_WEIGHT_PROJECTION - GPT2_EMBEDDING - GPT2_WEIGHT_NORM - GPT2_BIAS # 40107264 COCA_NLAYERS = 6 + 6 # text + multimodal -COCA_ALL = 184502784 -COCA_EMBEDDING = 0 # TODO -COCA_WEIGHT_PROJECTION = 14745600 -COCA_WEIGHT_NORM = 34560 -COCA_BIAS = 191232 -COCA_OTHER = 198912 -COCA_WEIGHT_NORMAL = 169332480 +COCA_ALL = 277424641 +COCA_EMBEDDING = 40118016 +COCA_WEIGHT_PROJECTION = 21233664 +COCA_WEIGHT_NORM = 768 * 79 +COCA_BIAS = 292608 +COCA_OTHER = 657409 +COCA_WEIGHT_NORMAL = 215062272 NR_PARAMETERS = { "gpt2": { diff --git a/tests/test_optimizer_factory.py b/tests/test_optimizer_factory.py index 4f273ad01..46ac8ee72 100644 --- a/tests/test_optimizer_factory.py +++ b/tests/test_optimizer_factory.py @@ -55,14 +55,14 @@ def _load_gpt2() -> FSDP: def _load_coca() -> FSDP: - config_file_path = _ROOT_DIR / Path("tests/models/coca/coca_config.yaml") + config_file_path = _ROOT_DIR / Path("tests/models/coca/coca_config_img_aud_vid.yaml") config_dict = load_app_config_dict(config_file_path=config_file_path) coca_config = CoCaConfig.model_validate(config_dict) coca_model = CoCa(**dict(coca_config)) coca_wrapped_model = ModelFactory.get_fsdp_wrapped_model( coca_model, sync_module_states=True, - block_names=["TransformerBlock", "VisionTransformerBlock"], + block_names=["TransformerBlock", "VisionTransformerBlock", "ConformerBlock"], mixed_precision_settings=MixedPrecisionSettings.FP_16, sharding_strategy=ShardingStrategy.NO_SHARD, ) @@ -73,7 +73,16 @@ def _load_coca() -> FSDP: GPT2_LINEAR = 66130944 GPT2_EMBEDDING = 768 * (50304 + 2048) # n_embd * (vocab_size + sequence_length) GPT2_LAYERNORM = 768 * 50 # n_embd * num_layer_norms -COCA_ALL = 184502784 + +COCA_LINEAR = 227321088 +COCA_CONV = 9226752 +# (n_embd * vocab_size) + +# (n_embd * (text_block_size + img_block_size + vid_block_size + aud_block_size + num_frames) +COCA_EMBEDDING = (768 * 50304) + (768 * ((1024 + 1) + 196 + 196 + 500 + 16)) +COCA_NORM = 768 * 152 # n_embd * norm layers +# 3 * (n_queries + 1) n_embd + logit_scale + (1(cls_token) * n_embd) + (n_latents * n_embd) +COCA_PARAMETER = (3 * 257 * 768) + 1 + (768) + (64 * 768) +COCA_ALL = COCA_LINEAR + COCA_CONV + COCA_EMBEDDING + COCA_NORM + COCA_PARAMETER @pytest.mark.skipif( @@ -92,6 +101,16 @@ def _load_coca() -> FSDP: ("gpt2", 1e-1, ["non-existing-group"], False, None, None), ("coca", 0, [], True, 0, COCA_ALL), ("coca", 1e-1, [], True, COCA_ALL, 0), + ("coca", 1e-1, ["embedding"], True, COCA_ALL - COCA_EMBEDDING, COCA_EMBEDDING), + ("coca", 1e-1, ["embedding", "norm"], True, COCA_ALL - COCA_EMBEDDING - COCA_NORM, COCA_EMBEDDING + COCA_NORM), + ( + "coca", + 1e-1, + ["embedding", "norm", "parameter"], + True, + COCA_LINEAR + COCA_CONV, + COCA_EMBEDDING + COCA_NORM + COCA_PARAMETER, + ), ("coca", 1e-1, ["non-existing-group"], False, None, None), ], ) diff --git a/tests/test_yaml_configs/coca_config_initialization.yaml b/tests/test_yaml_configs/coca_config_initialization.yaml index bda3fb253..42547001c 100644 --- a/tests/test_yaml_configs/coca_config_initialization.yaml +++ b/tests/test_yaml_configs/coca_config_initialization.yaml @@ -19,13 +19,39 @@ model_raw: variant_key: coca config: prediction_key: logits - vision_embd_prediction_key: vision_embeddings + audio_embd_prediction_key: audio_embeddings + image_embd_prediction_key: image_embeddings + video_embd_prediction_key: video_embeddings text_embd_prediction_key: text_embeddings - vision_cls_prediction_key: vision_cls + image_cls_prediction_key: image_cls + image_text_cls_prediction_key: image_text_cls + audio_cls_prediction_key: audio_cls + audio_text_cls_prediction_key: audio_text_cls + video_cls_prediction_key: video_cls + video_text_cls_prediction_key: video_text_cls text_cls_prediction_key: text_cls - vision_encoder_config: + modality_keys: + - audio + - images + - video + is_audio_video: false + individual_datasets: true + logit_scale_prediction_key: logit_scale + audio_encoder_config: + sample_key: audio + prediction_key: audio_embeddings + block_size: 500 + n_mels: 128 + n_embd: 768 + n_heads: 4 + n_conformer_blocks: 3 + attention_config: + attention_engine_type: default_attention + pointwise_conv_kernel_size: 1 + depthwise_conv_kernel_size: 31 + image_encoder_config: sample_key: images - prediction_key: vision_embeddings + prediction_key: image_embeddings img_size: 224 n_classes: Null # Disable vision transformer head n_layer: 6 @@ -39,6 +65,24 @@ model_raw: n_img_channels: 3 add_cls_token: False bias: True + video_encoder_config: + sample_key: video + prediction_key: video_embeddings + img_size: 224 # 288 in the original coca + n_classes: Null # Disable vision transformer head + n_layer: 6 + attention_config: + attention_engine_type: default_attention + n_head: 8 + n_embd: 768 + dropout: 0.0 + patch_size: 16 # 18 in the original coca + patch_stride: 16 # 18 in the original coca + n_img_channels: 3 + add_cls_token: False + bias: True + num_video_frames: 16 + n_latents: 64 text_decoder_config: sample_key: input_ids prediction_key: logits @@ -55,7 +99,7 @@ model_raw: bias: true activation: swiglu epsilon: 1e-5 - n_pool_head: 8 - n_vision_queries: 256 + n_pool_head: 12 + n_queries: 256 bias_attn_pool: False - epsilon_attn_pool: 1e-5 \ No newline at end of file + epsilon_attn_pool: 1e-5 From c84bbda6c2935187e2472f8c91cef93e22d44403 Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Fri, 11 Oct 2024 11:07:16 +0200 Subject: [PATCH 158/161] fix: update directory name for getting started example --- tests/tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests.py b/tests/tests.py index c4e0bc258..5a4adea17 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -106,9 +106,9 @@ def main(cpu: bool = False, single_gpu: bool = False, multi_gpu: bool = False, d # getting started example print("\n=== RUN GETTING STARTED EXAMPLE ===") - run_getting_started_example_directory = _ROOT_DIR / "examples" / "getting_started" + run_getting_started_example_directory = _ROOT_DIR / "tutorials" / "getting_started" run_getting_started_example_script = ( - _ROOT_DIR / "examples" / "getting_started" / "run_getting_started_example.sh" + _ROOT_DIR / "tutorials" / "getting_started" / "run_getting_started_example.sh" ) assert isfile( run_getting_started_example_script From e33076fb56d60d9bd2c94f6c45a8c6bc6ec9370e Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Fri, 11 Oct 2024 15:36:57 +0200 Subject: [PATCH 159/161] docs: update changelog with info about CoCa PR --- CHANGELOG_DEV.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/CHANGELOG_DEV.md b/CHANGELOG_DEV.md index 36a41c9c9..32c54ed41 100644 --- a/CHANGELOG_DEV.md +++ b/CHANGELOG_DEV.md @@ -85,3 +85,19 @@ This PR mainly addresses the warmstart of model training, e.g., after GPU crashe **Breaking Changes** * the settings part of the configs have been completely refactored + + +## PR #263 CoCa model updates + +This PR adds updates to the CoCa model: + + +**General Changes** +* add AudioTransformer model +* update the VisionTransformer model for video +* add the MultimodalWebDataset dataset for loading audio-text, image-text and video-text in the webdataset format +* add a multi-loss function for specifying a weighted-sum of different losses +* update the CoCa model to include encoders for video and audio + +**Breaking Changes** +* the LLMDataLoader now contains a Pytorch Dataloader object instead of inheriting from it. From 71a5bd193992456a7c5adfb0edff9afb871d597c Mon Sep 17 00:00:00 2001 From: Santosh Thoduka Date: Sat, 12 Oct 2024 17:14:01 +0200 Subject: [PATCH 160/161] chore: fix linting --- src/modalities/evaluator.py | 2 +- tests/test_loss_functions.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index 6e51463bc..8556ce647 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List +from typing import Callable, List import torch import torch.distributed as dist diff --git a/tests/test_loss_functions.py b/tests/test_loss_functions.py index 86b6c0d7c..eccd05765 100644 --- a/tests/test_loss_functions.py +++ b/tests/test_loss_functions.py @@ -79,9 +79,7 @@ def setup_distributed(monkeypatch): def test_clip_loss(clip_loss_object, clip_loss_forward_batch, setup_distributed): - loss_fn = clip_loss_object - forward_batch = clip_loss_forward_batch loss_fn(clip_loss_forward_batch) @@ -105,18 +103,18 @@ def test_multiple_functions_loss_initialized_with_single_loss( def test_multiple_functions_loss_reset_cumulated_individual_losses( multiple_functions_loss_object_with_two_losses, ): - loss = multiple_functions_loss_object_with_two_losses num_losses = len(loss.groups) loss.cumulated_individual_losses = torch.randn(num_losses) loss.reset_cumulated_individual_losses() - assert (loss.cumulated_individual_losses, torch.zeros(num_losses)) + assert torch.equal( + loss.cumulated_individual_losses, torch.zeros(num_losses, device=loss.cumulated_individual_losses.device) + ) @pytest.fixture def multiple_functions_loss_forward_batch() -> InferenceResultBatch: - targets = {"target_ids": torch.Tensor([[1, 2, 1], [1, 1, 2]])} predictions = { "image_cls": torch.Tensor([[1, 2, 3], [4, 5, 6]]).to("cuda"), From de9baab0e574b7303870a84114baaabbbe0e9e89 Mon Sep 17 00:00:00 2001 From: Thomas Holz Date: Wed, 23 Oct 2024 10:21:13 +0000 Subject: [PATCH 161/161] docs: fix minor docstring inconsistencies --- src/modalities/dataloader/dataset.py | 5 +- .../audio_transformer_model.py | 58 +++---------------- 2 files changed, 10 insertions(+), 53 deletions(-) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index c4b1a6f9c..f2ef94d93 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -548,7 +548,7 @@ class AudioTransformConfig(TransformConfig): n_mels (int): Number of mel-frequency bands. Defaults to 128. freq_domain_mask_length (int): Length of frequency masking during training. Defaults to 30. time_domain_mask_length (int): Length of time masking during training. Defaults to 100. - block_size_audio_encoder (int): The target block size for audio encoding. + block_size_audio_encoder (int): Maximum allowed input length to the audio encoder. """ is_training: bool = False @@ -578,7 +578,7 @@ def __init__( Initializes the AudioTransform class. Args: - block_size_audio_encoder (int): The target block size for audio encoding. + block_size_audio_encoder (int): Maximum allowed input length to the audio encoder. is_training (bool, optional): Whether the module is in training mode. Defaults to False. n_mels (int, optional): Number of mel-frequency bands. Defaults to 128. freq_domain_mask_length (int, optional): Length of frequency masking. Defaults to 30. @@ -1046,7 +1046,6 @@ def _transform_video(self, sample: dict[str, Any]) -> dict[str, Any]: return sample def _transform_audio(self, sample: dict[str, Any]) -> dict[str, Any]: - # Apply audio transforms to the input sample. source_key, target_key = self.modality_key_mapping[ModalityEnum.AUDIO] transform: AudioTransform = self.modality_transforms[ModalityEnum.AUDIO] sample[target_key], sample["audio_len"] = transform(sample[source_key]) diff --git a/src/modalities/models/audio_transformer/audio_transformer_model.py b/src/modalities/models/audio_transformer/audio_transformer_model.py index 7e9612533..66184f9f8 100644 --- a/src/modalities/models/audio_transformer/audio_transformer_model.py +++ b/src/modalities/models/audio_transformer/audio_transformer_model.py @@ -14,45 +14,27 @@ class AudioTransformerConfig(BaseModel): This configuration class defines all necessary parameters to instantiate and configure an `AudioTransformer` model. - Args: + Attributes: sample_key (str): The key in the input dictionary that contains the audio samples. prediction_key (str): The key under which the model's output will be stored in the output dictionary. block_size (int): The size of each block for positional embeddings. Must be a positive integer. - n_mels (int): The number of mel-frequency bands used for input audio feature extraction. + n_mels (int): The number of mel-frequency bands used for input audio feature extraction. Must be a positive integer. n_embd (int): The embedding dimension used throughout the model. Must be a positive integer. n_heads (int): The number of attention heads in the conformer blocks. Must be a positive integer. - n_conformer_blocks (int): The number of conformer blocks to include in the transformer model. + n_conformer_blocks (int): The number of conformer blocks to include in the transformer model. Must be a positive integer. attention_config (AttentionConfig): Configuration object for attention mechanisms. - pointwise_conv_kernel_size (int): Kernel size for the pointwise convolutional layers in conformer blocks. + pointwise_conv_kernel_size (int): Kernel size for the pointwise convolutional layers in conformer blocks. Must be a positive integer. - depthwise_conv_kernel_size (int): Kernel size for the depthwise convolutional layers in conformer blocks. + depthwise_conv_kernel_size (int): Kernel size for the depthwise convolutional layers in conformer blocks. Must be a positive integer. - ffmodule_dropout (float, optional): Dropout rate for feed-forward modules in conformer blocks. + ffmodule_dropout (float, optional): Dropout rate for feed-forward modules in conformer blocks. Must be a float less than 1.0. Default is 0.1. - attn_dropout (float, optional): Dropout rate for attention mechanisms. Must be a float less than 1.0. + attn_dropout (float, optional): Dropout rate for attention mechanisms. Must be a float less than 1.0. Default is 0.1. - convmodule_dropout (float, optional): Dropout rate for depthwise convolutional layers in conformer blocks. + convmodule_dropout (float, optional): Dropout rate for depthwise convolutional layers in conformer blocks. Must be a float less than 1.0. Default is 0.1. - - Returns: - AudioTransformerConfig: A configuration object that can be used to instantiate an `AudioTransformer` model with\ - the specified parameters. - - Examples: - >>> audio_encoder_config = AudioTransformerConfig( - sample_key="audio", - prediction_key="audio_embeddings", - block_size=2_000, - n_mels=128, - n_embd=768, - n_heads=8, - n_conformer_blocks=2, - attention_config=AttentionConfig(attention_engine_type="default_attention"), - pointwise_conv_kernel_size=1, - depthwise_conv_kernel_size=31 - ) """ sample_key: str @@ -92,14 +74,6 @@ def __init__( pointwise_conv_kernel_size (int): The kernel size for both the first and second pointwise convolutions. depthwise_conv_kernel_size (int): The kernel size for the depthwise convolution. dropout (float): Dropout rate applied after each layer. Must be a float between 0 and 1. - - Examples: - >>> module = ConvolutionModule( - n_embd=768, - pointwise_conv_kernel_size=1, - depthwise_conv_kernel_size=31, - dropout=0.1 - ) """ super().__init__() self.ln_1 = nn.LayerNorm(n_embd) @@ -285,22 +259,6 @@ def __init__( attn_dropout (float): Dropout rate for attention mechanisms. Default is 0.1. convmodule_dropout (float): Dropout rate for depthwise convolutional layers in conformer blocks. Default is 0.1. - - Examples: - >>> audio_encoder_config = { - "sample_key": "audio", - "prediction_key": "audio_embeddings", - "block_size": 2000, - "n_mels": 128, - "n_embd": 768, - "n_heads": 8, - "n_conformer_blocks": 2, - "attention_config": { - "attention_engine_type": "default_attention" - }, - "pointwise_conv_kernel_size": 1, - "depthwise_conv_kernel_size": 31 - } """ super().__init__() self.sample_key = sample_key