Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
bbee65c
feat: added enhanced logging
le1nux Jan 2, 2025
90f5d5a
refactor: scaled the index creation and tokenization by supporting mu…
le1nux Jan 2, 2025
e876c97
refactor: finalized processing strategies
le1nux Jan 5, 2025
3fe7690
feat: introduced process controller
le1nux Jan 5, 2025
3d79cc4
feat: added custom process implementation
le1nux Jan 5, 2025
8f41a74
feat: added interface for the processing strategies
le1nux Jan 5, 2025
a05bd9f
refactor: bug fixing in TokenizationInstantiationModel
le1nux Jan 5, 2025
ac8f74e
refactor: API uses now process controller
le1nux Jan 5, 2025
95f30d4
feat: added multiple queue destinations for strategy
le1nux Jan 8, 2025
d5d59a7
feat: temp env var decorator
le1nux Jan 8, 2025
d242fd3
refactor: moved WorkerTypes
le1nux Jan 9, 2025
3222469
refactor: refactored tokenization strategies
le1nux Jan 9, 2025
be14b0d
feat: added ProcessorException and ProcessingStrategyDoneException
le1nux Jan 9, 2025
6f54287
refactor: generalized multiprocesssing system
le1nux Jan 10, 2025
ac22596
refactor: adapted the create_packed_data endpoint to fit the previous…
le1nux Jan 10, 2025
f2422a1
refactor: fixed issues in joining processors
le1nux Jan 10, 2025
ad55dfb
refactor: finalized processors and strategy setup
le1nux Jan 10, 2025
4da0959
feat: enhanced tokenization instantiation with pydantic model
le1nux Jan 12, 2025
004daeb
feat: added PreprocessingNumberConversion
le1nux Jan 12, 2025
8be1ad5
feat: integrated global index creation
le1nux Jan 13, 2025
438858f
chore: minor renaming fix
le1nux Jan 13, 2025
b5273b2
chore: fixed imports
le1nux Jan 13, 2025
ea694a1
chore: adapted configs after renamings
le1nux Jan 13, 2025
b956f90
chore: Merge branch 'main' into tokenization_at_scale
le1nux Jan 13, 2025
1a515ba
refactor: improved code quality in strategies.py
le1nux Jan 13, 2025
716a98c
refactor: integrated changes from PR #283 into tokenization strategy
le1nux Jan 13, 2025
3100919
chore: Merge branch 'tokenization_at_scale' of github.com:Modalities/…
le1nux Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
rev: 23.9.1
hooks:
- id: black
language_version: python3.10
language_version: python3.11
stages: [pre-commit]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.278
Expand Down
4 changes: 2 additions & 2 deletions config_files/training/config_example_coca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ settings:
sequence_length: 256
training_target:
num_target_tokens:
component_key: number_conversion
component_key: training_number_conversion
variant_key: num_tokens_from_num_steps
config:
num_steps: ${settings.training_target.num_target_steps}
Expand All @@ -36,7 +36,7 @@ settings:
sequence_length: ${settings.step_profile.sequence_length}
gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps}
num_target_steps: # for the batch progress subscriber
component_key: number_conversion
component_key: training_number_conversion
variant_key: num_steps_from_num_samples
config:
num_ranks: ${settings.cuda_env.world_size}
Expand Down
4 changes: 2 additions & 2 deletions config_files/training/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ settings:
sequence_length: 256
training_target:
num_target_tokens:
component_key: number_conversion
component_key: training_number_conversion
variant_key: num_tokens_from_packed_mem_map_dataset_continuous
config:
dataset_path: ${settings.paths.train_dataset_path}
Expand All @@ -36,7 +36,7 @@ settings:
local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size}
gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps}
num_target_steps: # for the batch progress subscriber
component_key: number_conversion
component_key: training_number_conversion
variant_key: num_steps_from_num_tokens
config:
num_ranks: ${settings.cuda_env.world_size}
Expand Down
30 changes: 25 additions & 5 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

from modalities.api import (
convert_pytorch_to_hf_checkpoint,
create_raw_data_index,
create_global_index,
create_local_index,
create_shuffled_global_index,
generate_text,
merge_packed_data_files,
pack_encoded_data,
Expand All @@ -34,6 +36,7 @@
from modalities.running_env.cuda_env import CudaEnv
from modalities.trainer import Trainer
from modalities.util import get_total_number_of_trainable_parameters, print_rank_0
from modalities.utils.logging import get_logger


@click.group()
Expand Down Expand Up @@ -123,15 +126,15 @@ def data():
pass


@data.command(name="create_raw_index")
@data.command(name="create_local_index")
@click.argument("src_path", type=Path)
@click.option(
"--index_path",
type=Path,
default=None,
help="output path for index. will use parent directory of src_path if none.",
)
def CMD_entry_point_data_create_raw_index(src_path: Path, index_path: Path):
def CMD_entry_point_data_create_local_index(src_path: Path, index_path: Path):
"""Utility CMD for indexing the content of a large jsonl-file.
Background is the ability to further process the respective file without loading it,
while splitting its content line-based. This step is necessary in advance of further processing like tokenization.
Expand All @@ -144,11 +147,27 @@ def CMD_entry_point_data_create_raw_index(src_path: Path, index_path: Path):
Raises:
ValueError: If the index file already exists.
"""
create_raw_data_index(src_path=src_path, index_path=index_path)
create_local_index(src_path=src_path, index_path=index_path)


@data.command(name="create_global_index")
@click.option("--file_list_path", type=Path, required=True)
@click.option("--root_index_path", type=Path, required=True)
@click.option("--global_index_root_path", type=Path, required=True)
def CMD_entry_point_create_global_index(file_list_path: Path, root_index_path: Path, global_index_root_path: Path):
create_global_index(
file_list_path=file_list_path, root_index_path=root_index_path, global_index_root_path=global_index_root_path
)


@data.command(name="create_shuffled_global_index")
@click.option("--global_index_file_path", type=Path, required=True)
def CMD_entry_point_create_shuffled_global_index(global_index_file_path: Path):
create_shuffled_global_index(global_index_file_path=global_index_file_path)


@data.command(name="pack_encoded_data")
@click.argument("config_path", type=FilePath)
@click.option("--config_path", type=FilePath, required=True)
def CMD_entry_point_pack_encoded_data(config_path: FilePath):
"""Utility to encode an indexed, large jsonl-file.
(see also `create_index` for more information)
Expand All @@ -158,6 +177,7 @@ def CMD_entry_point_pack_encoded_data(config_path: FilePath):
Args:
config_path (FilePath): Path to the config file describing the tokenization setup.
"""
get_logger().info(f"Loading config from {config_path}.")
config_dict = load_app_config_dict(config_path)

pack_encoded_data(config_dict=config_dict)
Expand Down
203 changes: 173 additions & 30 deletions src/modalities/api.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,42 @@
#!/usr/bin/env python

import multiprocessing as mp
import os
from enum import Enum
from pathlib import Path

from pydantic import FilePath

import modalities.dataloader.preprocessing.indexation.global_indexation as global_indexation
import modalities.inference.inference as inference
from modalities.checkpointing.checkpoint_conversion import CheckpointConversion
from modalities.config.component_factory import ComponentFactory
from modalities.config.instantiation_models import PackedDatasetComponentsInstantiationModel
from modalities.dataloader.create_index import IndexGenerator
from modalities.dataloader.create_packed_data import EmbeddedStreamData, PackedDataGenerator, join_embedded_stream_data
from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader
from modalities.config.instantiation_models import TokenizationInstantiationModel
from modalities.dataloader.preprocessing.indexation.local_indexation import IndexGenerator
from modalities.dataloader.preprocessing.queued_processing.process_controller import PipelineStep, ProcessController
from modalities.dataloader.preprocessing.queued_processing.processors import Processor
from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import (
EmbeddedStreamData,
join_embedded_stream_data,
)
from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LocalLargeFileLinesReader
from modalities.dataloader.preprocessing.tokenization.strategies import ProcessingStrategyFactory, WorkerTypes
from modalities.models.huggingface_adapters.hf_adapter import HFModelAdapter
from modalities.registry.components import COMPONENTS
from modalities.registry.registry import Registry
from modalities.utils.env_variables import temporary_env_vars_decorator
from modalities.utils.logging import get_logger


def create_raw_data_index(src_path: Path, index_path: Path):
class FileExistencePolicy(Enum):
SKIP = "skip"
ERROR = "error"
OVERRIDE = "override"


def create_local_index(
src_path: Path, index_path: Path, file_existence_policy: FileExistencePolicy = FileExistencePolicy.ERROR
):
"""Creates the index file for the content of a large jsonl-file. The index file
contains the byte-offsets and lengths of each line in the jsonl-file.
Background is the ability to further process the respective file without loading it,
Expand All @@ -31,17 +50,40 @@ def create_raw_data_index(src_path: Path, index_path: Path):
Raises:
ValueError: If the index file already exists.
"""
index_path = LargeFileLinesReader.default_index_path(src_path, index_path)
os.makedirs(index_path.parent, exist_ok=True)
index_path = LocalLargeFileLinesReader.default_index_path(src_path, index_path)
if index_path.exists():
raise ValueError("index already exists. delete it or specify different output folder.")
if file_existence_policy == FileExistencePolicy.SKIP:
get_logger(name="main").warning(f"Index already exists at {str(index_path)}. Skipping index creation.")
return
elif file_existence_policy == FileExistencePolicy.OVERRIDE:
get_logger(name="main").warning(f"Index already exists at {str(index_path)}. Overriding it.")
os.remove(index_path)
elif file_existence_policy == FileExistencePolicy.ERROR:
raise ValueError("index already exists. delete it or specify different output folder.")
else:
raise ValueError(f"Unknown file existence policy: {file_existence_policy}")

get_logger(name="main").info(
f"Reading raw data from {str(src_path)} and" f" writing index to {str(index_path)} ..."
)
os.makedirs(index_path.parent, exist_ok=True)

print(f"reading raw data from {src_path}")
print(f"writing index to {index_path}")
generator = IndexGenerator(src_path)
generator.create_index(index_path)


def create_global_index(file_list_path: Path, root_index_path: Path, global_index_root_path: Path) -> Path:
global_index_file_path = global_indexation.create_global_index(
file_list_path, root_index_path, global_index_root_path
)
return global_index_file_path


def create_shuffled_global_index(global_index_file_path: Path) -> Path:
global_shuffled_index_file_path = global_indexation.create_shuffled_global_index(global_index_file_path)
return global_shuffled_index_file_path


def generate_text(config_file_path: FilePath):
"""Inference function to generate text with a given model.

Expand Down Expand Up @@ -70,6 +112,9 @@ def convert_pytorch_to_hf_checkpoint(
return hf_model


# not setting this can cause deadlocks when using hf's "FastTokenizers". See also:
# https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning/67254879#67254879
@temporary_env_vars_decorator({"TOKENIZERS_PARALLELISM": "false"})
def pack_encoded_data(config_dict: dict):
"""Packs and encodes an indexed, large jsonl-file.
(see also `create_index` for more information)
Expand All @@ -79,31 +124,129 @@ def pack_encoded_data(config_dict: dict):
Args:
config_dict (dict): Dictionary containing the configuration for the packed data generation.
"""

# TODO: if we want to use alternative entrypoints together with the ResolverRegistry,
# we can currently not rely on the existing class resolver.
# This is based on its connection to the overall `AppConfig`.
# One would requires an object of it to instantiate the ResolverRegistry.
# This could get resolved by implementing on own ResolverRegistry for each entrypoint or adapting the existing
# ResolverRegistry to work dynamically with any type-hinted config object from config.py.
registry = Registry(COMPONENTS)
component_factory = ComponentFactory(registry=registry)
components: PackedDatasetComponentsInstantiationModel = component_factory.build_components(
config_dict=config_dict, components_model_type=PackedDatasetComponentsInstantiationModel
instantion_model: TokenizationInstantiationModel = component_factory.build_components(
config_dict=config_dict, components_model_type=TokenizationInstantiationModel
)

# build the queues
reader_q, tokenizer_q, writer_q, logging_message_q = ProcessingStrategyFactory.get_process_queues(
reader_q_maxsize=instantion_model.reader_q_maxsize,
writer_q_maxsize=instantion_model.writer_q_maxsize,
tokenizer_q_maxsize=instantion_model.tokenizer_q_maxsize,
)

# build the workers
stop_event = mp.Event()
reader_q_key = "reader_q"
tokenizer_q_key = "tokenizer_q"
writer_q_key = "writer_q"
logging_message_q_key = "logging_message_q"

populating_worker = Processor(
out_qs={reader_q_key: reader_q, logging_message_q_key: logging_message_q},
in_q_timeout=instantion_model.in_q_timeout,
out_q_timeout=instantion_model.out_q_timeout,
strategy=ProcessingStrategyFactory.get_populating_strategy(
reader_q_key=reader_q_key,
logging_message_q_key=logging_message_q_key,
index_start=instantion_model.populate_worker_settings.index_start,
num_samples=instantion_model.populate_worker_settings.num_samples,
batch_size=instantion_model.populate_worker_settings.batch_size,
),
process_type=WorkerTypes.POPULATOR,
process_id=0,
logging_message_q_key=logging_message_q_key,
set_stop_event_on_processing_error=True,
stop_event=stop_event,
)

generator = PackedDataGenerator(
components.settings.src_path,
index_path=components.settings.index_path,
tokenizer=components.tokenizer,
eod_token=components.settings.eod_token,
jq_pattern=components.settings.jq_pattern,
number_of_processes=components.settings.num_cpus,
processing_batch_size=components.settings.processing_batch_size,
raw_samples_queue_size=components.settings.raw_samples_queue_size,
processed_samples_queue_size=components.settings.processed_samples_queue_size,
reader_settings = instantion_model.reader_worker_settings.reader_settings
reader_workers = [
Processor(
in_q=reader_q,
out_qs={tokenizer_q_key: tokenizer_q, logging_message_q_key: logging_message_q},
in_q_timeout=instantion_model.in_q_timeout,
out_q_timeout=instantion_model.out_q_timeout,
strategy=ProcessingStrategyFactory.get_reader_strategy(
reader_settings, tokenizer_q_key=tokenizer_q_key, logging_message_q_key=logging_message_q_key
),
process_type=WorkerTypes.READER,
process_id=i,
logging_message_q_key=logging_message_q_key,
set_stop_event_on_processing_error=False,
stop_event=stop_event,
)
for i in range(instantion_model.reader_worker_settings.num_workers)
]

tokenizer_workers = [
Processor(
in_q=tokenizer_q,
out_qs={writer_q_key: writer_q, logging_message_q_key: logging_message_q},
in_q_timeout=instantion_model.in_q_timeout,
out_q_timeout=instantion_model.out_q_timeout,
strategy=ProcessingStrategyFactory.get_tokenizer_strategy(
tokenizer_settings=instantion_model.tokenizer_worker_settings.tokenizer_settings,
writer_q_key=writer_q_key,
logging_message_q_key=logging_message_q_key,
),
process_type=WorkerTypes.TOKENIZER,
process_id=i,
logging_message_q_key=logging_message_q_key,
set_stop_event_on_processing_error=False,
stop_event=stop_event,
)
for i in range(instantion_model.tokenizer_worker_settings.num_workers)
]

writer_worker = Processor(
in_q=writer_q,
out_qs={logging_message_q_key: logging_message_q},
in_q_timeout=instantion_model.in_q_timeout,
out_q_timeout=instantion_model.out_q_timeout,
strategy=ProcessingStrategyFactory.get_writing_strategy(
ww_settings=instantion_model.writer_worker_settings, logging_message_q_key=logging_message_q_key
),
process_type=WorkerTypes.WRITER,
process_id=0,
logging_message_q_key=logging_message_q_key,
set_stop_event_on_processing_error=True,
stop_event=stop_event,
)

logging_worker = Processor(
in_q=logging_message_q,
out_qs={},
in_q_timeout=instantion_model.in_q_timeout,
out_q_timeout=instantion_model.out_q_timeout,
strategy=ProcessingStrategyFactory.get_progress_logging_strategy(
logging_interval=instantion_model.logging_worker_settings.logging_interval,
total_num_samples=instantion_model.logging_worker_settings.num_samples,
q_dict={
reader_q_key: reader_q,
tokenizer_q_key: tokenizer_q,
writer_q_key: writer_q,
logging_message_q_key: logging_message_q,
},
),
process_type=WorkerTypes.LOGGING,
process_id=0,
set_stop_event_on_processing_error=False,
stop_event=stop_event,
)
generator.run(components.settings.dst_path)

pipeline_steps = [
PipelineStep(name="populating", input_queue=None, processors=[populating_worker], poisonable=False),
PipelineStep(name="reading", input_queue=reader_q, processors=reader_workers, poisonable=True),
PipelineStep(name="tokenizing", input_queue=tokenizer_q, processors=tokenizer_workers, poisonable=True),
PipelineStep(name="writing", input_queue=writer_q, processors=[writer_worker], poisonable=True),
PipelineStep(name="logging", input_queue=logging_message_q, processors=[logging_worker], poisonable=True),
]

process_controller = ProcessController(pipeline_steps=pipeline_steps, stop_event=stop_event)
process_controller.run()


def merge_packed_data_files(src_paths: list[Path], target_path: Path):
Expand Down
Loading
Loading