From bbee65cc1f709c5c068cd24040a5175e6a8cd9b1 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Thu, 2 Jan 2025 20:47:26 +0100 Subject: [PATCH 01/25] feat: added enhanced logging --- src/modalities/__main__.py | 4 +++- src/modalities/exceptions.py | 6 ++++++ src/modalities/utils/logging.py | 10 ++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 src/modalities/utils/logging.py diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 34acf6263..164209dcc 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -9,6 +9,7 @@ import click import click_pathlib +from modalities.utils.logging import get_logger from pydantic import BaseModel, FilePath from modalities.api import ( @@ -148,7 +149,7 @@ def CMD_entry_point_data_create_raw_index(src_path: Path, index_path: Path): @data.command(name="pack_encoded_data") -@click.argument("config_path", type=FilePath) +@click.option("--config_path", type=FilePath, required=True) def CMD_entry_point_pack_encoded_data(config_path: FilePath): """Utility to encode an indexed, large jsonl-file. (see also `create_index` for more information) @@ -158,6 +159,7 @@ def CMD_entry_point_pack_encoded_data(config_path: FilePath): Args: config_path (FilePath): Path to the config file describing the tokenization setup. """ + get_logger().info(f"Loading config from {config_path}.") config_dict = load_app_config_dict(config_path) pack_encoded_data(config_dict=config_dict) diff --git a/src/modalities/exceptions.py b/src/modalities/exceptions.py index 0ac49dcc1..1bc34a255 100644 --- a/src/modalities/exceptions.py +++ b/src/modalities/exceptions.py @@ -24,3 +24,9 @@ class OptimizerError(Exception): class ConfigError(Exception): pass + +class EmptySampleError(RuntimeError): + pass + +class ReaderIndexationError(Exception): + pass \ No newline at end of file diff --git a/src/modalities/utils/logging.py b/src/modalities/utils/logging.py new file mode 100644 index 000000000..21eda110b --- /dev/null +++ b/src/modalities/utils/logging.py @@ -0,0 +1,10 @@ +import logging + +def get_logger(name: str = "main") -> logging.Logger: + logger = logging.getLogger(name) + if not logger.handlers: + logger.setLevel(logging.DEBUG) + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter('%(name)s - %(levelname)s - %(message)s')) + logger.addHandler(handler) + return logger \ No newline at end of file From 90f5d5a4361362cfba01d4c2f8f8f3518a3407e1 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Thu, 2 Jan 2025 20:49:45 +0100 Subject: [PATCH 02/25] refactor: scaled the index creation and tokenization by supporting multprocessing --- src/modalities/api.py | 124 ++++- src/modalities/config/instantiation_models.py | 72 ++- .../dataloader/create_packed_data.py | 441 ----------------- src/modalities/dataloader/dataset.py | 6 +- .../dataloader/large_file_lines_reader.py | 130 ----- .../dataloader/preprocessing/__init__.py | 0 .../preprocessing/indexation/__init__.py | 0 .../indexation}/create_index.py | 58 ++- .../preprocessing/tokenization/__init__.py | 0 .../tokenization/create_packed_data.py | 113 +++++ .../tokenization/embedded_stream_data.py | 123 +++++ .../tokenization/large_file_lines_reader.py | 291 +++++++++++ .../tokenization/tokenization_processes.py | 456 ++++++++++++++++++ src/modalities/utils/env_variables.py | 22 + tests/conftest.py | 8 +- .../test_large_file_lines_reader.py | 10 +- 16 files changed, 1220 insertions(+), 634 deletions(-) delete mode 100644 src/modalities/dataloader/create_packed_data.py delete mode 100644 src/modalities/dataloader/large_file_lines_reader.py create mode 100644 src/modalities/dataloader/preprocessing/__init__.py create mode 100644 src/modalities/dataloader/preprocessing/indexation/__init__.py rename src/modalities/dataloader/{ => preprocessing/indexation}/create_index.py (62%) create mode 100644 src/modalities/dataloader/preprocessing/tokenization/__init__.py create mode 100644 src/modalities/dataloader/preprocessing/tokenization/create_packed_data.py create mode 100644 src/modalities/dataloader/preprocessing/tokenization/embedded_stream_data.py create mode 100644 src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py create mode 100644 src/modalities/dataloader/preprocessing/tokenization/tokenization_processes.py create mode 100644 src/modalities/utils/env_variables.py diff --git a/src/modalities/api.py b/src/modalities/api.py index 05f8ef2c2..867da0337 100644 --- a/src/modalities/api.py +++ b/src/modalities/api.py @@ -2,22 +2,45 @@ import os from pathlib import Path - +from typing import Optional + +from modalities.dataloader.preprocessing.tokenization.create_packed_data import PackedDataGenerator +from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import ( + EmbeddedStreamData, + join_embedded_stream_data, +) +from modalities.dataloader.preprocessing.tokenization.tokenization_processes import ( + ProcessFactory, + ProgressLoggingWorker, + get_required_num_of_bytes_to_repr, +) +from modalities.utils.logging import get_logger from pydantic import FilePath import modalities.inference.inference as inference from modalities.checkpointing.checkpoint_conversion import CheckpointConversion from modalities.config.component_factory import ComponentFactory from modalities.config.instantiation_models import PackedDatasetComponentsInstantiationModel -from modalities.dataloader.create_index import IndexGenerator -from modalities.dataloader.create_packed_data import EmbeddedStreamData, PackedDataGenerator, join_embedded_stream_data -from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader +from modalities.dataloader.preprocessing.indexation.create_index import IndexGenerator +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LocalLargeFileLinesReader from modalities.models.huggingface_adapters.hf_adapter import HFModelAdapter from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry +import multiprocessing as mp +import shutil + +from enum import Enum + +class FileExistencePolicy(Enum): + SKIP = "skip" + ERROR = "error" + OVERRIDE = "override" -def create_raw_data_index(src_path: Path, index_path: Path): + +def create_raw_data_index( + src_path: Path, index_path: Path, file_existence_policy: FileExistencePolicy = FileExistencePolicy.ERROR +): """Creates the index file for the content of a large jsonl-file. The index file contains the byte-offsets and lengths of each line in the jsonl-file. Background is the ability to further process the respective file without loading it, @@ -31,13 +54,24 @@ def create_raw_data_index(src_path: Path, index_path: Path): Raises: ValueError: If the index file already exists. """ - index_path = LargeFileLinesReader.default_index_path(src_path, index_path) - os.makedirs(index_path.parent, exist_ok=True) + index_path = LocalLargeFileLinesReader.default_index_path(src_path, index_path) if index_path.exists(): - raise ValueError("index already exists. delete it or specify different output folder.") + if file_existence_policy == FileExistencePolicy.SKIP: + get_logger(name="main").warning(f"Index already exists at {str(index_path)}. Skipping index creation.") + return + elif file_existence_policy == FileExistencePolicy.OVERRIDE: + get_logger(name="main").warning(f"Index already exists at {str(index_path)}. Overriding it.") + os.remove(index_path) + elif file_existence_policy == FileExistencePolicy.ERROR: + raise ValueError("index already exists. delete it or specify different output folder.") + else: + raise ValueError(f"Unknown file existence policy: {file_existence_policy}") + + get_logger(name="main").info( + f"Reading raw data from {str(src_path)} and" f" writing index to {str(index_path)} ..." + ) + os.makedirs(index_path.parent, exist_ok=True) - print(f"reading raw data from {src_path}") - print(f"writing index to {index_path}") generator = IndexGenerator(src_path) generator.create_index(index_path) @@ -88,22 +122,70 @@ def pack_encoded_data(config_dict: dict): # ResolverRegistry to work dynamically with any type-hinted config object from config.py. registry = Registry(COMPONENTS) component_factory = ComponentFactory(registry=registry) - components: PackedDatasetComponentsInstantiationModel = component_factory.build_components( + instantion_model: PackedDatasetComponentsInstantiationModel = component_factory.build_components( config_dict=config_dict, components_model_type=PackedDatasetComponentsInstantiationModel ) + # build the queues + reader_q, tokenizer_q, writer_q, logging_message_q = ProcessFactory.get_process_queues( + writer_q_maxsize=instantion_model.writer_q_maxsize, tokenizer_q_maxsize=instantion_model.tokenizer_q_maxsize + ) + + # build the workers + stop_event = mp.Event() + token_size_in_bytes = get_required_num_of_bytes_to_repr( + instantion_model.tokenizer_worker_settings.tokenizer_settings.tokenizer.vocab_size + ) + + reader_workers = ProcessFactory.get_reader_workers( + rw_settings=instantion_model.reader_worker_settings, + reader_q=reader_q, + tokenizer_q=tokenizer_q, + logging_message_q=logging_message_q, + stop_event=stop_event, + ) + + tokenizer_workers = ProcessFactory.get_tokenizer_workers( + tw_settings=instantion_model.tokenizer_worker_settings, + tokenizer_q=tokenizer_q, + writer_q=writer_q, + logging_message_q=logging_message_q, + token_size_in_bytes=token_size_in_bytes, + stop_event=stop_event, + ) + + writer_worker = ProcessFactory.get_writer_worker( + writer_q=writer_q, + logging_message_q=logging_message_q, + token_size_in_bytes=token_size_in_bytes, + ww_settings=instantion_model.writer_worker_settings, + stop_event=stop_event, + ) + + progress_logging_worker = ProgressLoggingWorker( + logging_message_q=logging_message_q, + reader_q=reader_q, + tokenizer_q=tokenizer_q, + writer_q=writer_q, + total_num_samples=instantion_model.num_samples, + stop_event=stop_event, + logging_interval=instantion_model.logging_interval, + ) + generator = PackedDataGenerator( - components.settings.src_path, - index_path=components.settings.index_path, - tokenizer=components.tokenizer, - eod_token=components.settings.eod_token, - jq_pattern=components.settings.jq_pattern, - number_of_processes=components.settings.num_cpus, - processing_batch_size=components.settings.processing_batch_size, - raw_samples_queue_size=components.settings.raw_samples_queue_size, - processed_samples_queue_size=components.settings.processed_samples_queue_size, + reader_workers=reader_workers, + tokenizer_workers=tokenizer_workers, + writer_worker=writer_worker, + progress_logging_worker=progress_logging_worker, + reader_q=reader_q, + tokenizer_q=tokenizer_q, + writer_q=writer_q, + logging_message_q=logging_message_q, + index_start=instantion_model.index_start, + num_samples=instantion_model.num_samples, + batch_size=instantion_model.batch_size, ) - generator.run(components.settings.dst_path) + generator.run() def merge_packed_data_files(src_paths: list[Path], target_path: Path): diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index edafccae5..d7df1ba71 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -2,7 +2,8 @@ from pathlib import Path from typing import Annotated, Any, Optional -from pydantic import BaseModel, ConfigDict, Field, FilePath, field_validator, model_validator, root_validator +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LargeFileLinesReaderTypes +from pydantic import BaseModel, ConfigDict, Field, FilePath, field_validator, model_validator, root_validator, validator from modalities.config.pydanctic_if_types import ( PydanticCheckpointSavingIFType, @@ -192,19 +193,62 @@ def _check_token_amount_in_dataset(self) -> "TrainingComponentsInstantiationMode class PackedDatasetComponentsInstantiationModel(BaseModel): - class PackedDatasetSettings(BaseModel): - src_path: FilePath - dst_path: Optional[Path] = None - index_path: Optional[FilePath] = None - jq_pattern: str - num_cpus: Annotated[int, Field(strict=True, ge=1)] = os.cpu_count() - eod_token: str - processing_batch_size: Annotated[int, Field(strict=True, ge=1)] - raw_samples_queue_size: Annotated[int, Field(strict=True, ge=1)] - processed_samples_queue_size: Annotated[int, Field(strict=True, ge=1)] - - tokenizer: PydanticTokenizerIFType - settings: PackedDatasetSettings + + class ReaderWorkerSettings(BaseModel): + class ReaderSettings(BaseModel): + class LocalReaderArgs(BaseModel): + raw_data_path: Path + index_path: Optional[Path] = None + encoding: Optional[str] = "utf-8" + + class GlobalReaderArgs(BaseModel): + global_inorder_index_path: Path + raw_data_file_list_path: Path + raw_data_root_path: Path + global_shuffle_index_path: Optional[Path] = None + encoding: Optional[str] = "utf-8" + + reader_type: LargeFileLinesReaderTypes + reader_args: LocalReaderArgs | GlobalReaderArgs + + num_reader_processes: Annotated[int, Field(strict=True, ge=1)] + reader_settings: ReaderSettings + + class TokenizerWorkerSettings(BaseModel): + class TokenizerSettings(BaseModel): + tokenizer: PydanticTokenizerIFType + eod_token: str + jq_pattern: str + + num_tokenizer_processes: Annotated[int, Field(strict=True, ge=1)] + tokenizer_settings: TokenizerSettings + + + class WriterWorkerSettings(BaseModel): + dst_path: Path + index_start: Annotated[int, Field(strict=True, ge=0)] + + + @field_validator("dst_path") + def ensure_path_does_not_exist(cls, value): + path = Path(value) # Convert to Path object if it's a string + if path.exists(): + raise ValueError(f"The filepath '{path}' already exists.") + return path + + paths: dict[str, Path] + reader_worker_settings: ReaderWorkerSettings + tokenizer_worker_settings: TokenizerWorkerSettings + writer_worker_settings: WriterWorkerSettings + tokenizer_q_maxsize: Annotated[int, Field(strict=True, ge=1)] + writer_q_maxsize: Annotated[int, Field(strict=True, ge=1)] + index_start: Annotated[int, Field(strict=True, ge=0)] + num_samples: Annotated[int, Field(strict=True, ge=1)] + batch_size: Annotated[int, Field(strict=True, ge=1)] + logging_interval: Annotated[int, Field(strict=True, ge=1)] + + + class TextGenerationInstantiationModel(BaseModel): diff --git a/src/modalities/dataloader/create_packed_data.py b/src/modalities/dataloader/create_packed_data.py deleted file mode 100644 index 7dc5fce49..000000000 --- a/src/modalities/dataloader/create_packed_data.py +++ /dev/null @@ -1,441 +0,0 @@ -import logging -import math -import multiprocessing -import os -import pickle -import traceback -import warnings -from io import BufferedWriter -from pathlib import Path -from typing import Callable, Iterator, Optional - -import jq -import numpy as np -from pydantic import FilePath -from tqdm import tqdm - -from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader -from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper - -logger = logging.getLogger(__name__) - - -class EmptySampleError(RuntimeError): - pass - - -class PackedDataGenerator: - """Reads in a JSONL file and the corresponding index file and packs the dataset for LLM training.""" - - def __init__( - self, - src_path: FilePath, - tokenizer: TokenizerWrapper, - eod_token: str, - number_of_processes: int, - jq_pattern: str, - processing_batch_size: int, - raw_samples_queue_size: int, - processed_samples_queue_size: int, - index_path: Optional[FilePath] = None, - ): - """ - Initializes a PackedDataGenerator object. - - Args: - src_path (FilePath): Path to a JSONL file, which holds text data. - tokenizer (TokenizerWrapper): PretrainedTokenizer object used to tokenize the provided data in `src_path`. - eod_token (str): End-of-document token. - number_of_processes (int): Number of processes used for parallel processing. - jq_pattern (str): jq-pattern applied on every jsonl-entry. Results are afterwards tokenized and packed. - processing_batch_size (int): Size of the batches that the workers process. - raw_samples_queue_size (int): Maximum size of the raw samples queue. - processed_samples_queue_size (int): Maximum size of the processed samples queue. - index_path (Optional[FilePath], optional): Path to an index file, - which indicates the start character position - and length of samples given in `src_path`. If not defined, an index file next to `src_path` is picked, - by replacing its suffix with ".idx". Defaults to None. - - Returns: - None - """ - self.src_path = src_path - self.tokenizer = tokenizer - self.eod_token = eod_token - self._token_size_in_bytes = self._get_required_num_of_bytes_to_repr(self.tokenizer.vocab_size) - encoded_eod_token = self.tokenizer.get_token_id(self.eod_token) - self._encoded_eos_token_as_bytes = self._encoded_token_to_bytes(encoded_eod_token) - self.jq_filter = jq.compile(jq_pattern) - self._number_of_processes = number_of_processes - self._reader = LargeFileLinesReader(src_path, index_path=index_path) # reads string with utf-8 encoding - self._total_num_of_tokens = 0 - self._raw_samples_queue = multiprocessing.Queue(maxsize=raw_samples_queue_size) - self.processed_samples_queue = multiprocessing.Queue(maxsize=processed_samples_queue_size) - self._exception_buffer = [] - self.processing_batch_size = processing_batch_size - - @staticmethod - def _get_required_num_of_bytes_to_repr(int_to_get_repr: int) -> int: - """ - Calculates the required number of bytes to represent an integer. - - Args: - int_to_get_repr (int): The integer to get the representation for. - - Returns: - int: The number of bytes required to represent the integer. - """ - # we currently only support token sizes of 1, 2 and 4 bytes, as implemented here: - # https://github.com/Modalities/modalities/blob/fix_char_bytes_indexation_mismatch/src/modalities/dataloader/dataset.py#L202 - num_bytes = math.ceil(math.log2(int_to_get_repr) / 8) - if num_bytes == 1: - return 1 - elif num_bytes == 2: - return 2 - elif num_bytes <= 4: - return 4 - else: - raise ValueError("Currently only support token byte sizes of 1, 2, and 4.") - - def _encoded_token_to_bytes(self, encoded_token: int) -> bytes: - """ - Converts an encoded token to its byte representaion. - - Args: - encoded_token (int): The encoded token to be converted. - - Returns: - bytes: The byte representation of the token. - - """ - return encoded_token.to_bytes(self._token_size_in_bytes, byteorder="little", signed=False) - - def _default_destination_path(self, destination_path: Optional[Path] = None) -> Path: - """ - Returns the default destination path for the packed data. - - Args: - destination_path (Path, optional): The specific destination path. Defaults to None. - - Returns: - Path: The default destination path for the packed data. - """ - if destination_path is None: - default_destination_path = Path(self.src_path.parent, f"{self.src_path.stem}.pbin") - print( - f"No specific Destination Path provided. " - f"Pointing to destination next to input data at: {default_destination_path}" - ) - return default_destination_path - return Path(destination_path) - - def run(self, dst_path: Optional[Path] = None): - """ - Packs data and saves it to (default) dst_path. - - Args: - dst_path (Optional[Path]): The destination path to save the packed data. - If not provided, a default destination path will be used. - - Raises: - ValueError: If the file already exists at the destination path. - Exception: If an exception occurs during the data packing process. - - Returns: - None - """ - assert self._total_num_of_tokens == 0, f"This {self.__name__} was already used and is exhausted. Use another!" - dst_path = self._default_destination_path(destination_path=dst_path) - - dst_path.parent.mkdir(parents=True, exist_ok=True) - if dst_path.exists(): - raise ValueError(f"file already exists at destination path '{dst_path}'.") - - self._exception_buffer = [] - try: - # not setting this can cause deadlocks when using hf's "FastTokenizers". See also: - # https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning/67254879#67254879 - os.environ["TOKENIZERS_PARALLELISM"] = "false" - self._launch_parallelized_workers(dst_path) - finally: - os.unsetenv("TOKENIZERS_PARALLELISM") - - if self._exception_buffer: - raise self._exception_buffer[0] - - def _launch_parallelized_workers(self, dst_path: Path): - # Launches workers in parallel for reading, writing, and processing data. - # The data is stored in the provided destination path. - - reader = multiprocessing.Process(target=self._reader_thread()) - reader.start() - - writer = multiprocessing.Process(target=self._writer_thread(dst_path)) - writer.start() - processor_threads = [ - multiprocessing.Process(target=self._process_thread, args=(i,)) for i in range(self._number_of_processes) - ] - for p in processor_threads: - p.start() - for p in processor_threads: - p.join() - self._stop_processing() - writer.join() - - def _stop_processing(self): - # Stops the processing of samples by putting None in the processed_samples_queue. - self.processed_samples_queue.put(None) - - def _generator_for_tokens_to_get_written(self): - # Generator function that yields batches of processed samples. - - while True: - if self._check_for_parallel_errors(): - return - batch = self.processed_samples_queue.get() - if batch is None: - break - yield batch - - def _check_for_parallel_errors(self) -> bool: - # Checks if there are any errors in the exception buffer. - return bool(self._exception_buffer) - - def _writer_thread(self, dst_path: Path) -> Callable: - # Returns a callable writer function that writes a batch - # received from the processed_samples_queue to the destination file. - - def writer(): - # writes a batch received from the processed_samples_queue to the destination file - def _write_batch( - batch: list[tuple[int, bytes]], prev_line_id: int, curr_offset: int, index_list: list, f: BufferedWriter - ) -> tuple[int, int]: - # write the tokens for each document - for line_id, tokens_as_bytes in batch: - if prev_line_id + 1 != line_id: - raise ValueError( - f"Line IDs are not consecutive. Expected {prev_line_id + 1}, but got {line_id}" - ) - f.write(tokens_as_bytes) - segment_length = len(tokens_as_bytes) - index_list.append((curr_offset, segment_length)) - curr_offset += segment_length - prev_line_id = line_id - return prev_line_id, curr_offset - - index_list = [] - with dst_path.open("wb") as f: - # allocate first self.header_size_in_bytes bytes for header (encodes length of data section) - # not possible to prepend header after determining size of data section - f.write((0).to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little")) - f.write( - self._token_size_in_bytes.to_bytes( - EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="little" - ) - ) - # The offset only applies to the data section, not the header - # When we load the file, we add the header size to the offset - curr_offset = 0 - - # write data section (tokens) - pbar = tqdm(total=len(self._reader), desc="Processed batches") - prev_line_id = -1 - batch_dict = {} - for batch in self._generator_for_tokens_to_get_written(): - line_id = batch[0][0] - batch_dict[line_id] = batch - - while prev_line_id + 1 in batch_dict: - batch = batch_dict.pop(prev_line_id + 1) - prev_line_id, curr_offset = _write_batch(batch, prev_line_id, curr_offset, index_list, f) - pbar.update(len(batch)) - # write index - f.write(pickle.dumps(index_list)) - - self._update_data_length_in_pre_allocated_header(dst_path, index_list) - - return writer - - def _reader_thread(self) -> Callable: - # returns a reader function that reads lines from the reader and puts them into a queue. - def reader(): - batch = [] - for line_id, line in tqdm(enumerate(self._reader), desc="Reading jsonl", disable=True): - batch.append((line_id, line)) - if len(batch) % self.processing_batch_size == 0: - self._raw_samples_queue.put(batch) - batch = [] - - # add the remaining samples - if len(batch) > 0: - self._raw_samples_queue.put(batch) - - for _ in range(self._number_of_processes): - self._raw_samples_queue.put(None) - - return reader - - def _process_thread(self, process_id: int): - # Process the lines in a batch and put the processed samples into the processed_samples_queue. - if self._check_for_parallel_errors(): - return - - while True: - if self._check_for_parallel_errors(): - return - batch = self._raw_samples_queue.get() - if batch is None: - break - - try: - batch_processed = [] - for line_id, line in batch: - processed_line = self._process_line(line, process_id) - batch_processed.append((line_id, processed_line)) - self.processed_samples_queue.put(batch_processed) - except EmptySampleError: - warnings.warn( - f"Encountered empty sample in line {line_id} of file {self.src_path} within process {process_id}" - ) - except Exception as exception: - warnings.warn( - f"Could not process line {line_id} in {self.src_path} within process {process_id}. " - f"Raised the following error: {exception=}" - ) - traceback.print_exc() - - def _update_data_length_in_pre_allocated_header(self, dst_path: Path, index_list: list[tuple[int, int]]): - # Update the length of the data section in the pre-allocated header of the destination file. - # The data segment length is sum of the starting position and the length of the last document. - length_of_byte_encoded_data_section = index_list[-1][0] + index_list[-1][1] - data_section_length_in_bytes = length_of_byte_encoded_data_section.to_bytes( - EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little" - ) - with dst_path.open("rb+") as fout: - fout.seek(0) - fout.write(data_section_length_in_bytes) - - def _process_line(self, line: str, process_id: int) -> bytes: - # extracts the text via the jq_filter and applies tokenization to the extract text - jq_retrieved_text = self.jq_filter.input_text(line).first() - if jq_retrieved_text is None: - raise ValueError(f"jq was not able to find anything using the expression: {self.jq_filter}") - tokens = self.tokenizer.tokenize(jq_retrieved_text) - if len(tokens) == 0: - raise EmptySampleError("Received empty sample...") - return b"".join(map(self._encoded_token_to_bytes, tokens)) + self._encoded_eos_token_as_bytes - - -class EmbeddedStreamData: - # amount of bytes to represent number of all tokens in dataset. - # If the amount exceeds 2^(8*`header_size_in_bytes`), this requires adaptation. - # Decided to keep this constant, since a size of 8 bytes requires more data than the internet currently provides - DATA_SECTION_LENGTH_IN_BYTES = 8 - TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES = 4 - HEADER_SIZE_IN_BYTES = DATA_SECTION_LENGTH_IN_BYTES + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES - - def __init__(self, data_path: Path, load_index: Optional[bool] = True): - """ - Initializes an EmbeddedStreamData object. - - Args: - data_path (Path): The path to the packed data file. - load_index (bool, optional): Whether to load the index. Defaults to True. - - Raises: - FileNotFoundError: If the packed data file is not found at the specified path. - - """ - self._data_path = data_path - if not self._data_path.is_file(): - raise FileNotFoundError( - f"Packed Data was not found at {self._data_path.absolute()}." - f"Create on in advance by using `modalities data pack_encoded_data`." - ) - - with self._data_path.open("rb") as f: - # get number of bytes in data section - data_section_length_in_bytes = f.read(self.DATA_SECTION_LENGTH_IN_BYTES) - self.data_len = int.from_bytes(data_section_length_in_bytes, byteorder="little") - - # get number of bytes for encoding a single token - f.seek(self.DATA_SECTION_LENGTH_IN_BYTES) - token_size_as_bytes = f.read(self.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES) - self.token_size_in_bytes = int.from_bytes(token_size_as_bytes, byteorder="little", signed=False) - - # get index - if load_index: - f.seek(self.HEADER_SIZE_IN_BYTES + self.data_len) - pkl_encoded_index = f.read() - # contains the start offset and length of each segment - # as byte positions in the data section - self._index_base: list[tuple[int, int]] = pickle.loads(pkl_encoded_index) - else: - self._index_base = None - - # initialize memmapped data section - self._data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,)) - - @property - def index_base(self) -> list[tuple[int, int]]: - if self._index_base is None: - raise ValueError("Index was not loaded. Set `load_index=True` during initialization.") - return self._index_base - - @property - def data(self) -> np.ndarray: - return self._data - - -def join_embedded_stream_data(stream_data: list[EmbeddedStreamData], target_file: Path, chunk_size: int = 2048): - """ - Joins the embedded stream data into a single file. - - Args: - stream_data (list[EmbeddedStreamData]): A list of EmbeddedStreamData objects representing the stream data. - target_file (Path): The target file to write the joined data to. - chunk_size (int, optional): The size of each data chunk. Defaults to 2048. - - Raises: - FileExistsError: If the target file already exists. - - Returns: - None - """ - if target_file.exists(): - raise FileExistsError(f'Target File at "{target_file}" exists!') - data_len = sum(d.data_len for d in stream_data) - assert len({d.token_size_in_bytes for d in stream_data}) == 1, ( - "Found different token representation sizes. This could indicate the usage of different tokenizers. " - "Not supported!" - ) - token_size_in_bytes = stream_data[0].token_size_in_bytes - - num_data_chunks = sum(math.ceil(d.data_len / chunk_size) for d in stream_data) - data_stream_generator = (d.data[i : i + chunk_size] for d in stream_data for i in range(0, d.data_len, chunk_size)) - - num_entries = sum(len(d.index_base) for d in stream_data) - - def index_stream_generator() -> Iterator[tuple[int, int]]: - # generates a stream of index offsets and segment lengths. - curr_offset = 0 - for embedded_stream_data in stream_data: - for entry_offset, segment_length in embedded_stream_data.index_base: - yield entry_offset + curr_offset, segment_length - curr_offset += embedded_stream_data.data_len - curr_offset -= embedded_stream_data.HEADER_SIZE_IN_BYTES - - with target_file.open("wb") as fout: - fout.write(data_len.to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little")) - fout.write( - token_size_in_bytes.to_bytes(EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="little") - ) - for data_chunk in tqdm(data_stream_generator, total=num_data_chunks, desc="Writing Data Chunks..."): - fout.write(data_chunk) - - joint_index = [entry for entry in tqdm(index_stream_generator(), total=num_entries, desc="Concatenating Index")] - pickled_index = pickle.dumps(joint_index) - pickled_index_as_chunks = (pickled_index[i : i + chunk_size] for i in range(0, len(pickled_index), chunk_size)) - num_index_chunks = math.ceil(len(pickled_index) / chunk_size) - for index_chunk in tqdm(pickled_index_as_chunks, total=num_index_chunks, desc="Writing Index Chunks..."): - fout.write(index_chunk) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 11ba3d2dd..577d459fd 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -5,6 +5,8 @@ from typing import Optional import jq +from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import EmbeddedStreamData +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LocalLargeFileLinesReader import numpy as np from pydantic import BaseModel from torch.utils.data.dataset import Dataset as TorchdataSet @@ -13,8 +15,6 @@ from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper -from ..dataloader.large_file_lines_reader import LargeFileLinesReader -from .create_packed_data import EmbeddedStreamData class Dataset(TorchdataSet): @@ -163,7 +163,7 @@ def __init__( """ super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) - self.reader = LargeFileLinesReader(self.raw_data_path, index_path=index_path) + self.reader = LocalLargeFileLinesReader(self.raw_data_path, index_path=index_path) self.jq_filter = jq.compile(jq_pattern) self.tokenizer = tokenizer diff --git a/src/modalities/dataloader/large_file_lines_reader.py b/src/modalities/dataloader/large_file_lines_reader.py deleted file mode 100644 index 6488cdd1b..000000000 --- a/src/modalities/dataloader/large_file_lines_reader.py +++ /dev/null @@ -1,130 +0,0 @@ -import mmap -import pickle -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Optional - - -class BaseReader(ABC): - @abstractmethod - def __len__(self) -> int: - raise NotImplementedError - - @abstractmethod - def __getitem__(self, key: int) -> str | list[str]: - raise NotImplementedError - - -class LargeFileLinesReader(BaseReader): - """LargeFileLinesReader class that read lines from a large file efficiently.""" - - def __init__( - self, - raw_data_path: Path, - index_path: Optional[Path] = None, - encoding: Optional[str] = "utf-8", - use_sample_length_from_index: bool = True, - ): - """ - Initializes a LargeFileLinesReader object. - - Args: - raw_data_path (Path): Path to a jsonl file, which holds text data. - index_path (Optional[Path]): Path to an index file, which indicates the start character/byte position - and length of samples given in `raw_data_path`. - If not defined, an index next to `raw_data_path` is picked, - by replacing its suffix with ".idx". - encoding (Optional[str]): The encoding of the file (default: "utf-8"). - If encoding is None, the raw data is read as bytes. - use_sample_length_from_index (bool): If True, the sample length is taken from the index file - i.e., the (offset, sample_length) pairs. If False, the sample length is calculated - as the difference between the starting point of the next and the current sample. - Returns: - None - """ - self.encoding = encoding - self.raw_data_path = raw_data_path - self.index_path = self.default_index_path(self.raw_data_path, index_path) - self.use_sample_length_from_index = use_sample_length_from_index - - if not self.raw_data_path.is_file(): - raise FileNotFoundError("Raw data file does not exist") - if not self.index_path.is_file(): - raise FileNotFoundError("Index file does not exist. Use `modalities data create_raw_index` to create one.") - - with self.index_path.open("rb") as f: - self.index = pickle.load(f) - - self.raw_data_fd = self.raw_data_path.open("rb") - self.mmapped_data_file = mmap.mmap(self.raw_data_fd.fileno(), 0, access=mmap.ACCESS_READ) - - def close(self): - self.mmapped_data_file.close() - self.raw_data_fd.close() - - @staticmethod - def default_index_path(raw_data_path: Path, index_path: Optional[Path] = None) -> Path: - """ - Returns the default index path for the given raw data path. - - Args: - raw_data_path (Path): The path to the raw data file. - index_path (Optional[Path]): The path to the index file (default: None). - - Returns: - Path: The default index path. - - Note: - If `index_path` is not provided, the default index path is generated by - appending the extension ".idx" to the stem of the `raw_data_path`. - """ - if index_path is None: - default_index_path = Path(raw_data_path.parent, f"{raw_data_path.stem}.idx") - print(f"No specific Index Path provided. Pointing to index next to input data at: {default_index_path}") - return default_index_path - return index_path - - def __len__(self) -> int: - """ - Returns the length of the index. - - Returns: - int: The length of the index. - """ - return len(self.index) - - def __getitem__(self, key: int) -> str | bytes: - """ - Retrieves an item from the LargeFileLinesReader. - - Args: - key (int): The index used to retrieve the item. - - Returns: - str | bytes: The item retrieved from the LargeFileLinesReader. - - Raises: - IndexError: If the key is out of range. - - """ - - offset, sample_length_in_bytes = self.index[key] - - # If use_sample_length_from_index = False, we calculate the sample length as the difference between the - # starting point of the next and the current sample. - # This allows for reading in the entire sample including the newline character. - if not self.use_sample_length_from_index: - if key + 1 < len(self.index): - sample_length_in_bytes = self.index[key + 1][0] - self.index[key][0] - else: - sample_length_in_bytes = len(self.mmapped_data_file) - offset - - return self._read_from_raw_file(offset, sample_length_in_bytes) - - def _read_from_raw_file(self, offset: int, sample_length_in_bytes: int) -> str | bytes: - # Reads a specified number of bytes from a raw file starting from a given offset. - data = self.mmapped_data_file[offset : offset + sample_length_in_bytes] - if self.encoding is not None: - data_decoded = data.decode(self.encoding) - return data_decoded - return data diff --git a/src/modalities/dataloader/preprocessing/__init__.py b/src/modalities/dataloader/preprocessing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/modalities/dataloader/preprocessing/indexation/__init__.py b/src/modalities/dataloader/preprocessing/indexation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/modalities/dataloader/create_index.py b/src/modalities/dataloader/preprocessing/indexation/create_index.py similarity index 62% rename from src/modalities/dataloader/create_index.py rename to src/modalities/dataloader/preprocessing/indexation/create_index.py index 55e573e15..17266b91c 100644 --- a/src/modalities/dataloader/create_index.py +++ b/src/modalities/dataloader/preprocessing/indexation/create_index.py @@ -3,14 +3,16 @@ import pickle as pkl import queue import threading -import warnings +import time from pathlib import Path +import jq +from modalities.utils.logging import get_logger from tqdm import tqdm class IndexGenerator: - def __init__(self, src_file: Path, drop_faulty_entries: bool = False): + def __init__(self, src_file: Path, drop_faulty_entries: bool = False, jq_pattern: str = ".text"): """ Initializes an IndexGenerator object. Reads a JSONL file as a binary file, and iterates through it character by character. @@ -26,6 +28,7 @@ def __init__(self, src_file: Path, drop_faulty_entries: bool = False): None """ self.src_file = src_file + self.jq_pattern = jq_pattern self.drop_faulty_entries = drop_faulty_entries with self.src_file.open(mode="rb") as fin: # Move the cursor to the end of the file @@ -50,6 +53,7 @@ def create_index(self, target_path_for_index_file: Path): Returns: None """ + start_time = time.time() self._exception_buffer = [] reader = threading.Thread(target=self._reader_thread) reader.start() @@ -58,9 +62,21 @@ def create_index(self, target_path_for_index_file: Path): reader.join() processor.join() if self._exception_buffer: + get_logger(name="main").warning( + f"Index creation failed for {target_path_for_index_file}. Exception buffer: {self._exception_buffer}" + ) raise self._exception_buffer[0] - print(f"Created index of length {len(self._index_map)}") - target_path_for_index_file.write_bytes(pkl.dumps(self._index_map)) + + if len(self._index_map) == 0: + get_logger(name="main").warning(f"Could not create index! No entries found in {self.src_file}") + else: + end_time = time.time() + get_logger(name="main").info( + f"Created index {target_path_for_index_file} of length {len(self._index_map)} " + f"at {len(self._index_map) / (end_time - start_time)} iterations/s." + ) + target_path_for_index_file.write_bytes(pkl.dumps(self._index_map)) + get_logger(name="main").info(f"Wrote index {target_path_for_index_file} to disc.") def _indexer_thread(self): # This method is responsible for indexing the lines in the queue and parsing them as JSON. @@ -78,33 +94,40 @@ def queue_generator(): break yield line - def parse_line_as_json(line_start_idx: int, line: str): + def parse_line_as_json(line_id: int, line_start_byte_pos: int, line: bytes, jq_filter): # Parses a line as JSON and appends the sample index, i.e., # the line start index and length to the index map. # If the line is faulty and `drop_faulty_entries` is set to True, a warning is issued. - try: # check if line is a valid json - json.loads(line) - self._index_map.append((line_start_idx, len(line))) - except Exception as low_level_err: + line_string = line.decode("utf-8") + jq_retrieved_text = jq_filter.input_text(line_string).first() + if jq_retrieved_text is not None: + if len(jq_retrieved_text) > 0: + self._index_map.append((line_start_byte_pos, len(line))) + else: + get_logger(name="main").warning(f'Faulty line {line_id} (no text) in {str(self.src_file)}, skipping...') + else: if self.drop_faulty_entries: - warnings.warn(f'faulty line "{line}", skipping...') + get_logger(name="main").warning(f'Faulty line {line_id} (parsing error) in {str(self.src_file)}, skipping...') else: - err = ValueError(f'faulty line "{line}", skipping...') - err.__cause__ = low_level_err + get_logger(name="main").warning(f'Faulty line {line_id} (parsing error), stopping...') + err = ValueError(f'Faulty line "{line} in {str(self.src_file)}') self._exception_buffer.append(err) + jq_filter = jq.compile(self.jq_pattern) self._index_map = [] - for line_start_idx, line in tqdm(queue_generator(), desc="Processed Lines"): + for line_id, line_start_byte_pos, line in tqdm(queue_generator(), desc="Processed Lines", disable=True): if self._check_for_parallel_errors(): return - parse_line_as_json(line_start_idx, line) + parse_line_as_json(line_id, line_start_byte_pos, line, jq_filter) def _reader_thread(self): # Reads lines from the source file and puts them into a queue. # This method is executed in a separate thread. It reads lines from the source file until # the end of the file is reached. Each line is put into a queue along with its cursor position. If any # errors are detected, the method returns immediately. + get_logger(name="main").info(f"Reading the jsonl file {self.src_file}...") + num_read_documents = 0 with open(self.src_file, "rb") as fin: while True: cursor = fin.tell() @@ -114,10 +137,13 @@ def _reader_thread(self): if fin.tell() == self._total_num_bytes: if line.endswith(b"\n"): line = line[:-1] - self._queue_of_raw_lines.put((cursor, line)) + self._queue_of_raw_lines.put((num_read_documents, cursor, line)) + num_read_documents += 1 break line_without_newline_char = line[:-1] - self._queue_of_raw_lines.put((cursor, line_without_newline_char)) + self._queue_of_raw_lines.put((num_read_documents, cursor, line_without_newline_char)) + num_read_documents += 1 + get_logger(name="main").info(f"Finished reading the jsonl file {self.src_file} (read {num_read_documents}).") self._queue_of_raw_lines.put(None) def _check_for_parallel_errors(self) -> bool: diff --git a/src/modalities/dataloader/preprocessing/tokenization/__init__.py b/src/modalities/dataloader/preprocessing/tokenization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/modalities/dataloader/preprocessing/tokenization/create_packed_data.py b/src/modalities/dataloader/preprocessing/tokenization/create_packed_data.py new file mode 100644 index 000000000..79c8f3401 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/tokenization/create_packed_data.py @@ -0,0 +1,113 @@ +import multiprocessing as mp +import time + + +from modalities.dataloader.preprocessing.tokenization.tokenization_processes import ( + ProgressLoggingWorker, + ReaderWorker, + TokenizerWorker, + WriterWorker, +) +from modalities.utils.env_variables import temporary_env_var +from modalities.utils.logging import get_logger +import tqdm +import time + + + + +class PackedDataGenerator: + """Reads in a JSONL file and the corresponding index file and packs the dataset for LLM training.""" + + def __init__( + self, + reader_workers: list[ReaderWorker], + tokenizer_workers: list[TokenizerWorker], + writer_worker: WriterWorker, + progress_logging_worker: ProgressLoggingWorker, + reader_q: mp.Queue, + tokenizer_q: mp.Queue, + writer_q: mp.Queue, + logging_message_q: mp.Queue, + index_start: int, + num_samples: int, + batch_size: int, + ): + self.reader_workers = reader_workers + self.tokenizer_workers = tokenizer_workers + self.writer_worker = writer_worker + self.progress_logging_worker = progress_logging_worker + self.reader_q = reader_q + self.tokenizer_q = tokenizer_q + self.writer_q = writer_q + self.logging_message_q = logging_message_q + self._index_start = index_start + self._num_samples = num_samples + self.batch_size = batch_size + self._exception_buffer = [] + + if num_samples == -1: + # TODO accessing the reader directly is not nice, but we need to know the total number of samples + total_num_samples = len(self.reader_workers[0]._reader) + num_samples = total_num_samples - index_start + + def run(self): + # Not setting TOKENIZERS_PARALLELISM to false can cause deadlocks when using hf's "FastTokenizers". See also: + # https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning/67254879#67254879 + with temporary_env_var("TOKENIZERS_PARALLELISM", "false"): + start_time = time.time() + # populate the reader queue with the sample_ids that we want to tokenize + self._populate_reader_q( + index_start=self._index_start, + num_samples=self._num_samples, + num_reader_processes=len(self.reader_workers), + ) + + # start the progress logging worker + self.progress_logging_worker.start() + + # start the reader proceseses + for reader_worker in tqdm.tqdm(self.reader_workers, desc="Starting reader workers"): + reader_worker.start() + + # start the tokenizer processes + for tokenizer_worker in tqdm.tqdm(self.tokenizer_workers, desc="Starting tokenizer workers"): + tokenizer_worker.start() + + # start the writer process + self.writer_worker.start() + + # wait for all processes to finish + for reader_worker in tqdm.tqdm(self.reader_workers, desc="Stopping for reader workers"): + reader_worker.join() + + # stop the tokenizer processes + for _ in self.tokenizer_workers: + self.tokenizer_q.put(None) + for tokenizer_worker in tqdm.tqdm(self.tokenizer_workers, desc="Stopping tokenizer workers"): + tokenizer_worker.join() + + # stop the writer process + get_logger().info("Stopping writer worker.") + self.writer_q.put(None) + self.writer_worker.join() + + # stop the logging worker process + get_logger().info("Stopping progress logging worker.") + self.logging_message_q.put(None) + self.progress_logging_worker.join() + + end_time = time.time() + get_logger().info(f"Tokenization took {end_time - start_time} seconds.") + + if self._exception_buffer: + raise self._exception_buffer[0] + + + def _populate_reader_q(self, index_start: int, num_samples: int, num_reader_processes: int): + # populate the reader queue with the line_ids that we want to tokenize + + for i in tqdm.tqdm(range(index_start, index_start + num_samples, self.batch_size), desc="Filling up reader queue with line ids"): + self.reader_q.put((i, self.batch_size)) + for i in range(num_reader_processes): + self.reader_q.put(None) diff --git a/src/modalities/dataloader/preprocessing/tokenization/embedded_stream_data.py b/src/modalities/dataloader/preprocessing/tokenization/embedded_stream_data.py new file mode 100644 index 000000000..c7bb74d12 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/tokenization/embedded_stream_data.py @@ -0,0 +1,123 @@ +import math +import pickle +from pathlib import Path +from typing import Iterator, Optional + + +import numpy as np +from tqdm import tqdm + + +class EmbeddedStreamData: + # amount of bytes to represent number of all tokens in dataset. + # If the amount exceeds 2^(8*`header_size_in_bytes`), this requires adaptation. + # Decided to keep this constant, since a size of 8 bytes requires more data than the internet currently provides + DATA_SECTION_LENGTH_IN_BYTES = 8 + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES = 4 + HEADER_SIZE_IN_BYTES = DATA_SECTION_LENGTH_IN_BYTES + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES + + def __init__(self, data_path: Path, load_index: Optional[bool] = True): + """ + Initializes an EmbeddedStreamData object. + + Args: + data_path (Path): The path to the packed data file. + load_index (bool, optional): Whether to load the index. Defaults to True. + + Raises: + FileNotFoundError: If the packed data file is not found at the specified path. + + """ + self._data_path = data_path + if not self._data_path.is_file(): + raise FileNotFoundError( + f"Packed Data was not found at {self._data_path.absolute()}." + f"Create on in advance by using `modalities data pack_encoded_data`." + ) + + with self._data_path.open("rb") as f: + # get number of bytes in data section + data_section_length_in_bytes = f.read(self.DATA_SECTION_LENGTH_IN_BYTES) + self.data_len = int.from_bytes(data_section_length_in_bytes, byteorder="little") + + # get number of bytes for encoding a single token + f.seek(self.DATA_SECTION_LENGTH_IN_BYTES) + token_size_as_bytes = f.read(self.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES) + self.token_size_in_bytes = int.from_bytes(token_size_as_bytes, byteorder="little", signed=False) + + # get index + if load_index: + f.seek(self.HEADER_SIZE_IN_BYTES + self.data_len) + pkl_encoded_index = f.read() + # contains the start offset and length of each segment + # as byte positions in the data section + self._index_base: list[tuple[int, int]] = pickle.loads(pkl_encoded_index) + else: + self._index_base = None + + # initialize memmapped data section + self._data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,)) + + @property + def index_base(self) -> list[tuple[int, int]]: + if self._index_base is None: + raise ValueError("Index was not loaded. Set `load_index=True` during initialization.") + return self._index_base + + @property + def data(self) -> np.ndarray: + return self._data + + +def join_embedded_stream_data(stream_data: list[EmbeddedStreamData], target_file: Path, chunk_size: int = 2048): + """ + Joins the embedded stream data into a single file. + + Args: + stream_data (list[EmbeddedStreamData]): A list of EmbeddedStreamData objects representing the stream data. + target_file (Path): The target file to write the joined data to. + chunk_size (int, optional): The size of each data chunk. Defaults to 2048. + + Raises: + FileExistsError: If the target file already exists. + + Returns: + None + """ + if target_file.exists(): + raise FileExistsError(f'Target File at "{target_file}" exists!') + data_len = sum(d.data_len for d in stream_data) + assert len({d.token_size_in_bytes for d in stream_data}) == 1, ( + "Found different token representation sizes. This could indicate the usage of different tokenizers. " + "Not supported!" + ) + token_size_in_bytes = stream_data[0].token_size_in_bytes + + num_data_chunks = sum(math.ceil(d.data_len / chunk_size) for d in stream_data) + data_stream_generator = (d.data[i : i + chunk_size] for d in stream_data for i in range(0, d.data_len, chunk_size)) + + num_entries = sum(len(d.index_base) for d in stream_data) + + def index_stream_generator() -> Iterator[tuple[int, int]]: + # generates a stream of index offsets and segment lengths. + curr_offset = 0 + for embedded_stream_data in stream_data: + for entry_offset, segment_length in embedded_stream_data.index_base: + yield entry_offset + curr_offset, segment_length + curr_offset += embedded_stream_data.data_len + curr_offset -= embedded_stream_data.HEADER_SIZE_IN_BYTES + + with target_file.open("wb") as fout: + fout.write(data_len.to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little")) + fout.write( + token_size_in_bytes.to_bytes(EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="little") + ) + for data_chunk in tqdm(data_stream_generator, total=num_data_chunks, desc="Writing Data Chunks..."): + fout.write(data_chunk) + + joint_index = [entry for entry in tqdm(index_stream_generator(), total=num_entries, desc="Concatenating Index")] + pickled_index = pickle.dumps(joint_index) + pickled_index_as_chunks = (pickled_index[i : i + chunk_size] for i in range(0, len(pickled_index), chunk_size)) + num_index_chunks = math.ceil(len(pickled_index) / chunk_size) + for index_chunk in tqdm(pickled_index_as_chunks, total=num_index_chunks, desc="Writing Index Chunks..."): + fout.write(index_chunk) diff --git a/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py b/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py new file mode 100644 index 000000000..10bec2881 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py @@ -0,0 +1,291 @@ +from dataclasses import dataclass +import mmap +import pickle +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional +from modalities.exceptions import ReaderIndexationError +import numpy as np +from enum import Enum + +@dataclass +class Sample: + # If the index is not shuffled, then the incrementeal_line_id + # points to the position in the dataset + # If the index is shuffled, then the incremental_line_id + # points to the position in the shuffled index and the + # shuffled_line_id points to the position in the original index + incremental_line_id: int + raw_data_path: Path + offset: int + sample_length_in_bytes: int + content_raw: str | bytes + content_tokenized: Optional[bytes] = None + shuffled_line_id: Optional[int] = None + + +class BaseReader(ABC): + @abstractmethod + def __len__(self) -> int: + raise NotImplementedError + + @abstractmethod + def __getitem__(self, key: int) -> Sample: + raise NotImplementedError + + +class LocalLargeFileLinesReader(BaseReader): + """LargeFileLinesReader class that read lines from a large file efficiently.""" + + def __init__( + self, + raw_data_path: Path, + index_path: Optional[Path] = None, + encoding: Optional[str] = "utf-8", + use_sample_length_from_index: bool = True, + ): + """ + Initializes a LargeFileLinesReader object. + + Args: + raw_data_path (Path): Path to a jsonl file, which holds text data. + index_path (Optional[Path]): Path to an index file, which indicates the start character/byte position + and length of samples given in `raw_data_path`. + If not defined, an index next to `raw_data_path` is picked, + by replacing its suffix with ".idx". + encoding (Optional[str]): The encoding of the file (default: "utf-8"). + If encoding is None, the raw data is read as bytes. + use_sample_length_from_index (bool): If True, the sample length is taken from the index file + i.e., the (offset, sample_length) pairs. If False, the sample length is calculated + as the difference between the starting point of the next and the current sample. + Returns: + None + """ + self.encoding = encoding + self.raw_data_path = raw_data_path + self.index_path = self.default_index_path(self.raw_data_path, index_path) + self.use_sample_length_from_index = use_sample_length_from_index + + if not self.raw_data_path.is_file(): + raise FileNotFoundError("Raw data file does not exist") + if not self.index_path.is_file(): + raise FileNotFoundError("Index file does not exist. Use `modalities data create_raw_index` to create one.") + + with self.index_path.open("rb") as f: + self.index = pickle.load(f) + + self.raw_data_fd = self.raw_data_path.open("rb") + self.mmapped_data_file = mmap.mmap(self.raw_data_fd.fileno(), 0, access=mmap.ACCESS_READ) + + def close(self): + self.mmapped_data_file.close() + self.raw_data_fd.close() + + @staticmethod + def default_index_path(raw_data_path: Path, index_path: Optional[Path] = None) -> Path: + """ + Returns the default index path for the given raw data path. + + Args: + raw_data_path (Path): The path to the raw data file. + index_path (Optional[Path]): The path to the index file (default: None). + + Returns: + Path: The default index path. + + Note: + If `index_path` is not provided, the default index path is generated by + appending the extension ".idx" to the stem of the `raw_data_path`. + """ + if index_path is None: + default_index_path = Path(raw_data_path.parent, f"{raw_data_path.stem}.idx") + print(f"No specific Index Path provided. Pointing to index next to input data at: {default_index_path}") + return default_index_path + return index_path + + def __len__(self) -> int: + """ + Returns the length of the index. + + Returns: + int: The length of the index. + """ + return len(self.index) + + def __getitem__(self, key: int) -> Sample: + """ + Retrieves an item from the LargeFileLinesReader. + + Args: + key (int): The index used to retrieve the item. + + Returns: + Sample: The item retrieved from the LargeFileLinesReader. + + Raises: + IndexError: If the key is out of range. + + """ + + offset, sample_length_in_bytes = self.index[key] + + # If use_sample_length_from_index = False, we calculate the sample length as the difference between the + # starting point of the next and the current sample. + # This allows for reading in the entire sample including the newline character. + if not self.use_sample_length_from_index: + if key + 1 < len(self.index): + sample_length_in_bytes = self.index[key + 1][0] - self.index[key][0] + else: + sample_length_in_bytes = len(self.mmapped_data_file) - offset + + content = self._read_from_raw_file(offset, sample_length_in_bytes) + return Sample( + raw_data_path=self.raw_data_path, + incremental_line_id=key, + shuffled_line_id=key, # TODO so far we don't support shuffling here! + offset=offset, + sample_length_in_bytes=sample_length_in_bytes, + content_raw=content, + ) + + def _read_from_raw_file(self, offset: int, sample_length_in_bytes: int) -> str | bytes: + # Reads a specified number of bytes from a raw file starting from a given offset. + data = self.mmapped_data_file[offset : offset + sample_length_in_bytes] + if self.encoding is not None: + data_decoded = data.decode(self.encoding) + return data_decoded + return data + + +class GlobalLargeFileLinesReader(BaseReader): + """LargeFileLinesReader class that read lines from a large file efficiently.""" + + def __init__( + self, + global_inorder_index_path: Path, + raw_data_file_list_path: Path, + raw_data_root_path: Path, + global_shuffle_index_path: Optional[Path] = None, + encoding: Optional[str] = "utf-8", + ): + self.global_inorder_index_path = global_inorder_index_path + self.raw_data_file_list_path = raw_data_file_list_path + self.raw_data_root_path = raw_data_root_path + self.global_shuffle_index_path = global_shuffle_index_path + self.encoding = encoding + + # create the raw data file path list (the JSONL files) + # the file paths are relative to the raw_data_root_path + with open(self.raw_data_file_list_path, "r", encoding="utf-8") as f: + self.relative_raw_data_file_paths = [line.strip() for line in f.readlines()] + + self.relative_to_absolute_raw_data_file_paths = { + rel_file_path: raw_data_root_path / rel_file_path for rel_file_path in self.relative_raw_data_file_paths + } + + # open memmap / index files + num_rows, _, _ = np.memmap(self.global_inorder_index_path, dtype="int64", mode="r")[0:3] + + self.global_index_inorder = np.memmap( + self.global_inorder_index_path, dtype="int64", mode="r", shape=(num_rows, 3) + ) + if self.global_shuffle_index_path is not None: + self.global_shuffle_index = np.memmap(self.global_shuffle_index_path, dtype="int64", mode="r") + else: + self.global_shuffle_index = None + # the 0th element in the global_index_inorder contains the meta data (num_rows, num_cols, is_shuffled) + # therefore we have to skip the first element when iterating. + # Note, when we iterate via the global_shuffle_index, we don't have to do this, + # as the the shuffled index does not contain index 0. + self.global_index_inorder = self.global_index_inorder[1:] + + def close(self): + pass + + def __len__(self) -> int: + """ + Returns the length of the index. + + Returns: + int: The length of the index. + """ + if self.global_shuffle_index is not None: + return len(self.global_shuffle_index) + else: + return len(self.global_index_inorder) + + def __getitem__(self, key: int) -> Sample: + """ + Retrieves an item from the LargeFileLinesReader. + + Args: + key (int): The index used to retrieve the item. + + Returns: + Sample: The item retrieved from the LargeFileLinesReader. + + Raises: + IndexError: If the key is out of range. + + """ + try: + if self.global_shuffle_index is not None: + mapped_key = self.global_shuffle_index[key] + else: + mapped_key = key + file_index, offset, sample_length_in_bytes = self.global_index_inorder[mapped_key] + rel_file_path = self.relative_raw_data_file_paths[file_index] + abs_raw_file_path = self.relative_to_absolute_raw_data_file_paths[rel_file_path] + except Exception as e: + raise ReaderIndexationError(f"Error while reading sample with key {key}: {e}") from e + + with open(abs_raw_file_path, "rb") as fd: + raw_data_mmap = mmap.mmap(fd.fileno(), 0, access=mmap.ACCESS_READ) + content = raw_data_mmap[offset : offset + sample_length_in_bytes] + if self.encoding is not None: + content = content.decode(self.encoding) + return Sample( + incremental_line_id=key, + shuffled_line_id=mapped_key, + raw_data_path=abs_raw_file_path, + offset=offset, + sample_length_in_bytes=sample_length_in_bytes, + content_raw=content, + ) + + +class LargeFileLinesReaderTypes(Enum): + LOCAL = "LOCAL" + GLOBAL = "GLOBAL" + + +class LargeFileLinesReaderFactory: + + @staticmethod + def get_local_reader( + raw_data_path: Path, + index_path: Optional[Path] = None, + encoding: Optional[str] = "utf-8", + ) -> LocalLargeFileLinesReader: + return LocalLargeFileLinesReader( + raw_data_path=raw_data_path, + index_path=index_path, + encoding=encoding, + use_sample_length_from_index=True, + ) + + @staticmethod + def get_global_reader( + global_inorder_index_path: Path, + raw_data_file_list_path: Path, + raw_data_root_path: Path, + global_shuffle_index_path: Optional[Path] = None, + encoding: Optional[str] = "utf-8", + ) -> GlobalLargeFileLinesReader: + return GlobalLargeFileLinesReader( + global_inorder_index_path=global_inorder_index_path, + raw_data_file_list_path=raw_data_file_list_path, + raw_data_root_path=raw_data_root_path, + global_shuffle_index_path=global_shuffle_index_path, + encoding=encoding, + ) diff --git a/src/modalities/dataloader/preprocessing/tokenization/tokenization_processes.py b/src/modalities/dataloader/preprocessing/tokenization/tokenization_processes.py new file mode 100644 index 000000000..8123cd6df --- /dev/null +++ b/src/modalities/dataloader/preprocessing/tokenization/tokenization_processes.py @@ -0,0 +1,456 @@ +from dataclasses import dataclass +from enum import Enum +import math +import multiprocessing as mp +import os +import pickle +import time +import traceback +from typing import Any, Callable, Type +import warnings +from io import BufferedWriter +from pathlib import Path +from multiprocessing.synchronize import Event +from data_quality_ablations.utils.logging import get_logger +import jq +from modalities.config.instantiation_models import PackedDatasetComponentsInstantiationModel +from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import EmbeddedStreamData +from modalities.exceptions import EmptySampleError +from pydantic import BaseModel +from tqdm import tqdm +import queue + +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import ( + BaseReader, + LargeFileLinesReaderFactory, + LargeFileLinesReaderTypes, + Sample, +) +from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper + + +def get_required_num_of_bytes_to_repr(int_to_get_repr: int) -> int: + """ + Calculates the required number of bytes to represent an integer. + + Args: + int_to_get_repr (int): The integer to get the representation for. + + Returns: + int: The number of bytes required to represent the integer. + """ + # we currently only support token sizes of 1, 2 and 4 bytes, as implemented here: + # https://github.com/Modalities/modalities/blob/fix_char_bytes_indexation_mismatch/src/modalities/dataloader/dataset.py#L202 + num_bytes = math.ceil(math.log2(int_to_get_repr) / 8) + if num_bytes == 1: + return 1 + elif num_bytes == 2: + return 2 + elif num_bytes <= 4: + return 4 + else: + raise ValueError("Currently only support token byte sizes of 1, 2, and 4.") + + +class ReaderWorker(mp.Process): + def __init__( + self, + reader_type: Type[BaseReader], + reader_args: BaseModel, + reader_q: mp.Queue, + tokenizer_q: mp.Queue, + logging_message_q: mp.Queue, + process_id: int, + stop_event: Event, + ): + super().__init__() + self._reader_q = reader_q + self._tokenizer_q = tokenizer_q + self._logging_message_q = logging_message_q + self._reader_type = reader_type + self._reader_args = reader_args + self._stop_event = stop_event + self.process_id = process_id + + def run(self): + reader = self._reader_type(**self._reader_args.model_dump()) + batch = [] + num_samples_read = 0 + while not self._stop_event.is_set(): + try: + # we set the timout here, such that the worker can check if the stop_event is set + item = self._reader_q.get(timeout=3) + except queue.Empty: + continue + if item is None: + print(f"Reading worker with pid {mp.current_process().pid} exiting, Read {num_samples_read} samples") + break + sample_id, batch_size = item + + + batch: list[Sample] = [reader[sample_id + i] for i in range(batch_size)] + self._tokenizer_q.put(batch) + self._logging_message_q.put(ProgressMessage(WorkerTypes.READER, self.process_id, len(batch))) + num_samples_read += len(batch) + + if not self._stop_event.is_set(): + # add the remaining samples + if len(batch) > 0: + self._tokenizer_q.put(batch) + self._logging_message_q.put(ProgressMessage(WorkerTypes.READER, self.process_id, len(batch))) + + + +class TokenizerWorker(mp.Process): + def __init__( + self, + tokenizer: TokenizerWrapper, + eod_token: str, + token_size_in_bytes: int, + tokenizer_q: mp.Queue, + logging_message_q: mp.Queue, + writer_q: mp.Queue, + jq_pattern: str, + process_id: int, + stop_event: Event, + ): + super().__init__() + self._jq_filter = jq.compile(jq_pattern) + self.tokenizer = tokenizer + self.eod_token = eod_token + self._token_size_in_bytes = token_size_in_bytes + encoded_eod_token = self.tokenizer.get_token_id(self.eod_token) + self._encoded_eos_token_as_bytes = self._encoded_token_to_bytes(encoded_eod_token) + self._tokenizer_q = tokenizer_q + self._writer_q = writer_q + self._logging_message_q = logging_message_q + self._process_id = process_id + self._stop_event = stop_event + + def run(self): + # Process the lines in a batch and put the processed samples into the writer_q. + + while not self._stop_event.is_set(): + try: + batch: list[Sample] = self._tokenizer_q.get(timeout=10) + except queue.Empty: + continue + if batch is None: + break + + try: + batch_processed = [] + for sample in batch: + processed_line = self._process_line(sample.content_raw) + sample.content_tokenized = processed_line + batch_processed.append(sample) + self._writer_q.put(batch_processed) + self._logging_message_q.put(ProgressMessage(WorkerTypes.TOKENIZER, self._process_id, len(batch))) + except EmptySampleError: + warnings.warn( + f"Encountered empty sample in line {sample.shuffled_line_id} in file {sample.raw_data_path} within process {self._process_id}" + ) + except Exception as exception: + warnings.warn( + f"Could not process line {sample.shuffled_line_id} in file {sample.raw_data_path} within process {self._process_id}. " + f"Raised the following error: {exception=}" + ) + traceback.print_exc() + + def _process_line(self, line: str) -> bytes: + # extracts the text via the jq_filter and applies tokenization to the extract text + jq_retrieved_text = self._jq_filter.input_text(line).first() + if jq_retrieved_text is None: + raise ValueError(f"jq was not able to find anything using the expression: {self._jq_filter}") + tokens = self.tokenizer.tokenize(jq_retrieved_text) + if len(tokens) == 0: + raise EmptySampleError("Received empty sample...") + return b"".join(map(self._encoded_token_to_bytes, tokens)) + self._encoded_eos_token_as_bytes + + def _encoded_token_to_bytes(self, encoded_token: int) -> bytes: + """ + Converts an encoded token to its byte representaion. + + Args: + encoded_token (int): The encoded token to be converted. + + Returns: + bytes: The byte representation of the token. + + """ + return encoded_token.to_bytes(self._token_size_in_bytes, byteorder="little", signed=False) + + +class WriterWorker(mp.Process): + def __init__( + self, token_size_in_bytes: int, writer_q: mp.Queue, logging_message_q: mp.Queue, dst_path: Path, stop_event: Event, index_start: int, + process_id: int + ): + super().__init__() + self._token_size_in_bytes = token_size_in_bytes + self._dst_path = dst_path + self._writer_q = writer_q + self._logging_message_q = logging_message_q + self._stop_event = stop_event + self._index_start = index_start + self.process_id = process_id + + def run(self): + index_list = [] + if not self._dst_path.parent.exists(): + self._dst_path.parent.mkdir(parents=True, exist_ok=True) + with self._dst_path.open("wb") as f: + # allocate first self.header_size_in_bytes bytes for header (encodes length of data section) + # not possible to prepend header after determining size of data section + f.write((0).to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little")) + f.write( + self._token_size_in_bytes.to_bytes( + EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="little" + ) + ) + # The offset only applies to the data section, not the header + # When we load the file, we add the header size to the offset + curr_offset = 0 + + # write data section (tokens) + prev_line_id = self._index_start - 1 + batch_dict = {} + while not self._stop_event.is_set(): + try: + batch: list[Sample] = self._writer_q.get(timeout=3) + except queue.Empty: + continue + if batch is None: + break + line_id = batch[0].incremental_line_id + batch_dict[line_id] = batch + + while prev_line_id + 1 in batch_dict: + batch = batch_dict.pop(prev_line_id + 1) + prev_line_id, curr_offset = WriterWorker._write_batch( + batch, prev_line_id, curr_offset, index_list, f + ) + self._logging_message_q.put(ProgressMessage(WorkerTypes.WRITER, self.process_id, len(batch))) + + # write index + f.write(pickle.dumps(index_list)) + if not self._stop_event.is_set() and len(index_list) > 0 and len(batch_dict) == 0: + self._update_data_length_in_pre_allocated_header(self._dst_path, index_list) + else: + # if the process was stopped due to a stop event or the index list is empty, we remove the file + get_logger(name="main").warning(f"Removing file {self._dst_path} due to empty index list or stop event or non-empty batch_dict. " + f"stop_event: {self._stop_event.is_set()}, index_list: {len(index_list)}, batch_dict: {batch_dict.keys()}") + os.remove(self._dst_path) + + # writes a batch received from the writer_q to the destination file + def _write_batch( + batch: list[Sample], prev_line_id: int, curr_offset: int, index_list: list, f: BufferedWriter + ) -> tuple[int, int]: + # write the tokens for each document + for sample in batch: + if prev_line_id + 1 != sample.incremental_line_id: + raise ValueError( + f"Line IDs are not consecutive. Expected {prev_line_id + 1}, but got {sample.incremental_line_id}" + ) + f.write(sample.content_tokenized) + segment_length = len(sample.content_tokenized) + index_list.append((curr_offset, segment_length)) + curr_offset += segment_length + prev_line_id = sample.incremental_line_id + return prev_line_id, curr_offset + + @staticmethod + def _update_data_length_in_pre_allocated_header(dst_path: Path, index_list: list[tuple[int, int]]): + # Update the length of the data section in the pre-allocated header of the destination file. + # The data segment length is sum of the starting position and the length of the last document. + length_of_byte_encoded_data_section = index_list[-1][0] + index_list[-1][1] + data_section_length_in_bytes = length_of_byte_encoded_data_section.to_bytes( + EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little" + ) + with dst_path.open("rb+") as fout: + fout.seek(0) + fout.write(data_section_length_in_bytes) + + +class WorkerTypes(Enum): + READER = "READER" + TOKENIZER = "TOKENIZER" + WRITER = "WRITER" + +@dataclass +class ProgressMessage: + worker_type: WorkerTypes + process_id: int + num_samples: int + + +class ProgressLoggingWorker(mp.Process): + def __init__(self, logging_message_q: mp.Queue, logging_interval: int, reader_q: mp.Queue, tokenizer_q: mp.Queue, writer_q: mp.Queue, total_num_samples: int, stop_event: Event): + super().__init__() + self._logging_message_q = logging_message_q + self._logging_interval = logging_interval + self._reader_q = reader_q + self._tokenizer_q = tokenizer_q + self._writer_q = writer_q + self._stop_event = stop_event + self._worker_to_pid_to_num_samples: dict[WorkerTypes, dict[int, int]] = {} + + self._total_num_samples = total_num_samples + self._worker_type_to_processed_num_samples = {worker_type: 0 for worker_type in WorkerTypes} + + def _add_progress_message(self, progress_message: ProgressMessage): + if progress_message.worker_type not in self._worker_to_pid_to_num_samples: + self._worker_to_pid_to_num_samples[progress_message.worker_type] = {} + + if progress_message.process_id not in self._worker_to_pid_to_num_samples[progress_message.worker_type]: + self._worker_to_pid_to_num_samples[progress_message.worker_type][progress_message.process_id] = 0 + + self._worker_to_pid_to_num_samples[progress_message.worker_type][progress_message.process_id] += progress_message.num_samples + self._worker_type_to_processed_num_samples[progress_message.worker_type] += progress_message.num_samples + + + def _log_and_reset(self, passed_time: int): + logging_message = f"\n==================Progress report (last {passed_time}s) ==================\n" + + logging_message += f"Total progress: \n" + for worker_type, processed_num_samples in self._worker_type_to_processed_num_samples.items(): + logging_message += f"\t{worker_type.name}: {processed_num_samples}/{self._total_num_samples} samples ({processed_num_samples/self._total_num_samples*100}%)\n" + + + logging_message += "\n" + logging_message += f"Aggregated Throughput: \n" + + for worker_type, pid_to_num_samples in self._worker_to_pid_to_num_samples.items(): + total_samples = sum(pid_to_num_samples.values()) + logging_message += f"\t{worker_type.name} workers: {total_samples/passed_time} samples/s.\n" + logging_message += "\n" + logging_message += f"Worker Throughput: \n" + for worker_type, pid_to_num_samples in self._worker_to_pid_to_num_samples.items(): + logging_message += f"{worker_type.name} workers:\n" + for pid, num_samples in pid_to_num_samples.items(): + logging_message += f"\t{worker_type.name} {pid}: {num_samples/passed_time} samples/s.\n" + logging_message += "\n" + logging_message += "\n" + + logging_message += "Queues: \n" + logging_message += f"\tReader queue: {self._reader_q.qsize()} batches (approx.)\n" + logging_message += f"\tTokenizer queue: {self._tokenizer_q.qsize()} batches (approx.)\n" + logging_message += f"\tWriter queue: {self._writer_q.qsize()} batches (approx.)\n" + + get_logger().info(logging_message) + + # reset values + for worker_type in self._worker_to_pid_to_num_samples.keys(): + self._worker_to_pid_to_num_samples[worker_type] = {pid: 0 for pid in self._worker_to_pid_to_num_samples[worker_type].keys()} + + + def run(self): + last_logged = time.time() + last_step = False + while not self._stop_event.is_set(): + try: + progress_message: ProgressMessage = self._logging_message_q.get(timeout=1) + if progress_message is None: + last_step = True + break + self._add_progress_message(progress_message) + except queue.Empty: + continue + finally: + passed_time = time.time() - last_logged + if passed_time > self._logging_interval or last_step: + self._log_and_reset(passed_time) + last_logged = time.time() + + +class ProcessFactory: + + @staticmethod + def get_reader_workers( + rw_settings: PackedDatasetComponentsInstantiationModel.ReaderWorkerSettings, + reader_q: mp.Queue, + tokenizer_q: mp.Queue, + logging_message_q: mp.Queue, + stop_event: Event, + ) -> list[tuple[Type[Callable], BaseModel]]: + # create readers + reader_type = rw_settings.reader_settings.reader_type + if reader_type == LargeFileLinesReaderTypes.LOCAL: + readers = [ + (LargeFileLinesReaderFactory.get_local_reader, rw_settings.reader_settings.reader_args) + for _ in range(rw_settings.num_reader_processes) + ] + + elif reader_type == LargeFileLinesReaderTypes.GLOBAL: + readers = [ + (LargeFileLinesReaderFactory.get_global_reader, rw_settings.reader_settings.reader_args) + for _ in range(rw_settings.num_reader_processes) + ] + else: + raise ValueError(f"Reader type {reader_type} is not supported.") + + # create reader workers + reader_workers = [ + ReaderWorker( + reader_type= reader_type, + reader_args = reader_args, + reader_q=reader_q, + tokenizer_q=tokenizer_q, + logging_message_q=logging_message_q, + stop_event=stop_event, + process_id=pid, + + ) + for pid, (reader_type, reader_args) in enumerate(readers) + ] + + return reader_workers + + def get_tokenizer_workers( + tokenizer_q: mp.Queue, + writer_q: mp.Queue, + logging_message_q: mp.Queue, + token_size_in_bytes: int, + tw_settings: PackedDatasetComponentsInstantiationModel.TokenizerWorkerSettings, + stop_event: Event, + ) -> list[TokenizerWorker]: + tokenizer_settings = tw_settings.tokenizer_settings + tokenizer_workers = [ + TokenizerWorker( + process_id=i, + stop_event=stop_event, + tokenizer_q=tokenizer_q, + writer_q=writer_q, + logging_message_q=logging_message_q, + token_size_in_bytes=token_size_in_bytes, + **tokenizer_settings.model_dump(), + ) + for i in range(tw_settings.num_tokenizer_processes) + ] + return tokenizer_workers + + def get_writer_worker( + writer_q: mp.Queue, + logging_message_q: mp.Queue, + token_size_in_bytes: int, + ww_settings: PackedDatasetComponentsInstantiationModel.WriterWorkerSettings, + stop_event: Event, + ) -> WriterWorker: + writer_worker = WriterWorker( + writer_q=writer_q, + logging_message_q=logging_message_q, + token_size_in_bytes=token_size_in_bytes, + dst_path=ww_settings.dst_path, + index_start=ww_settings.index_start, + stop_event=stop_event, + process_id=0, + ) + return writer_worker + + @staticmethod + def get_process_queues(tokenizer_q_maxsize: int, writer_q_maxsize) -> tuple[mp.Queue, mp.Queue, mp.Queue]: + reader_q = mp.Queue() # containes line_ids to be read + tokenizer_q = mp.Queue(maxsize=tokenizer_q_maxsize) # contains (line_id, line) pairs to be tokenized + writer_q = mp.Queue(maxsize=writer_q_maxsize) # contains (line_id, tokenized_line) to be written to disc + logging_message_q = mp.Queue() + return reader_q, tokenizer_q, writer_q, logging_message_q diff --git a/src/modalities/utils/env_variables.py b/src/modalities/utils/env_variables.py new file mode 100644 index 000000000..15c4630d5 --- /dev/null +++ b/src/modalities/utils/env_variables.py @@ -0,0 +1,22 @@ +import os +from contextlib import contextmanager + +@contextmanager +def temporary_env_var(key, value): + """ + Temporarily set an environment variable. + + Args: + key (str): The environment variable name. + value (str): The temporary value to set. + """ + original_value = os.environ.get(key) # Store the original value (if any) + os.environ[key] = value # Set the temporary value + try: + yield # Allow code execution within the context + finally: + # Restore the original value or delete the key if it wasn't set originally + if original_value is None: + del os.environ[key] + else: + os.environ[key] = original_value \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index c05cb2c80..7214dac84 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,9 +11,9 @@ from modalities.checkpointing.checkpoint_saving import CheckpointSaving from modalities.config.config import load_app_config_dict -from modalities.dataloader.create_index import IndexGenerator +from modalities.dataloader.preprocessing.indexation.create_index import IndexGenerator from modalities.dataloader.dataloader import LLMDataLoader -from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LocalLargeFileLinesReader from modalities.evaluator import Evaluator from modalities.logging_broker.publisher import MessagePublisher from modalities.loss_functions import Loss @@ -67,7 +67,7 @@ def dummy_data_path(tmpdir) -> DataPathCollection: source_raw_dummy_data_path = _ROOT_DIR / Path("./data/lorem_ipsum.jsonl") dummy_data_path = Path(tmpdir, source_raw_dummy_data_path.name) dummy_data_path.write_text(source_raw_dummy_data_path.read_text()) - index_path = LargeFileLinesReader.default_index_path(dummy_data_path) + index_path = LocalLargeFileLinesReader.default_index_path(dummy_data_path) index_path.unlink(missing_ok=True) return DataPathCollection(raw_data_path=dummy_data_path, index_path=index_path) @@ -77,7 +77,7 @@ def dummy_data_path_long(tmpdir) -> DataPathCollection: source_raw_dummy_data_path = _ROOT_DIR / Path("./data/lorem_ipsum_long.jsonl") dummy_data_path = Path(tmpdir, source_raw_dummy_data_path.name) dummy_data_path.write_text(source_raw_dummy_data_path.read_text()) - index_path = LargeFileLinesReader.default_index_path(dummy_data_path) + index_path = LocalLargeFileLinesReader.default_index_path(dummy_data_path) index_path.unlink(missing_ok=True) return DataPathCollection(raw_data_path=dummy_data_path, index_path=index_path) diff --git a/tests/dataloader/test_large_file_lines_reader.py b/tests/dataloader/test_large_file_lines_reader.py index 47afd9074..914014f5a 100644 --- a/tests/dataloader/test_large_file_lines_reader.py +++ b/tests/dataloader/test_large_file_lines_reader.py @@ -7,7 +7,7 @@ import pytest from modalities.dataloader.create_index import IndexGenerator -from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LocalLargeFileLinesReader from tests.conftest import DataPathCollection @@ -75,7 +75,7 @@ def generate_data_index_file(data_path: Path, **kwargs): ) def test_large_file_lines_reader_text(indexed_dummy_data_path: DataPathCollection, use_sample_length_from_index: bool): raw_data_path = indexed_dummy_data_path.raw_data_path - reader = LargeFileLinesReader( + reader = LocalLargeFileLinesReader( raw_data_path, use_sample_length_from_index=use_sample_length_from_index, encoding="utf-8" ) assert raw_data_path.read_text().count("\n") == 12 @@ -106,10 +106,10 @@ def test_large_file_lines_reader_binary_text_equivalence( indexed_dummy_data_path: DataPathCollection, use_sample_length_from_index: bool ): raw_data_path = indexed_dummy_data_path.raw_data_path - reader_binary = LargeFileLinesReader( + reader_binary = LocalLargeFileLinesReader( raw_data_path, use_sample_length_from_index=use_sample_length_from_index, encoding=None ) - reader_text = LargeFileLinesReader( + reader_text = LocalLargeFileLinesReader( raw_data_path, use_sample_length_from_index=use_sample_length_from_index, encoding="utf-8" ) @@ -124,4 +124,4 @@ def test_large_file_lines_reader_missing_source_data(dummy_data_path: DataPathCo raw_data_path.unlink(missing_ok=True) assert not raw_data_path.exists() with pytest.raises(FileNotFoundError): - LargeFileLinesReader(raw_data_path, dummy_data_path.index_path) + LocalLargeFileLinesReader(raw_data_path, dummy_data_path.index_path) From e876c977c550cd949948f9bbec8873fc9876f0a2 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Sun, 5 Jan 2025 11:25:07 +0100 Subject: [PATCH 03/25] refactor: finalized processing strategies --- .pre-commit-config.yaml | 2 +- .../tokenization/tokenization_processes.py | 456 ------------------ .../tokenization/tokenization_strategies.py | 423 ++++++++++++++++ 3 files changed, 424 insertions(+), 457 deletions(-) delete mode 100644 src/modalities/dataloader/preprocessing/tokenization/tokenization_processes.py create mode 100644 src/modalities/dataloader/preprocessing/tokenization/tokenization_strategies.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a9b1b18e4..7dc1c8a95 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: rev: 23.9.1 hooks: - id: black - language_version: python3.10 + language_version: python3.11 stages: [pre-commit] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.0.278 diff --git a/src/modalities/dataloader/preprocessing/tokenization/tokenization_processes.py b/src/modalities/dataloader/preprocessing/tokenization/tokenization_processes.py deleted file mode 100644 index 8123cd6df..000000000 --- a/src/modalities/dataloader/preprocessing/tokenization/tokenization_processes.py +++ /dev/null @@ -1,456 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -import math -import multiprocessing as mp -import os -import pickle -import time -import traceback -from typing import Any, Callable, Type -import warnings -from io import BufferedWriter -from pathlib import Path -from multiprocessing.synchronize import Event -from data_quality_ablations.utils.logging import get_logger -import jq -from modalities.config.instantiation_models import PackedDatasetComponentsInstantiationModel -from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import EmbeddedStreamData -from modalities.exceptions import EmptySampleError -from pydantic import BaseModel -from tqdm import tqdm -import queue - -from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import ( - BaseReader, - LargeFileLinesReaderFactory, - LargeFileLinesReaderTypes, - Sample, -) -from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper - - -def get_required_num_of_bytes_to_repr(int_to_get_repr: int) -> int: - """ - Calculates the required number of bytes to represent an integer. - - Args: - int_to_get_repr (int): The integer to get the representation for. - - Returns: - int: The number of bytes required to represent the integer. - """ - # we currently only support token sizes of 1, 2 and 4 bytes, as implemented here: - # https://github.com/Modalities/modalities/blob/fix_char_bytes_indexation_mismatch/src/modalities/dataloader/dataset.py#L202 - num_bytes = math.ceil(math.log2(int_to_get_repr) / 8) - if num_bytes == 1: - return 1 - elif num_bytes == 2: - return 2 - elif num_bytes <= 4: - return 4 - else: - raise ValueError("Currently only support token byte sizes of 1, 2, and 4.") - - -class ReaderWorker(mp.Process): - def __init__( - self, - reader_type: Type[BaseReader], - reader_args: BaseModel, - reader_q: mp.Queue, - tokenizer_q: mp.Queue, - logging_message_q: mp.Queue, - process_id: int, - stop_event: Event, - ): - super().__init__() - self._reader_q = reader_q - self._tokenizer_q = tokenizer_q - self._logging_message_q = logging_message_q - self._reader_type = reader_type - self._reader_args = reader_args - self._stop_event = stop_event - self.process_id = process_id - - def run(self): - reader = self._reader_type(**self._reader_args.model_dump()) - batch = [] - num_samples_read = 0 - while not self._stop_event.is_set(): - try: - # we set the timout here, such that the worker can check if the stop_event is set - item = self._reader_q.get(timeout=3) - except queue.Empty: - continue - if item is None: - print(f"Reading worker with pid {mp.current_process().pid} exiting, Read {num_samples_read} samples") - break - sample_id, batch_size = item - - - batch: list[Sample] = [reader[sample_id + i] for i in range(batch_size)] - self._tokenizer_q.put(batch) - self._logging_message_q.put(ProgressMessage(WorkerTypes.READER, self.process_id, len(batch))) - num_samples_read += len(batch) - - if not self._stop_event.is_set(): - # add the remaining samples - if len(batch) > 0: - self._tokenizer_q.put(batch) - self._logging_message_q.put(ProgressMessage(WorkerTypes.READER, self.process_id, len(batch))) - - - -class TokenizerWorker(mp.Process): - def __init__( - self, - tokenizer: TokenizerWrapper, - eod_token: str, - token_size_in_bytes: int, - tokenizer_q: mp.Queue, - logging_message_q: mp.Queue, - writer_q: mp.Queue, - jq_pattern: str, - process_id: int, - stop_event: Event, - ): - super().__init__() - self._jq_filter = jq.compile(jq_pattern) - self.tokenizer = tokenizer - self.eod_token = eod_token - self._token_size_in_bytes = token_size_in_bytes - encoded_eod_token = self.tokenizer.get_token_id(self.eod_token) - self._encoded_eos_token_as_bytes = self._encoded_token_to_bytes(encoded_eod_token) - self._tokenizer_q = tokenizer_q - self._writer_q = writer_q - self._logging_message_q = logging_message_q - self._process_id = process_id - self._stop_event = stop_event - - def run(self): - # Process the lines in a batch and put the processed samples into the writer_q. - - while not self._stop_event.is_set(): - try: - batch: list[Sample] = self._tokenizer_q.get(timeout=10) - except queue.Empty: - continue - if batch is None: - break - - try: - batch_processed = [] - for sample in batch: - processed_line = self._process_line(sample.content_raw) - sample.content_tokenized = processed_line - batch_processed.append(sample) - self._writer_q.put(batch_processed) - self._logging_message_q.put(ProgressMessage(WorkerTypes.TOKENIZER, self._process_id, len(batch))) - except EmptySampleError: - warnings.warn( - f"Encountered empty sample in line {sample.shuffled_line_id} in file {sample.raw_data_path} within process {self._process_id}" - ) - except Exception as exception: - warnings.warn( - f"Could not process line {sample.shuffled_line_id} in file {sample.raw_data_path} within process {self._process_id}. " - f"Raised the following error: {exception=}" - ) - traceback.print_exc() - - def _process_line(self, line: str) -> bytes: - # extracts the text via the jq_filter and applies tokenization to the extract text - jq_retrieved_text = self._jq_filter.input_text(line).first() - if jq_retrieved_text is None: - raise ValueError(f"jq was not able to find anything using the expression: {self._jq_filter}") - tokens = self.tokenizer.tokenize(jq_retrieved_text) - if len(tokens) == 0: - raise EmptySampleError("Received empty sample...") - return b"".join(map(self._encoded_token_to_bytes, tokens)) + self._encoded_eos_token_as_bytes - - def _encoded_token_to_bytes(self, encoded_token: int) -> bytes: - """ - Converts an encoded token to its byte representaion. - - Args: - encoded_token (int): The encoded token to be converted. - - Returns: - bytes: The byte representation of the token. - - """ - return encoded_token.to_bytes(self._token_size_in_bytes, byteorder="little", signed=False) - - -class WriterWorker(mp.Process): - def __init__( - self, token_size_in_bytes: int, writer_q: mp.Queue, logging_message_q: mp.Queue, dst_path: Path, stop_event: Event, index_start: int, - process_id: int - ): - super().__init__() - self._token_size_in_bytes = token_size_in_bytes - self._dst_path = dst_path - self._writer_q = writer_q - self._logging_message_q = logging_message_q - self._stop_event = stop_event - self._index_start = index_start - self.process_id = process_id - - def run(self): - index_list = [] - if not self._dst_path.parent.exists(): - self._dst_path.parent.mkdir(parents=True, exist_ok=True) - with self._dst_path.open("wb") as f: - # allocate first self.header_size_in_bytes bytes for header (encodes length of data section) - # not possible to prepend header after determining size of data section - f.write((0).to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little")) - f.write( - self._token_size_in_bytes.to_bytes( - EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="little" - ) - ) - # The offset only applies to the data section, not the header - # When we load the file, we add the header size to the offset - curr_offset = 0 - - # write data section (tokens) - prev_line_id = self._index_start - 1 - batch_dict = {} - while not self._stop_event.is_set(): - try: - batch: list[Sample] = self._writer_q.get(timeout=3) - except queue.Empty: - continue - if batch is None: - break - line_id = batch[0].incremental_line_id - batch_dict[line_id] = batch - - while prev_line_id + 1 in batch_dict: - batch = batch_dict.pop(prev_line_id + 1) - prev_line_id, curr_offset = WriterWorker._write_batch( - batch, prev_line_id, curr_offset, index_list, f - ) - self._logging_message_q.put(ProgressMessage(WorkerTypes.WRITER, self.process_id, len(batch))) - - # write index - f.write(pickle.dumps(index_list)) - if not self._stop_event.is_set() and len(index_list) > 0 and len(batch_dict) == 0: - self._update_data_length_in_pre_allocated_header(self._dst_path, index_list) - else: - # if the process was stopped due to a stop event or the index list is empty, we remove the file - get_logger(name="main").warning(f"Removing file {self._dst_path} due to empty index list or stop event or non-empty batch_dict. " - f"stop_event: {self._stop_event.is_set()}, index_list: {len(index_list)}, batch_dict: {batch_dict.keys()}") - os.remove(self._dst_path) - - # writes a batch received from the writer_q to the destination file - def _write_batch( - batch: list[Sample], prev_line_id: int, curr_offset: int, index_list: list, f: BufferedWriter - ) -> tuple[int, int]: - # write the tokens for each document - for sample in batch: - if prev_line_id + 1 != sample.incremental_line_id: - raise ValueError( - f"Line IDs are not consecutive. Expected {prev_line_id + 1}, but got {sample.incremental_line_id}" - ) - f.write(sample.content_tokenized) - segment_length = len(sample.content_tokenized) - index_list.append((curr_offset, segment_length)) - curr_offset += segment_length - prev_line_id = sample.incremental_line_id - return prev_line_id, curr_offset - - @staticmethod - def _update_data_length_in_pre_allocated_header(dst_path: Path, index_list: list[tuple[int, int]]): - # Update the length of the data section in the pre-allocated header of the destination file. - # The data segment length is sum of the starting position and the length of the last document. - length_of_byte_encoded_data_section = index_list[-1][0] + index_list[-1][1] - data_section_length_in_bytes = length_of_byte_encoded_data_section.to_bytes( - EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little" - ) - with dst_path.open("rb+") as fout: - fout.seek(0) - fout.write(data_section_length_in_bytes) - - -class WorkerTypes(Enum): - READER = "READER" - TOKENIZER = "TOKENIZER" - WRITER = "WRITER" - -@dataclass -class ProgressMessage: - worker_type: WorkerTypes - process_id: int - num_samples: int - - -class ProgressLoggingWorker(mp.Process): - def __init__(self, logging_message_q: mp.Queue, logging_interval: int, reader_q: mp.Queue, tokenizer_q: mp.Queue, writer_q: mp.Queue, total_num_samples: int, stop_event: Event): - super().__init__() - self._logging_message_q = logging_message_q - self._logging_interval = logging_interval - self._reader_q = reader_q - self._tokenizer_q = tokenizer_q - self._writer_q = writer_q - self._stop_event = stop_event - self._worker_to_pid_to_num_samples: dict[WorkerTypes, dict[int, int]] = {} - - self._total_num_samples = total_num_samples - self._worker_type_to_processed_num_samples = {worker_type: 0 for worker_type in WorkerTypes} - - def _add_progress_message(self, progress_message: ProgressMessage): - if progress_message.worker_type not in self._worker_to_pid_to_num_samples: - self._worker_to_pid_to_num_samples[progress_message.worker_type] = {} - - if progress_message.process_id not in self._worker_to_pid_to_num_samples[progress_message.worker_type]: - self._worker_to_pid_to_num_samples[progress_message.worker_type][progress_message.process_id] = 0 - - self._worker_to_pid_to_num_samples[progress_message.worker_type][progress_message.process_id] += progress_message.num_samples - self._worker_type_to_processed_num_samples[progress_message.worker_type] += progress_message.num_samples - - - def _log_and_reset(self, passed_time: int): - logging_message = f"\n==================Progress report (last {passed_time}s) ==================\n" - - logging_message += f"Total progress: \n" - for worker_type, processed_num_samples in self._worker_type_to_processed_num_samples.items(): - logging_message += f"\t{worker_type.name}: {processed_num_samples}/{self._total_num_samples} samples ({processed_num_samples/self._total_num_samples*100}%)\n" - - - logging_message += "\n" - logging_message += f"Aggregated Throughput: \n" - - for worker_type, pid_to_num_samples in self._worker_to_pid_to_num_samples.items(): - total_samples = sum(pid_to_num_samples.values()) - logging_message += f"\t{worker_type.name} workers: {total_samples/passed_time} samples/s.\n" - logging_message += "\n" - logging_message += f"Worker Throughput: \n" - for worker_type, pid_to_num_samples in self._worker_to_pid_to_num_samples.items(): - logging_message += f"{worker_type.name} workers:\n" - for pid, num_samples in pid_to_num_samples.items(): - logging_message += f"\t{worker_type.name} {pid}: {num_samples/passed_time} samples/s.\n" - logging_message += "\n" - logging_message += "\n" - - logging_message += "Queues: \n" - logging_message += f"\tReader queue: {self._reader_q.qsize()} batches (approx.)\n" - logging_message += f"\tTokenizer queue: {self._tokenizer_q.qsize()} batches (approx.)\n" - logging_message += f"\tWriter queue: {self._writer_q.qsize()} batches (approx.)\n" - - get_logger().info(logging_message) - - # reset values - for worker_type in self._worker_to_pid_to_num_samples.keys(): - self._worker_to_pid_to_num_samples[worker_type] = {pid: 0 for pid in self._worker_to_pid_to_num_samples[worker_type].keys()} - - - def run(self): - last_logged = time.time() - last_step = False - while not self._stop_event.is_set(): - try: - progress_message: ProgressMessage = self._logging_message_q.get(timeout=1) - if progress_message is None: - last_step = True - break - self._add_progress_message(progress_message) - except queue.Empty: - continue - finally: - passed_time = time.time() - last_logged - if passed_time > self._logging_interval or last_step: - self._log_and_reset(passed_time) - last_logged = time.time() - - -class ProcessFactory: - - @staticmethod - def get_reader_workers( - rw_settings: PackedDatasetComponentsInstantiationModel.ReaderWorkerSettings, - reader_q: mp.Queue, - tokenizer_q: mp.Queue, - logging_message_q: mp.Queue, - stop_event: Event, - ) -> list[tuple[Type[Callable], BaseModel]]: - # create readers - reader_type = rw_settings.reader_settings.reader_type - if reader_type == LargeFileLinesReaderTypes.LOCAL: - readers = [ - (LargeFileLinesReaderFactory.get_local_reader, rw_settings.reader_settings.reader_args) - for _ in range(rw_settings.num_reader_processes) - ] - - elif reader_type == LargeFileLinesReaderTypes.GLOBAL: - readers = [ - (LargeFileLinesReaderFactory.get_global_reader, rw_settings.reader_settings.reader_args) - for _ in range(rw_settings.num_reader_processes) - ] - else: - raise ValueError(f"Reader type {reader_type} is not supported.") - - # create reader workers - reader_workers = [ - ReaderWorker( - reader_type= reader_type, - reader_args = reader_args, - reader_q=reader_q, - tokenizer_q=tokenizer_q, - logging_message_q=logging_message_q, - stop_event=stop_event, - process_id=pid, - - ) - for pid, (reader_type, reader_args) in enumerate(readers) - ] - - return reader_workers - - def get_tokenizer_workers( - tokenizer_q: mp.Queue, - writer_q: mp.Queue, - logging_message_q: mp.Queue, - token_size_in_bytes: int, - tw_settings: PackedDatasetComponentsInstantiationModel.TokenizerWorkerSettings, - stop_event: Event, - ) -> list[TokenizerWorker]: - tokenizer_settings = tw_settings.tokenizer_settings - tokenizer_workers = [ - TokenizerWorker( - process_id=i, - stop_event=stop_event, - tokenizer_q=tokenizer_q, - writer_q=writer_q, - logging_message_q=logging_message_q, - token_size_in_bytes=token_size_in_bytes, - **tokenizer_settings.model_dump(), - ) - for i in range(tw_settings.num_tokenizer_processes) - ] - return tokenizer_workers - - def get_writer_worker( - writer_q: mp.Queue, - logging_message_q: mp.Queue, - token_size_in_bytes: int, - ww_settings: PackedDatasetComponentsInstantiationModel.WriterWorkerSettings, - stop_event: Event, - ) -> WriterWorker: - writer_worker = WriterWorker( - writer_q=writer_q, - logging_message_q=logging_message_q, - token_size_in_bytes=token_size_in_bytes, - dst_path=ww_settings.dst_path, - index_start=ww_settings.index_start, - stop_event=stop_event, - process_id=0, - ) - return writer_worker - - @staticmethod - def get_process_queues(tokenizer_q_maxsize: int, writer_q_maxsize) -> tuple[mp.Queue, mp.Queue, mp.Queue]: - reader_q = mp.Queue() # containes line_ids to be read - tokenizer_q = mp.Queue(maxsize=tokenizer_q_maxsize) # contains (line_id, line) pairs to be tokenized - writer_q = mp.Queue(maxsize=writer_q_maxsize) # contains (line_id, tokenized_line) to be written to disc - logging_message_q = mp.Queue() - return reader_q, tokenizer_q, writer_q, logging_message_q diff --git a/src/modalities/dataloader/preprocessing/tokenization/tokenization_strategies.py b/src/modalities/dataloader/preprocessing/tokenization/tokenization_strategies.py new file mode 100644 index 000000000..389546853 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/tokenization/tokenization_strategies.py @@ -0,0 +1,423 @@ +import math +import multiprocessing as mp +import os +import pickle +import time +from dataclasses import dataclass +from enum import Enum +from io import BufferedWriter +from pathlib import Path +from typing import Optional, Type + +import jq +import tqdm +from data_quality_ablations.utils.logging import get_logger +from pydantic import BaseModel + +from modalities.config.component_factory import ComponentFactory +from modalities.config.instantiation_models import TokenizationInstantiationModel +from modalities.dataloader.preprocessing.queued_processing.processing_strategy_if import ProcessingStrategyIF +from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import EmbeddedStreamData +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import ( + BaseReader, + LargeFileLinesReaderFactory, + LargeFileLinesReaderTypes, + Sample, +) +from modalities.exceptions import EmptySampleError +from modalities.registry.components import COMPONENTS +from modalities.registry.registry import Registry +from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper + + +def get_required_num_of_bytes_to_repr(int_to_get_repr: int) -> int: + """ + Calculates the required number of bytes to represent an integer. + + Args: + int_to_get_repr (int): The integer to get the representation for. + + Returns: + int: The number of bytes required to represent the integer. + """ + # we currently only support token sizes of 1, 2 and 4 bytes, as implemented here: + # https://github.com/Modalities/modalities/blob/fix_char_bytes_indexation_mismatch/src/modalities/dataloader/dataset.py#L202 + num_bytes = math.ceil(math.log2(int_to_get_repr) / 8) + if num_bytes == 1: + return 1 + elif num_bytes == 2: + return 2 + elif num_bytes <= 4: + return 4 + else: + raise ValueError("Currently only support token byte sizes of 1, 2, and 4.") + + +def populate_reader_q( + reader_q: mp.Queue, index_start: int, num_samples: int, num_reader_processes: int, batch_size: int +): + # populate the reader queue with the line_ids that we want to tokenize + + for i in tqdm.tqdm( + range(index_start, index_start + num_samples, batch_size), desc="Filling up reader queue with line ids" + ): + reader_q.put(ReadingJob(sample_id=i, batch_size=batch_size)) + for _ in range(num_reader_processes): + reader_q.put(None) + + +@dataclass +class ReadingJob: + sample_id: int + batch_size: int + + +class WorkerTypes(Enum): + READER = "READER" + TOKENIZER = "TOKENIZER" + WRITER = "WRITER" + + +@dataclass +class ProgressMessage: + worker_type: WorkerTypes + num_samples: int + process_type: Optional[str] = None + process_id: Optional[str] = None + + +class ReadingStrategy(ProcessingStrategyIF): + def __init__( + self, reader_type: Type[BaseReader], reader_args: BaseModel, tokenizer_q_key: str, logging_message_q_key: str + ): + self._reader_type = reader_type + self._reader_args = reader_args + self._reader = None + self._tokenizer_q_key = tokenizer_q_key + self._logging_message_q_key = logging_message_q_key + + def __enter__(self): + self._reader = self._reader_type(**self._reader_args.model_dump()) + return self + + def finalize(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + self._reader.close() + + def process(self, item: ReadingJob) -> dict[str, list[Sample] | ProgressMessage]: + batch: list[Sample] = [self._reader[item.sample_id + i] for i in range(item.batch_size)] + progress_message = ProgressMessage(WorkerTypes.READER, len(batch)) + return {self._tokenizer_q_key: batch, self._logging_message_q_key: progress_message} + + +class TokenizingStrategy(ProcessingStrategyIF): + def __init__( + self, + ti_settings: ( + TokenizationInstantiationModel.TokenizerWorkerSettings.TokenizerSettings.TokenizerInstantitionSettings + ), + eod_token: str, + jq_pattern: str, + writer_q_key: str, + logging_message_q_key: str, + ): + self._tokenizer_instantiation_setings = ti_settings + self._eod_token = eod_token + self._jq_filter = jq.compile(jq_pattern) + self._writer_q_key = writer_q_key + self._logging_message_q_key = logging_message_q_key + + def __enter__(self): + registry = Registry(COMPONENTS) + component_factory = ComponentFactory(registry=registry) + self._tokenizer: TokenizerWrapper = component_factory.instantiate_component_config( + component_key=self._tokenizer_instantiation_setings.tokenizer_component_key, + variant_key=self._tokenizer_instantiation_setings.tokenizer_variant_key, + config_dict=self._tokenizer_instantiation_setings.config, + ) + encoded_eod_token = self._tokenizer.get_token_id(self._eod_token) + self._encoded_eos_token_as_bytes = self._encoded_token_to_bytes(encoded_eod_token) + self._token_size_in_bytes = get_required_num_of_bytes_to_repr(self._tokenizer.vocab_size) + return self + + def finalize(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def process(self, item: list[Sample]) -> dict[str, list[Sample] | ProgressMessage]: + batch_processed = [] + for sample in item: + processed_line = self._process_line(sample.content_raw) + sample.content_tokenized = processed_line + sample.token_size_in_bytes = self._token_size_in_bytes + batch_processed.append(sample) + progress_message = ProgressMessage(WorkerTypes.TOKENIZER, self.process_id, len(batch_processed)) + return {self._writer_q_key: batch_processed, self._logging_message_q_key: progress_message} + + def _process_line(self, line: str) -> bytes: + # extracts the text via the jq_filter and applies tokenization to the extract text + jq_retrieved_text = self._jq_filter.input_text(line).first() + if jq_retrieved_text is None: + raise ValueError(f"jq was not able extract the text using the expression: {self._jq_filter}") + tokens = self.tokenizer.tokenize(jq_retrieved_text) + if len(tokens) == 0: + raise EmptySampleError("Received empty sample...") + return b"".join(map(self._encoded_token_to_bytes, tokens)) + self._encoded_eos_token_as_bytes + + def _encoded_token_to_bytes(self, encoded_token: int) -> bytes: + # Converts an encoded token to its bytes representaion. + return encoded_token.to_bytes(self._token_size_in_bytes, byteorder="little", signed=False) + + +class WritingStrategy(ProcessingStrategyIF): + def __init__(self, dst_path: Path, index_start: int, logging_message_q_key: str): + self._dst_path = dst_path + self._index_start = index_start + self._logging_message_q_key = logging_message_q_key + + if not self._dst_path.parent.exists(): + self._dst_path.parent.mkdir(parents=True, exist_ok=True) + + def __enter__(self): + self._dst_fd = self._dst_path.open("wb") + self.finalized = False + # allocate first self.header_size_in_bytes bytes for header (encodes length of data section) + # not possible to prepend header after determining size of data section + self._dst_fd.write((0).to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little")) + + # The offset only applies to the data section, not the header + # When we load the file, we add the header size to the offset + self._curr_offset = 0 + + self._prev_line_id = self._index_start - 1 + self._batch_dict = {} + self._index_list = [] + self._has_seen_first_batch = False + + return self + + def finalize(self): + # check that the index list IS NOT empty and the batch_dict IS empty + # i.e., all batches have been written to the file + if len(self._index_list) == 0 or len(self._batch_dict) >= 0: + raise ValueError( + f"Could not finalize writing strategy. Index list is empty or batch_dict is not empty. " + f"Index list: {len(self._index_list)}, batch_dict: {self._batch_dict.keys()}" + ) + else: + # write index + self._dst_fd.write(pickle.dumps(self._index_list)) + self._dst_fd.close() + self._update_data_length_in_pre_allocated_header(self._dst_path, self._index_list) + self.finalized = True + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self.finalized: + self._dst_fd.close() + # if the process was stopped due to a stop event or the index list is empty, we remove the file + get_logger(name="main").warning( + f"Removing file {self._dst_path} due to non-finalized pbin file. The pbin file either is not " + "finalized as WritingStrategy.finalize() was not called or not all samples have been written " + f"to disc. index_list: {len(self._index_list)}, batch_dict: {self._batch_dict.keys()}" + ) + os.remove(self._dst_path) + + def process(self, item: list[Sample]) -> dict[str, ProgressMessage]: + if not self._has_seen_first_batch: + # write the token size descriptor to the file + # we receive this information from the tokenizer (based on the tokenizer's vocab size) + # and is always provided within the Sample object + self._has_seen_first_batch = True + self._dst_fd.write( + item[0].token_size_in_bytes.to_bytes( + EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="little" + ) + ) + + line_id = item[0].incremental_line_id + self._batch_dict[line_id] = item + + num_samples_written = 0 + while self._prev_line_id + 1 in self._batch_dict: + batch = self._batch_dict.pop(self._prev_line_id + 1) + self._prev_line_id, self._curr_offset = WritingStrategy._write_batch( + batch, self._prev_line_id, self._curr_offset, self._index_list, self._dst_fd + ) + num_samples_written += len(batch) + progress_message = ProgressMessage(WorkerTypes.WRITER, self.process_id, num_samples_written) + return {self._logging_key: progress_message} + + # writes a batch received from the writer_q to the destination file + @staticmethod + def _write_batch( + batch: list[Sample], prev_line_id: int, curr_offset: int, index_list: list, f: BufferedWriter + ) -> tuple[int, int]: + # write the tokens for each document + for sample in batch: + if prev_line_id + 1 != sample.incremental_line_id: + raise ValueError( + f"Line IDs are not consecutive. Expected {prev_line_id + 1}, but got {sample.incremental_line_id}" + ) + f.write(sample.content_tokenized) + segment_length = len(sample.content_tokenized) + index_list.append((curr_offset, segment_length)) + curr_offset += segment_length + prev_line_id = sample.incremental_line_id + return prev_line_id, curr_offset + + @staticmethod + def _update_data_length_in_pre_allocated_header(dst_path: Path, index_list: list[tuple[int, int]]): + # Update the length of the data section in the pre-allocated header of the destination file. + # The data segment length is sum of the starting position and the length of the last document. + length_of_byte_encoded_data_section = index_list[-1][0] + index_list[-1][1] + data_section_length_in_bytes = length_of_byte_encoded_data_section.to_bytes( + EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little" + ) + with dst_path.open("rb+") as fout: + fout.seek(0) + fout.write(data_section_length_in_bytes) + + +class ProgressLoggingStrategy(ProcessingStrategyIF): + def __init__( + self, + logging_interval: int, + total_num_samples: int, + q_dict: dict[str, mp.Queue], + ): + self._logging_interval = logging_interval + self._total_num_samples = total_num_samples + self._worker_to_pid_to_num_samples: dict[WorkerTypes, dict[int, int]] = {} + self._worker_type_to_processed_num_samples = {worker_type: 0 for worker_type in WorkerTypes} + self._q_dict = q_dict + + def __enter__(self): + self._last_logged = time.time() + + def finalize(self): + passed_time = time.time() - self._last_logged + self._log_and_reset(passed_time) + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def process(self, item: ProgressMessage) -> dict: + self._add_progress_message(item) + passed_time = time.time() - self._last_logged + if passed_time > self._logging_interval or self._last_step: + self._log_and_reset(passed_time) + self._last_logged = time.time() + + def _add_progress_message(self, progress_message: ProgressMessage): + if progress_message.worker_type not in self._worker_to_pid_to_num_samples: + self._worker_to_pid_to_num_samples[progress_message.worker_type] = {} + + if progress_message.process_id not in self._worker_to_pid_to_num_samples[progress_message.worker_type]: + self._worker_to_pid_to_num_samples[progress_message.worker_type][progress_message.process_id] = 0 + + self._worker_to_pid_to_num_samples[progress_message.worker_type][ + progress_message.process_id + ] += progress_message.num_samples + self._worker_type_to_processed_num_samples[progress_message.worker_type] += progress_message.num_samples + + def _log_and_reset(self, passed_time: int): + logging_message = f"\n==================Progress report (last {passed_time}s) ==================\n" + + logging_message += "Total progress: \n" + for worker_type, processed_num_samples in self._worker_type_to_processed_num_samples.items(): + m = ( + f"\t{worker_type.name}: {processed_num_samples}/{self._total_num_samples} samples " + f"({processed_num_samples/self._total_num_samples*100}%)\n" + ) + logging_message += m + + logging_message += "\n" + logging_message += "Aggregated Throughput: \n" + + for worker_type, pid_to_num_samples in self._worker_to_pid_to_num_samples.items(): + total_samples = sum(pid_to_num_samples.values()) + logging_message += f"\t{worker_type.name} workers: {total_samples/passed_time} samples/s.\n" + logging_message += "\n" + logging_message += "Worker Throughput: \n" + for worker_type, pid_to_num_samples in self._worker_to_pid_to_num_samples.items(): + logging_message += f"{worker_type.name} workers:\n" + for pid, num_samples in pid_to_num_samples.items(): + logging_message += f"\t{worker_type.name} {pid}: {num_samples/passed_time} samples/s.\n" + logging_message += "\n" + logging_message += "\n" + + logging_message += "Queues: \n" + logging_message += f"\tReader queue: {self._reader_q.qsize()} batches (approx.)\n" + logging_message += f"\tTokenizer queue: {self._tokenizer_q.qsize()} batches (approx.)\n" + logging_message += f"\tWriter queue: {self._writer_q.qsize()} batches (approx.)\n" + + get_logger().info(logging_message) + + # reset values + for worker_type in self._worker_to_pid_to_num_samples.keys(): + self._worker_to_pid_to_num_samples[worker_type] = { + pid: 0 for pid in self._worker_to_pid_to_num_samples[worker_type].keys() + } + + +class ProcessingStrategyFactory: + @staticmethod + def get_reader_strategy( + reader_settings: TokenizationInstantiationModel.ReaderWorkerSettings.ReaderSettings, + tokenizer_q_key: str, + logging_message_q_key: str, + ) -> ReadingStrategy: + reader_type = reader_settings.reader_type + if reader_type == LargeFileLinesReaderTypes.LOCAL: + return ReadingStrategy( + LargeFileLinesReaderFactory.get_local_reader, + reader_settings.reader_args, + tokenizer_q_key, + logging_message_q_key, + ) + elif reader_type == LargeFileLinesReaderTypes.GLOBAL: + return ReadingStrategy( + LargeFileLinesReaderFactory.get_global_reader, + reader_settings.reader_args, + tokenizer_q_key, + logging_message_q_key, + ) + else: + raise ValueError(f"Reader type {reader_type} is not supported.") + + def get_tokenizer_strategy( + tokenizer_settings: TokenizationInstantiationModel.TokenizerWorkerSettings.TokenizerSettings, + writer_q_key: str, + logging_message_q_key: str, + ) -> TokenizingStrategy: + tokenizing_strategy = TokenizingStrategy( + tokenizer_instantiation_setings=tokenizer_settings.tokenizer_instantiation_settings, + eod_token=tokenizer_settings.eod_token, + jq_pattern=tokenizer_settings.jq_pattern, + writer_q_key=writer_q_key, + logging_message_q_key=logging_message_q_key, + ) + return tokenizing_strategy + + def get_writing_strategy( + ww_settings: TokenizationInstantiationModel.WriterWorkerSettings, + logging_message_q_key: str, + ) -> WritingStrategy: + writing_strategy = WritingStrategy( + dst_path=ww_settings.dst_path, + index_start=ww_settings.index_start, + logging_message_q_key=logging_message_q_key, + ) + return writing_strategy + + @staticmethod + def get_process_queues(tokenizer_q_maxsize: int, writer_q_maxsize) -> tuple[mp.Queue, mp.Queue, mp.Queue]: + reader_q = mp.Queue() # containes line_ids to be read + tokenizer_q = mp.Queue(maxsize=tokenizer_q_maxsize) # contains (line_id, line) pairs to be tokenized + writer_q = mp.Queue(maxsize=writer_q_maxsize) # contains (line_id, tokenized_line) to be written to disc + logging_message_q = mp.Queue() + return reader_q, tokenizer_q, writer_q, logging_message_q From 3fe76905bf1b474907c090aa6d58c93dc2e6b72b Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Sun, 5 Jan 2025 11:27:31 +0100 Subject: [PATCH 04/25] feat: introduced process controller --- .../queued_processing/process_controller.py | 43 +++++++ .../tokenization/create_packed_data.py | 113 ------------------ .../tokenization/large_file_lines_reader.py | 19 +-- 3 files changed, 54 insertions(+), 121 deletions(-) create mode 100644 src/modalities/dataloader/preprocessing/queued_processing/process_controller.py delete mode 100644 src/modalities/dataloader/preprocessing/tokenization/create_packed_data.py diff --git a/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py b/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py new file mode 100644 index 000000000..de467fd39 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py @@ -0,0 +1,43 @@ +import multiprocessing as mp +from dataclasses import dataclass +from typing import Callable + +from modalities.dataloader.preprocessing.queued_processing.queued_processing import Processor +from modalities.utils.logging import get_logger + + +@dataclass +class PipelineStep: + name: str + input_queue: mp.Queue + processors: list[Processor] + + +class ProcessController: + def __init__(self, pipeline_steps: list[PipelineStep], populate_jobs: Callable): + """Initializes the ProcessController + Each pipeline step contains a list of processors that retrieve the data from the input queue, + process it and if necessary put it into the output queue of the next step. + """ + self._pipeline_steps = pipeline_steps + self._populate_jobs = populate_jobs + + def run(self): + # add the jobs to the input queues + get_logger().info("Populating jobs") + self._populate_jobs() + + # start the processors + for step in self._pipeline_steps: + get_logger().info(f"Starting processors for step {step.name}") + for processor in step.processors: + processor.start() + + # wait for the processors to finish + for step in self._pipeline_steps: + for _ in step.processors: + step.input_queue.put(None) + get_logger().info(f"Waiting for processors in step {step.name} to finish") + + for processor in step.processors: + processor.join() diff --git a/src/modalities/dataloader/preprocessing/tokenization/create_packed_data.py b/src/modalities/dataloader/preprocessing/tokenization/create_packed_data.py deleted file mode 100644 index 79c8f3401..000000000 --- a/src/modalities/dataloader/preprocessing/tokenization/create_packed_data.py +++ /dev/null @@ -1,113 +0,0 @@ -import multiprocessing as mp -import time - - -from modalities.dataloader.preprocessing.tokenization.tokenization_processes import ( - ProgressLoggingWorker, - ReaderWorker, - TokenizerWorker, - WriterWorker, -) -from modalities.utils.env_variables import temporary_env_var -from modalities.utils.logging import get_logger -import tqdm -import time - - - - -class PackedDataGenerator: - """Reads in a JSONL file and the corresponding index file and packs the dataset for LLM training.""" - - def __init__( - self, - reader_workers: list[ReaderWorker], - tokenizer_workers: list[TokenizerWorker], - writer_worker: WriterWorker, - progress_logging_worker: ProgressLoggingWorker, - reader_q: mp.Queue, - tokenizer_q: mp.Queue, - writer_q: mp.Queue, - logging_message_q: mp.Queue, - index_start: int, - num_samples: int, - batch_size: int, - ): - self.reader_workers = reader_workers - self.tokenizer_workers = tokenizer_workers - self.writer_worker = writer_worker - self.progress_logging_worker = progress_logging_worker - self.reader_q = reader_q - self.tokenizer_q = tokenizer_q - self.writer_q = writer_q - self.logging_message_q = logging_message_q - self._index_start = index_start - self._num_samples = num_samples - self.batch_size = batch_size - self._exception_buffer = [] - - if num_samples == -1: - # TODO accessing the reader directly is not nice, but we need to know the total number of samples - total_num_samples = len(self.reader_workers[0]._reader) - num_samples = total_num_samples - index_start - - def run(self): - # Not setting TOKENIZERS_PARALLELISM to false can cause deadlocks when using hf's "FastTokenizers". See also: - # https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning/67254879#67254879 - with temporary_env_var("TOKENIZERS_PARALLELISM", "false"): - start_time = time.time() - # populate the reader queue with the sample_ids that we want to tokenize - self._populate_reader_q( - index_start=self._index_start, - num_samples=self._num_samples, - num_reader_processes=len(self.reader_workers), - ) - - # start the progress logging worker - self.progress_logging_worker.start() - - # start the reader proceseses - for reader_worker in tqdm.tqdm(self.reader_workers, desc="Starting reader workers"): - reader_worker.start() - - # start the tokenizer processes - for tokenizer_worker in tqdm.tqdm(self.tokenizer_workers, desc="Starting tokenizer workers"): - tokenizer_worker.start() - - # start the writer process - self.writer_worker.start() - - # wait for all processes to finish - for reader_worker in tqdm.tqdm(self.reader_workers, desc="Stopping for reader workers"): - reader_worker.join() - - # stop the tokenizer processes - for _ in self.tokenizer_workers: - self.tokenizer_q.put(None) - for tokenizer_worker in tqdm.tqdm(self.tokenizer_workers, desc="Stopping tokenizer workers"): - tokenizer_worker.join() - - # stop the writer process - get_logger().info("Stopping writer worker.") - self.writer_q.put(None) - self.writer_worker.join() - - # stop the logging worker process - get_logger().info("Stopping progress logging worker.") - self.logging_message_q.put(None) - self.progress_logging_worker.join() - - end_time = time.time() - get_logger().info(f"Tokenization took {end_time - start_time} seconds.") - - if self._exception_buffer: - raise self._exception_buffer[0] - - - def _populate_reader_q(self, index_start: int, num_samples: int, num_reader_processes: int): - # populate the reader queue with the line_ids that we want to tokenize - - for i in tqdm.tqdm(range(index_start, index_start + num_samples, self.batch_size), desc="Filling up reader queue with line ids"): - self.reader_q.put((i, self.batch_size)) - for i in range(num_reader_processes): - self.reader_q.put(None) diff --git a/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py b/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py index 10bec2881..3ff40e135 100644 --- a/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py +++ b/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py @@ -1,19 +1,22 @@ -from dataclasses import dataclass import mmap import pickle from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum from pathlib import Path from typing import Optional -from modalities.exceptions import ReaderIndexationError + import numpy as np -from enum import Enum + +from modalities.exceptions import ReaderIndexationError + @dataclass class Sample: # If the index is not shuffled, then the incrementeal_line_id # points to the position in the dataset - # If the index is shuffled, then the incremental_line_id - # points to the position in the shuffled index and the + # If the index is shuffled, then the incremental_line_id + # points to the position in the shuffled index and the # shuffled_line_id points to the position in the original index incremental_line_id: int raw_data_path: Path @@ -21,6 +24,7 @@ class Sample: sample_length_in_bytes: int content_raw: str | bytes content_tokenized: Optional[bytes] = None + token_size_in_bytes: Optional[int] = None shuffled_line_id: Optional[int] = None @@ -142,7 +146,7 @@ def __getitem__(self, key: int) -> Sample: return Sample( raw_data_path=self.raw_data_path, incremental_line_id=key, - shuffled_line_id=key, # TODO so far we don't support shuffling here! + shuffled_line_id=key, # TODO so far we don't support shuffling here! offset=offset, sample_length_in_bytes=sample_length_in_bytes, content_raw=content, @@ -229,7 +233,7 @@ def __getitem__(self, key: int) -> Sample: """ try: - if self.global_shuffle_index is not None: + if self.global_shuffle_index is not None: mapped_key = self.global_shuffle_index[key] else: mapped_key = key @@ -260,7 +264,6 @@ class LargeFileLinesReaderTypes(Enum): class LargeFileLinesReaderFactory: - @staticmethod def get_local_reader( raw_data_path: Path, From 3d79cc4ea4b14d9d77cfb076f352d756fd76a732 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Sun, 5 Jan 2025 11:28:23 +0100 Subject: [PATCH 05/25] feat: added custom process implementation --- .../queued_processing/queued_processing.py | 98 +++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 src/modalities/dataloader/preprocessing/queued_processing/queued_processing.py diff --git a/src/modalities/dataloader/preprocessing/queued_processing/queued_processing.py b/src/modalities/dataloader/preprocessing/queued_processing/queued_processing.py new file mode 100644 index 000000000..b43b40ec7 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/queued_processing/queued_processing.py @@ -0,0 +1,98 @@ +import multiprocessing as mp +import queue +import traceback +from multiprocessing.synchronize import Event +from typing import Any, Optional + +from modalities.dataloader.preprocessing.queued_processing.processing_strategy_if import ProcessingStrategyIF +from modalities.exceptions import ProcessorStopEventException +from modalities.utils.logging import get_logger + + +class Processor(mp.Process): + class QueueConsumer: + def __init__(self, in_q: mp.Queue, in_q_timeout: int): + self._in_q = in_q + self._in_q_timeout = in_q_timeout + + def get_item(self, stop_event: Event) -> Any: + while not stop_event.is_set(): + try: + item = self._in_q.get(timeout=self._in_q_timeout) + except queue.Empty: + continue + return item + raise ProcessorStopEventException("Stop event was set") + + class QueueProducer: + def __init__(self, out_q: mp.Queue, out_q_timeout: int): + self._out_q = out_q + self._out_q_timeout = out_q_timeout + + def put_item(self, item: Any, stop_event: Event): + while not stop_event.is_set(): + try: + self._out_q.put(item, timeout=self._out_q_timeout) + except queue.Full: + continue + return + raise ProcessorStopEventException("Stop event was set") + + def __init__( + self, + in_q: mp.Queue, + out_qs: dict[str, mp.Queue], + in_q_timeout: int, + out_q_timeout: int, + strategy: ProcessingStrategyIF, + process_id: str, + process_type: str, + stop_event: Event, + logging_message_q_key: Optional[str] = None, + ): + super().__init__() + self._consumer = Processor.QueueConsumer(in_q, in_q_timeout) + self._producers: dict[str, Processor.QueueProducer] = { + q_key: Processor.QueueProducer(out_q, out_q_timeout) for q_key, out_q in out_qs.items() + } + + self._strategy = strategy + self._stop_event = stop_event + self._process_type = process_type + self._process_id = process_id + self._logging_message_q_key = logging_message_q_key + + def run(self): + with self._strategy: + while True: + try: + item = self._consumer.get_item(stop_event=self._stop_event) + except ProcessorStopEventException: + get_logger().info(f"{self._process_id} stopped due to forced stop event") + break + if item is None: + get_logger().info(f"{self._process_id} received regular poison pill, exiting...") + self._strategy.finalize() + break + try: + processed_sub_items: dict[str, Any] | None = self._strategy.process(item) + except Exception as e: + get_logger().error( + f"{self._process_type}:{self._process_id} failed to process item {item}. Error: {e}" + ) + stacktrace = traceback.format_exc() + get_logger().error(f"Stacktrace for {self._process_type}:{self._process_id} : {stacktrace}") + get_logger().error(f"{self._process_id} setting stop event and then exiting...") + self._stop_event.set() + break + + # if the strategy returns None, we don't have to put anything in any of the out_qs + if processed_sub_items is None: + continue + else: + # place the processed sub items in the correct out queues + for destination_q_key, processed_sub_item in processed_sub_items.items(): + if destination_q_key == self._logging_message_q_key: + processed_sub_item.process_id = self._process_id + processed_sub_item.process_type = self._process_type + self._producers[destination_q_key].put_item(processed_sub_item, stop_event=self._stop_event) From 8f41a74c8e5806991b3bd5902870eeee85e0366c Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Sun, 5 Jan 2025 11:28:57 +0100 Subject: [PATCH 06/25] feat: added interface for the processing strategies --- .../preprocessing/queued_processing/__init__.py | 0 .../queued_processing/processing_strategy_if.py | 16 ++++++++++++++++ 2 files changed, 16 insertions(+) create mode 100644 src/modalities/dataloader/preprocessing/queued_processing/__init__.py create mode 100644 src/modalities/dataloader/preprocessing/queued_processing/processing_strategy_if.py diff --git a/src/modalities/dataloader/preprocessing/queued_processing/__init__.py b/src/modalities/dataloader/preprocessing/queued_processing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/modalities/dataloader/preprocessing/queued_processing/processing_strategy_if.py b/src/modalities/dataloader/preprocessing/queued_processing/processing_strategy_if.py new file mode 100644 index 000000000..4fe52219e --- /dev/null +++ b/src/modalities/dataloader/preprocessing/queued_processing/processing_strategy_if.py @@ -0,0 +1,16 @@ +from abc import ABC +from typing import Any + + +class ProcessingStrategyIF(ABC): + def process(self, item: Any) -> dict[str, Any] | None: + raise NotImplementedError + + def __enter__(self): + raise NotImplementedError + + def finalize(self): + raise NotImplementedError + + def __exit__(self, exc_type, exc_value, traceback): + raise From a05bd9fc87f4861eb442a4579411ab7a581dd496 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Sun, 5 Jan 2025 11:30:19 +0100 Subject: [PATCH 07/25] refactor: bug fixing in TokenizationInstantiationModel --- src/modalities/config/instantiation_models.py | 35 +++++++++---------- src/modalities/exceptions.py | 8 ++++- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index d7df1ba71..764a2d8c9 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -1,9 +1,7 @@ -import os from pathlib import Path from typing import Annotated, Any, Optional -from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LargeFileLinesReaderTypes -from pydantic import BaseModel, ConfigDict, Field, FilePath, field_validator, model_validator, root_validator, validator +from pydantic import BaseModel, ConfigDict, Field, FilePath, field_validator, model_validator, root_validator from modalities.config.pydanctic_if_types import ( PydanticCheckpointSavingIFType, @@ -17,10 +15,10 @@ PydanticPytorchDeviceType, PydanticPytorchModuleType, PydanticTextInferenceComponentType, - PydanticTokenizerIFType, ) from modalities.config.utils import parse_torch_device from modalities.dataloader.dataset import Dataset +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LargeFileLinesReaderTypes from modalities.util import warn_rank_0 @@ -192,8 +190,7 @@ def _check_token_amount_in_dataset(self) -> "TrainingComponentsInstantiationMode return self -class PackedDatasetComponentsInstantiationModel(BaseModel): - +class TokenizationInstantiationModel(BaseModel): class ReaderWorkerSettings(BaseModel): class ReaderSettings(BaseModel): class LocalReaderArgs(BaseModel): @@ -210,31 +207,34 @@ class GlobalReaderArgs(BaseModel): reader_type: LargeFileLinesReaderTypes reader_args: LocalReaderArgs | GlobalReaderArgs - - num_reader_processes: Annotated[int, Field(strict=True, ge=1)] + + num_workers: Annotated[int, Field(strict=True, ge=1)] reader_settings: ReaderSettings - + class TokenizerWorkerSettings(BaseModel): class TokenizerSettings(BaseModel): - tokenizer: PydanticTokenizerIFType + class TokenizerInstantitionSettings(BaseModel): + tokenizer_component_key: str + tokenizer_variant_key: str + config: dict[str, Any] + + tokenizer_instantiation_settings: TokenizerInstantitionSettings eod_token: str jq_pattern: str - - num_tokenizer_processes: Annotated[int, Field(strict=True, ge=1)] + + num_workers: Annotated[int, Field(strict=True, ge=1)] tokenizer_settings: TokenizerSettings - class WriterWorkerSettings(BaseModel): dst_path: Path index_start: Annotated[int, Field(strict=True, ge=0)] - @field_validator("dst_path") def ensure_path_does_not_exist(cls, value): path = Path(value) # Convert to Path object if it's a string if path.exists(): raise ValueError(f"The filepath '{path}' already exists.") - return path + return path paths: dict[str, Path] reader_worker_settings: ReaderWorkerSettings @@ -246,9 +246,8 @@ def ensure_path_does_not_exist(cls, value): num_samples: Annotated[int, Field(strict=True, ge=1)] batch_size: Annotated[int, Field(strict=True, ge=1)] logging_interval: Annotated[int, Field(strict=True, ge=1)] - - - + in_q_timeout: Annotated[int, Field(strict=True, ge=0)] + out_q_timeout: Annotated[int, Field(strict=True, ge=0)] class TextGenerationInstantiationModel(BaseModel): diff --git a/src/modalities/exceptions.py b/src/modalities/exceptions.py index 1bc34a255..a18c27d20 100644 --- a/src/modalities/exceptions.py +++ b/src/modalities/exceptions.py @@ -25,8 +25,14 @@ class OptimizerError(Exception): class ConfigError(Exception): pass + class EmptySampleError(RuntimeError): pass + class ReaderIndexationError(Exception): - pass \ No newline at end of file + pass + + +class ProcessorStopEventException(Exception): + pass From ac8f74e4bdcc0044c2577ecb03960f2ea79f5685 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Sun, 5 Jan 2025 11:31:33 +0100 Subject: [PATCH 08/25] refactor: API uses now process controller --- src/modalities/api.py | 157 ++++++++++++--------- src/modalities/config/component_factory.py | 4 +- 2 files changed, 90 insertions(+), 71 deletions(-) diff --git a/src/modalities/api.py b/src/modalities/api.py index 867da0337..ac9347bab 100644 --- a/src/modalities/api.py +++ b/src/modalities/api.py @@ -1,35 +1,33 @@ #!/usr/bin/env python +import multiprocessing as mp import os +from enum import Enum from pathlib import Path -from typing import Optional -from modalities.dataloader.preprocessing.tokenization.create_packed_data import PackedDataGenerator -from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import ( - EmbeddedStreamData, - join_embedded_stream_data, -) -from modalities.dataloader.preprocessing.tokenization.tokenization_processes import ( - ProcessFactory, - ProgressLoggingWorker, - get_required_num_of_bytes_to_repr, -) -from modalities.utils.logging import get_logger from pydantic import FilePath import modalities.inference.inference as inference from modalities.checkpointing.checkpoint_conversion import CheckpointConversion from modalities.config.component_factory import ComponentFactory -from modalities.config.instantiation_models import PackedDatasetComponentsInstantiationModel +from modalities.config.instantiation_models import TokenizationInstantiationModel from modalities.dataloader.preprocessing.indexation.create_index import IndexGenerator +from modalities.dataloader.preprocessing.queued_processing.process_controller import PipelineStep, ProcessController +from modalities.dataloader.preprocessing.queued_processing.queued_processing import Processor +from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import ( + EmbeddedStreamData, + join_embedded_stream_data, +) from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LocalLargeFileLinesReader +from modalities.dataloader.preprocessing.tokenization.tokenization_strategies import ( + ProcessingStrategyFactory, + WorkerTypes, + populate_reader_q, +) from modalities.models.huggingface_adapters.hf_adapter import HFModelAdapter from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry -import multiprocessing as mp -import shutil - -from enum import Enum +from modalities.utils.logging import get_logger class FileExistencePolicy(Enum): @@ -122,70 +120,91 @@ def pack_encoded_data(config_dict: dict): # ResolverRegistry to work dynamically with any type-hinted config object from config.py. registry = Registry(COMPONENTS) component_factory = ComponentFactory(registry=registry) - instantion_model: PackedDatasetComponentsInstantiationModel = component_factory.build_components( - config_dict=config_dict, components_model_type=PackedDatasetComponentsInstantiationModel + instantion_model: TokenizationInstantiationModel = component_factory.build_components( + config_dict=config_dict, components_model_type=TokenizationInstantiationModel ) # build the queues - reader_q, tokenizer_q, writer_q, logging_message_q = ProcessFactory.get_process_queues( + reader_q, tokenizer_q, writer_q, logging_message_q = ProcessingStrategyFactory.get_process_queues( writer_q_maxsize=instantion_model.writer_q_maxsize, tokenizer_q_maxsize=instantion_model.tokenizer_q_maxsize ) # build the workers stop_event = mp.Event() - token_size_in_bytes = get_required_num_of_bytes_to_repr( - instantion_model.tokenizer_worker_settings.tokenizer_settings.tokenizer.vocab_size - ) - - reader_workers = ProcessFactory.get_reader_workers( - rw_settings=instantion_model.reader_worker_settings, - reader_q=reader_q, - tokenizer_q=tokenizer_q, - logging_message_q=logging_message_q, - stop_event=stop_event, - ) - tokenizer_workers = ProcessFactory.get_tokenizer_workers( - tw_settings=instantion_model.tokenizer_worker_settings, - tokenizer_q=tokenizer_q, - writer_q=writer_q, - logging_message_q=logging_message_q, - token_size_in_bytes=token_size_in_bytes, + tokenizer_q_key = "tokenizer_q" + writer_q_key = "writer_q" + logging_message_q_key = "logging_message_q" + + reader_settings = instantion_model.reader_worker_settings.reader_settings + + reader_workers = [ + Processor( + in_q=reader_q, + out_qs={tokenizer_q_key: tokenizer_q, logging_message_q_key: logging_message_q}, + in_q_timeout=instantion_model.in_q_timeout, + out_q_timeout=instantion_model.out_q_timeout, + strategy=ProcessingStrategyFactory.get_reader_strategy( + reader_settings, tokenizer_q_key=tokenizer_q_key, logging_message_q_key=logging_message_q_key + ), + process_type=WorkerTypes.READER, + process_id=i, + logging_message_q_key=logging_message_q_key, + stop_event=stop_event, + ) + for i in range(instantion_model.reader_worker_settings.num_workers) + ] + + tokenizer_workers = [ + Processor( + in_q=tokenizer_q, + out_qs={writer_q_key: writer_q, logging_message_q_key: logging_message_q}, + in_q_timeout=instantion_model.in_q_timeout, + out_q_timeout=instantion_model.out_q_timeout, + strategy=ProcessingStrategyFactory.get_tokenizer_strategy( + tokenizer_settings=instantion_model.tokenizer_worker_settings.tokenizer_settings, + writer_q_key=writer_q_key, + logging_message_q_key=logging_message_q_key, + ), + process_type=WorkerTypes.TOKENIZER, + process_id=i, + logging_message_q_key=logging_message_q_key, + stop_event=stop_event, + ) + for i in range(instantion_model.tokenizer_worker_settings.num_workers) + ] + + writer_worker = Processor( + in_q=writer_q, + out_qs={logging_message_q_key: logging_message_q}, + in_q_timeout=instantion_model.in_q_timeout, + out_q_timeout=instantion_model.out_q_timeout, + strategy=ProcessingStrategyFactory.get_writing_strategy( + ww_settings=instantion_model.writer_worker_settings, logging_message_q_key=logging_message_q_key + ), + process_type=WorkerTypes.WRITER, + process_id=0, + logging_message_q_key=logging_message_q_key, stop_event=stop_event, ) - writer_worker = ProcessFactory.get_writer_worker( - writer_q=writer_q, - logging_message_q=logging_message_q, - token_size_in_bytes=token_size_in_bytes, - ww_settings=instantion_model.writer_worker_settings, - stop_event=stop_event, - ) - - progress_logging_worker = ProgressLoggingWorker( - logging_message_q=logging_message_q, - reader_q=reader_q, - tokenizer_q=tokenizer_q, - writer_q=writer_q, - total_num_samples=instantion_model.num_samples, - stop_event=stop_event, - logging_interval=instantion_model.logging_interval, - ) - - generator = PackedDataGenerator( - reader_workers=reader_workers, - tokenizer_workers=tokenizer_workers, - writer_worker=writer_worker, - progress_logging_worker=progress_logging_worker, - reader_q=reader_q, - tokenizer_q=tokenizer_q, - writer_q=writer_q, - logging_message_q=logging_message_q, - index_start=instantion_model.index_start, - num_samples=instantion_model.num_samples, - batch_size=instantion_model.batch_size, - ) - generator.run() + pipeline_steps = [ + PipelineStep(name="reading", input_queue=reader_q, processors=reader_workers), + PipelineStep(name="tokenizing", input_queue=tokenizer_q, processors=tokenizer_workers), + PipelineStep(name="writing", input_queue=writer_q, processors=[writer_worker]), + ] + + def populate(): + populate_reader_q( + reader_q=reader_q, + index_start=instantion_model.index_start, + num_samples=instantion_model.num_samples, + num_reader_processes=instantion_model.reader_worker_settings.num_workers, + batch_size=instantion_model.batch_size, + ) + + process_controller = ProcessController(pipeline_steps=pipeline_steps, populate_jobs=populate) + process_controller.run() def merge_packed_data_files(src_paths: list[Path], target_path: Path): diff --git a/src/modalities/config/component_factory.py b/src/modalities/config/component_factory.py index c8ff89896..e284a52e2 100644 --- a/src/modalities/config/component_factory.py +++ b/src/modalities/config/component_factory.py @@ -75,7 +75,7 @@ def _build_component( # instantiate component config component_key = current_component_config["component_key"] variant_key = current_component_config["variant_key"] - current_component_config = self._instantiate_component_config( + current_component_config = self.instantiate_component_config( component_key=component_key, variant_key=variant_key, config_dict=materialized_component_config["config"], @@ -139,7 +139,7 @@ def _is_reference_config(config_dict: dict) -> bool: # TODO instead of field checks, we should introduce an enum for the config type. return {"instance_key", "pass_type"} == config_dict.keys() - def _instantiate_component_config(self, component_key: str, variant_key: str, config_dict: dict) -> BaseModel: + def instantiate_component_config(self, component_key: str, variant_key: str, config_dict: dict) -> BaseModel: component_config_type: Type[BaseModel] = self.registry.get_config(component_key, variant_key) self._assert_valid_config_keys( component_key=component_key, From 95f30d453d9fe66108448affb38e803f8b19c09b Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 8 Jan 2025 11:33:07 +0100 Subject: [PATCH 09/25] feat: added multiple queue destinations for strategy --- .../queued_processing/process_controller.py | 7 ++++-- .../queued_processing/queued_processing.py | 24 ++++++++++++------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py b/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py index de467fd39..496f9ee6e 100644 --- a/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py +++ b/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py @@ -2,6 +2,8 @@ from dataclasses import dataclass from typing import Callable +import tqdm + from modalities.dataloader.preprocessing.queued_processing.queued_processing import Processor from modalities.utils.logging import get_logger @@ -35,9 +37,10 @@ def run(self): # wait for the processors to finish for step in self._pipeline_steps: - for _ in step.processors: + get_logger().info(f"Stopping {step.name} processes...") + for _ in tqdm.tqdm(step.processors, desc=f"Poisoning {step.name} processes"): step.input_queue.put(None) get_logger().info(f"Waiting for processors in step {step.name} to finish") - for processor in step.processors: + for processor in tqdm.tqdm(step.processors, desc=f"Joining {step.name} processes"): processor.join() diff --git a/src/modalities/dataloader/preprocessing/queued_processing/queued_processing.py b/src/modalities/dataloader/preprocessing/queued_processing/queued_processing.py index b43b40ec7..618627886 100644 --- a/src/modalities/dataloader/preprocessing/queued_processing/queued_processing.py +++ b/src/modalities/dataloader/preprocessing/queued_processing/queued_processing.py @@ -68,10 +68,10 @@ def run(self): try: item = self._consumer.get_item(stop_event=self._stop_event) except ProcessorStopEventException: - get_logger().info(f"{self._process_id} stopped due to forced stop event") + get_logger().info(f"{self._process_type}:{self._process_id} received forced stop event") break if item is None: - get_logger().info(f"{self._process_id} received regular poison pill, exiting...") + get_logger().info(f"{self._process_type}:{self._process_id} received regular poison pill") self._strategy.finalize() break try: @@ -90,9 +90,17 @@ def run(self): if processed_sub_items is None: continue else: - # place the processed sub items in the correct out queues - for destination_q_key, processed_sub_item in processed_sub_items.items(): - if destination_q_key == self._logging_message_q_key: - processed_sub_item.process_id = self._process_id - processed_sub_item.process_type = self._process_type - self._producers[destination_q_key].put_item(processed_sub_item, stop_event=self._stop_event) + try: + # place the processed sub items in the correct out queues + for destination_q_key, processed_sub_item in processed_sub_items.items(): + if destination_q_key == self._logging_message_q_key: + processed_sub_item.process_id = self._process_id + processed_sub_item.process_type = self._process_type + if destination_q_key == "writing_q_key": + continue + self._producers[destination_q_key].put_item(processed_sub_item, stop_event=self._stop_event) + + except ProcessorStopEventException: + get_logger().info(f"{self._process_type}:{self._process_id} received forced stop event") + break + get_logger().info(f"{self._process_type}:{self._process_id} exiting...") From d5d59a7b67aa454bf2fa274c4eb240a1328f3b1b Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 8 Jan 2025 11:34:11 +0100 Subject: [PATCH 10/25] feat: temp env var decorator --- src/modalities/utils/env_variables.py | 38 +++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/src/modalities/utils/env_variables.py b/src/modalities/utils/env_variables.py index 15c4630d5..89d7150df 100644 --- a/src/modalities/utils/env_variables.py +++ b/src/modalities/utils/env_variables.py @@ -1,11 +1,14 @@ import os from contextlib import contextmanager +from functools import wraps +from typing import Any + @contextmanager def temporary_env_var(key, value): """ Temporarily set an environment variable. - + Args: key (str): The environment variable name. value (str): The temporary value to set. @@ -19,4 +22,35 @@ def temporary_env_var(key, value): if original_value is None: del os.environ[key] else: - os.environ[key] = original_value \ No newline at end of file + os.environ[key] = original_value + + +def temporary_env_vars_decorator(env_vars: dict[str, Any]): + """ + Decorator to temporarily set multiple environment variables for the duration of a function call. + + Args: + env_vars (dict): A dictionary of environment variable names and their temporary values. + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + original_values = {} # Store original values of environment variables + try: + # Set the temporary environment variables + for key, value in env_vars.items(): + original_values[key] = os.environ.get(key) # Save original value + os.environ[key] = value # Set temporary value + return func(*args, **kwargs) # Execute the decorated function + finally: + # Restore original values or delete keys if not originally set + for key, original_value in original_values.items(): + if original_value is None: + del os.environ[key] + else: + os.environ[key] = original_value + + return wrapper + + return decorator From d242fd3db6fd2e4a08016baef23415f78eabe742 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Fri, 10 Jan 2025 00:48:18 +0100 Subject: [PATCH 11/25] refactor: moved WorkerTypes --- .../preprocessing/tokenization/worker_types.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 src/modalities/dataloader/preprocessing/tokenization/worker_types.py diff --git a/src/modalities/dataloader/preprocessing/tokenization/worker_types.py b/src/modalities/dataloader/preprocessing/tokenization/worker_types.py new file mode 100644 index 000000000..90ae24cf0 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/tokenization/worker_types.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class WorkerTypes(Enum): + POPULATOR = "POPULATOR" + READER = "READER" + TOKENIZER = "TOKENIZER" + WRITER = "WRITER" + LOGGING = "LOGGING" From 3222469951738b6ec0e37ebd845e71d9e2b66473 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Fri, 10 Jan 2025 00:51:44 +0100 Subject: [PATCH 12/25] refactor: refactored tokenization strategies --- ...kenization_strategies.py => strategies.py} | 116 ++++++++++-------- 1 file changed, 67 insertions(+), 49 deletions(-) rename src/modalities/dataloader/preprocessing/tokenization/{tokenization_strategies.py => strategies.py} (82%) diff --git a/src/modalities/dataloader/preprocessing/tokenization/tokenization_strategies.py b/src/modalities/dataloader/preprocessing/tokenization/strategies.py similarity index 82% rename from src/modalities/dataloader/preprocessing/tokenization/tokenization_strategies.py rename to src/modalities/dataloader/preprocessing/tokenization/strategies.py index 389546853..3c4a73559 100644 --- a/src/modalities/dataloader/preprocessing/tokenization/tokenization_strategies.py +++ b/src/modalities/dataloader/preprocessing/tokenization/strategies.py @@ -3,27 +3,25 @@ import os import pickle import time -from dataclasses import dataclass -from enum import Enum from io import BufferedWriter from pathlib import Path -from typing import Optional, Type +from typing import Type import jq -import tqdm from data_quality_ablations.utils.logging import get_logger from pydantic import BaseModel -from modalities.config.component_factory import ComponentFactory from modalities.config.instantiation_models import TokenizationInstantiationModel from modalities.dataloader.preprocessing.queued_processing.processing_strategy_if import ProcessingStrategyIF +from modalities.dataloader.preprocessing.queued_processing.queue_items import ProgressMessage, ReadingJob from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import EmbeddedStreamData from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import ( BaseReader, LargeFileLinesReaderFactory, LargeFileLinesReaderTypes, - Sample, ) +from modalities.dataloader.preprocessing.tokenization.queue_items import Sample +from modalities.dataloader.preprocessing.tokenization.worker_types import WorkerTypes from modalities.exceptions import EmptySampleError from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry @@ -53,37 +51,29 @@ def get_required_num_of_bytes_to_repr(int_to_get_repr: int) -> int: raise ValueError("Currently only support token byte sizes of 1, 2, and 4.") -def populate_reader_q( - reader_q: mp.Queue, index_start: int, num_samples: int, num_reader_processes: int, batch_size: int -): - # populate the reader queue with the line_ids that we want to tokenize - - for i in tqdm.tqdm( - range(index_start, index_start + num_samples, batch_size), desc="Filling up reader queue with line ids" +class PopulatingStrategy(ProcessingStrategyIF): + def __init__( + self, reader_q_key: str, logging_message_q_key: str, index_start: int, num_samples: int, batch_size: int ): - reader_q.put(ReadingJob(sample_id=i, batch_size=batch_size)) - for _ in range(num_reader_processes): - reader_q.put(None) - - -@dataclass -class ReadingJob: - sample_id: int - batch_size: int + self._reader_q_key = reader_q_key + self._logging_message_q_key = logging_message_q_key + self._batch_size = batch_size + self._reading_range = range(index_start, index_start + num_samples, batch_size) + def __enter__(self): + return self -class WorkerTypes(Enum): - READER = "READER" - TOKENIZER = "TOKENIZER" - WRITER = "WRITER" + def finalize(self): + pass + def __exit__(self, exc_type, exc_val, exc_tb): + pass -@dataclass -class ProgressMessage: - worker_type: WorkerTypes - num_samples: int - process_type: Optional[str] = None - process_id: Optional[str] = None + def process(self) -> dict[str, ReadingJob | ProgressMessage]: + sample_id = next(self._reading_range) + reading_job = ReadingJob(sample_id=sample_id, batch_size=self._batch_size) + progress_message = ProgressMessage(WorkerTypes.POPULATOR, num_samples=self._batch_size) + return {self._reader_q_key: reading_job, self._logging_message_q_key: progress_message} class ReadingStrategy(ProcessingStrategyIF): @@ -131,15 +121,17 @@ def __init__( def __enter__(self): registry = Registry(COMPONENTS) - component_factory = ComponentFactory(registry=registry) - self._tokenizer: TokenizerWrapper = component_factory.instantiate_component_config( + tokenizer_type: Type[TokenizerWrapper] = registry.get_component( component_key=self._tokenizer_instantiation_setings.tokenizer_component_key, variant_key=self._tokenizer_instantiation_setings.tokenizer_variant_key, - config_dict=self._tokenizer_instantiation_setings.config, ) + self._tokenizer: TokenizerWrapper = tokenizer_type(**self._tokenizer_instantiation_setings.config) + encoded_eod_token = self._tokenizer.get_token_id(self._eod_token) - self._encoded_eos_token_as_bytes = self._encoded_token_to_bytes(encoded_eod_token) self._token_size_in_bytes = get_required_num_of_bytes_to_repr(self._tokenizer.vocab_size) + self._encoded_eos_token_as_bytes = TokenizingStrategy._encoded_token_to_bytes( + token_size_in_bytes=self._token_size_in_bytes, encoded_token=encoded_eod_token + ) return self def finalize(self): @@ -155,7 +147,7 @@ def process(self, item: list[Sample]) -> dict[str, list[Sample] | ProgressMessag sample.content_tokenized = processed_line sample.token_size_in_bytes = self._token_size_in_bytes batch_processed.append(sample) - progress_message = ProgressMessage(WorkerTypes.TOKENIZER, self.process_id, len(batch_processed)) + progress_message = ProgressMessage(WorkerTypes.TOKENIZER, num_samples=len(batch_processed)) return {self._writer_q_key: batch_processed, self._logging_message_q_key: progress_message} def _process_line(self, line: str) -> bytes: @@ -163,14 +155,18 @@ def _process_line(self, line: str) -> bytes: jq_retrieved_text = self._jq_filter.input_text(line).first() if jq_retrieved_text is None: raise ValueError(f"jq was not able extract the text using the expression: {self._jq_filter}") - tokens = self.tokenizer.tokenize(jq_retrieved_text) + tokens = self._tokenizer.tokenize(jq_retrieved_text) if len(tokens) == 0: raise EmptySampleError("Received empty sample...") - return b"".join(map(self._encoded_token_to_bytes, tokens)) + self._encoded_eos_token_as_bytes + return ( + b"".join(map(self._encoded_token_to_bytes, [self._token_size_in_bytes] * len(tokens), tokens)) + + self._encoded_eos_token_as_bytes + ) - def _encoded_token_to_bytes(self, encoded_token: int) -> bytes: + @staticmethod + def _encoded_token_to_bytes(token_size_in_bytes: int, encoded_token: int) -> bytes: # Converts an encoded token to its bytes representaion. - return encoded_token.to_bytes(self._token_size_in_bytes, byteorder="little", signed=False) + return encoded_token.to_bytes(token_size_in_bytes, byteorder="little", signed=False) class WritingStrategy(ProcessingStrategyIF): @@ -203,7 +199,7 @@ def __enter__(self): def finalize(self): # check that the index list IS NOT empty and the batch_dict IS empty # i.e., all batches have been written to the file - if len(self._index_list) == 0 or len(self._batch_dict) >= 0: + if len(self._index_list) == 0 or len(self._batch_dict) > 0: raise ValueError( f"Could not finalize writing strategy. Index list is empty or batch_dict is not empty. " f"Index list: {len(self._index_list)}, batch_dict: {self._batch_dict.keys()}" @@ -248,8 +244,8 @@ def process(self, item: list[Sample]) -> dict[str, ProgressMessage]: batch, self._prev_line_id, self._curr_offset, self._index_list, self._dst_fd ) num_samples_written += len(batch) - progress_message = ProgressMessage(WorkerTypes.WRITER, self.process_id, num_samples_written) - return {self._logging_key: progress_message} + progress_message = ProgressMessage(WorkerTypes.WRITER, num_samples=num_samples_written) + return {self._logging_message_q_key: progress_message} # writes a batch received from the writer_q to the destination file @staticmethod @@ -308,7 +304,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def process(self, item: ProgressMessage) -> dict: self._add_progress_message(item) passed_time = time.time() - self._last_logged - if passed_time > self._logging_interval or self._last_step: + if passed_time > self._logging_interval: self._log_and_reset(passed_time) self._last_logged = time.time() @@ -351,9 +347,8 @@ def _log_and_reset(self, passed_time: int): logging_message += "\n" logging_message += "Queues: \n" - logging_message += f"\tReader queue: {self._reader_q.qsize()} batches (approx.)\n" - logging_message += f"\tTokenizer queue: {self._tokenizer_q.qsize()} batches (approx.)\n" - logging_message += f"\tWriter queue: {self._writer_q.qsize()} batches (approx.)\n" + for q_key, q in self._q_dict.items(): + logging_message += f"\t{q_key}: {q.qsize()} batches (approx.)\n" get_logger().info(logging_message) @@ -365,6 +360,18 @@ def _log_and_reset(self, passed_time: int): class ProcessingStrategyFactory: + @staticmethod + def get_populating_strategy( + reader_q_key: str, logging_message_q_key: str, index_start: int, num_samples: int, batch_size: int + ) -> PopulatingStrategy: + return PopulatingStrategy( + reader_q_key=reader_q_key, + logging_message_q_key=logging_message_q_key, + index_start=index_start, + num_samples=num_samples, + batch_size=batch_size, + ) + @staticmethod def get_reader_strategy( reader_settings: TokenizationInstantiationModel.ReaderWorkerSettings.ReaderSettings, @@ -395,7 +402,7 @@ def get_tokenizer_strategy( logging_message_q_key: str, ) -> TokenizingStrategy: tokenizing_strategy = TokenizingStrategy( - tokenizer_instantiation_setings=tokenizer_settings.tokenizer_instantiation_settings, + ti_settings=tokenizer_settings.tokenizer_instantiation_settings, eod_token=tokenizer_settings.eod_token, jq_pattern=tokenizer_settings.jq_pattern, writer_q_key=writer_q_key, @@ -414,6 +421,17 @@ def get_writing_strategy( ) return writing_strategy + def get_progress_logging_strategy( + logging_interval: int, + total_num_samples: int, + q_dict: dict[str, mp.Queue], + ) -> ProgressLoggingStrategy: + return ProgressLoggingStrategy( + logging_interval=logging_interval, + total_num_samples=total_num_samples, + q_dict=q_dict, + ) + @staticmethod def get_process_queues(tokenizer_q_maxsize: int, writer_q_maxsize) -> tuple[mp.Queue, mp.Queue, mp.Queue]: reader_q = mp.Queue() # containes line_ids to be read From be14b0d3e48b304844b121ccf5e79c17146d989d Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Fri, 10 Jan 2025 00:52:28 +0100 Subject: [PATCH 13/25] feat: added ProcessorException and ProcessingStrategyDoneException --- src/modalities/exceptions.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/modalities/exceptions.py b/src/modalities/exceptions.py index a18c27d20..03d0abc17 100644 --- a/src/modalities/exceptions.py +++ b/src/modalities/exceptions.py @@ -36,3 +36,11 @@ class ReaderIndexationError(Exception): class ProcessorStopEventException(Exception): pass + + +class ProcessorException(Exception): + pass + + +class ProcessingStrategyDoneException(Exception): + pass From 6f542873ea6bf4a0f7d1f426a7c363dd48009a89 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Fri, 10 Jan 2025 01:02:34 +0100 Subject: [PATCH 14/25] refactor: generalized multiprocesssing system --- .../queued_processing/process_controller.py | 54 +++++-- .../processing_strategy_if.py | 6 +- .../queued_processing/processors.py | 135 ++++++++++++++++++ .../queued_processing/queue_items.py | 17 +++ .../queued_processing/queued_processing.py | 106 -------------- .../preprocessing/tokenization/queue_items.py | 20 +++ 6 files changed, 215 insertions(+), 123 deletions(-) create mode 100644 src/modalities/dataloader/preprocessing/queued_processing/processors.py create mode 100644 src/modalities/dataloader/preprocessing/queued_processing/queue_items.py delete mode 100644 src/modalities/dataloader/preprocessing/queued_processing/queued_processing.py create mode 100644 src/modalities/dataloader/preprocessing/tokenization/queue_items.py diff --git a/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py b/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py index 496f9ee6e..3f3541d36 100644 --- a/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py +++ b/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py @@ -1,10 +1,10 @@ import multiprocessing as mp from dataclasses import dataclass -from typing import Callable +from multiprocessing.synchronize import Event import tqdm -from modalities.dataloader.preprocessing.queued_processing.queued_processing import Processor +from modalities.dataloader.preprocessing.queued_processing.processors import Processor from modalities.utils.logging import get_logger @@ -16,19 +16,50 @@ class PipelineStep: class ProcessController: - def __init__(self, pipeline_steps: list[PipelineStep], populate_jobs: Callable): + def __init__(self, pipeline_steps: list[PipelineStep], stop_event: Event, join_timeout: int = 5): """Initializes the ProcessController Each pipeline step contains a list of processors that retrieve the data from the input queue, process it and if necessary put it into the output queue of the next step. """ self._pipeline_steps = pipeline_steps - self._populate_jobs = populate_jobs + self._stop_event = stop_event + self._join_timeout = join_timeout - def run(self): - # add the jobs to the input queues - get_logger().info("Populating jobs") - self._populate_jobs() + def join_processors_in_step(self, step: PipelineStep): + """Joins the processors of a pipeline step + If the stop_event is set, the processors are terminated + """ + # poison the input queues of the processors + for _ in tqdm.tqdm(step.processors, desc=f"Poisoning {step.name} processes"): + if step.input_queue is not None: + step.input_queue.put(None) + + # join the processors + num_exits = 0 + while num_exits < len(step.processors): + processor = step.processors[num_exits] + if self._stop_event.is_set(): + try: + processor.terminate() + except Exception as e: + # if we can't terminate the processor, we continue with the next one + get_logger().error( + f"Error while terminating processor {processor.full_name}: {e}. " + "Continuing with the next processor." + ) + num_exits += 1 + continue + get_logger().info(f"Terminated processor {processor.full_name}") + num_exits += 1 + else: + try: + processor.join(timeout=self.join_timeout) + except TimeoutError: + continue + get_logger().info(f"Joined processor {processor.full_name}") + num_exits += 1 + def run(self): # start the processors for step in self._pipeline_steps: get_logger().info(f"Starting processors for step {step.name}") @@ -38,9 +69,4 @@ def run(self): # wait for the processors to finish for step in self._pipeline_steps: get_logger().info(f"Stopping {step.name} processes...") - for _ in tqdm.tqdm(step.processors, desc=f"Poisoning {step.name} processes"): - step.input_queue.put(None) - get_logger().info(f"Waiting for processors in step {step.name} to finish") - - for processor in tqdm.tqdm(step.processors, desc=f"Joining {step.name} processes"): - processor.join() + self.join_processors_in_step(step) diff --git a/src/modalities/dataloader/preprocessing/queued_processing/processing_strategy_if.py b/src/modalities/dataloader/preprocessing/queued_processing/processing_strategy_if.py index 4fe52219e..d243c20bc 100644 --- a/src/modalities/dataloader/preprocessing/queued_processing/processing_strategy_if.py +++ b/src/modalities/dataloader/preprocessing/queued_processing/processing_strategy_if.py @@ -1,9 +1,9 @@ from abc import ABC -from typing import Any +from typing import Any, Optional class ProcessingStrategyIF(ABC): - def process(self, item: Any) -> dict[str, Any] | None: + def process(self, item: Optional[Any] = None) -> dict[str, Any] | None: raise NotImplementedError def __enter__(self): @@ -13,4 +13,4 @@ def finalize(self): raise NotImplementedError def __exit__(self, exc_type, exc_value, traceback): - raise + raise NotImplementedError diff --git a/src/modalities/dataloader/preprocessing/queued_processing/processors.py b/src/modalities/dataloader/preprocessing/queued_processing/processors.py new file mode 100644 index 000000000..bb82e4f4b --- /dev/null +++ b/src/modalities/dataloader/preprocessing/queued_processing/processors.py @@ -0,0 +1,135 @@ +import multiprocessing as mp +import queue +import traceback +from multiprocessing.synchronize import Event +from typing import Any, Optional + +from modalities.dataloader.preprocessing.queued_processing.processing_strategy_if import ProcessingStrategyIF +from modalities.exceptions import ProcessingStrategyDoneException, ProcessorException, ProcessorStopEventException +from modalities.utils.logging import get_logger + + +class QueueConsumer: + def __init__(self, in_q: mp.Queue, in_q_timeout: int): + self._in_q = in_q + self._in_q_timeout = in_q_timeout + + def get_item(self, stop_event: Event) -> Any: + while not stop_event.is_set(): + try: + item = self._in_q.get(timeout=self._in_q_timeout) + except queue.Empty: + continue + return item + raise ProcessorStopEventException("Stop event was set") + + +class QueueProducer: + def __init__(self, out_q: mp.Queue, out_q_timeout: int): + self._out_q = out_q + self._out_q_timeout = out_q_timeout + + def put_item(self, item: Any, stop_event: Event): + while not stop_event.is_set(): + try: + self._out_q.put(item, timeout=self._out_q_timeout) + except queue.Full: + continue + return + raise ProcessorStopEventException("Stop event was set") + + +class Processor(mp.Process): + def __init__( + self, + out_qs: dict[str, mp.Queue], + in_q_timeout: int, + out_q_timeout: int, + strategy: ProcessingStrategyIF, + process_id: str, + process_type: str, + stop_event: Event, + set_stop_event_on_processing_error: bool, + in_q: mp.Queue = None, + logging_message_q_key: Optional[str] = None, + ): + super().__init__() + + self._consumer = QueueConsumer(in_q, in_q_timeout) if in_q is not None else None + self._producers: dict[str, QueueProducer] = { + q_key: QueueProducer(out_q, out_q_timeout) for q_key, out_q in out_qs.items() + } + self._strategy = strategy + self._stop_event = stop_event + self._process_type = process_type + self._process_id = process_id + self.exit_on_processing_error = set_stop_event_on_processing_error + self._logging_message_q_key = logging_message_q_key + + @property + def process_id(self) -> str: + return self._process_id + + @property + def process_type(self) -> str: + return self._process_type + + @property + def full_name(self) -> str: + return f"{self._process_type}:{self._process_id}" + + def _generate_item(self) -> dict[str, Any]: + processed_sub_items: dict[str, Any] = self._strategy.process() + return processed_sub_items + + def _process_item(self, item: Any) -> dict[str, Any] | None: + try: + if item is None: + get_logger().info(f"{self.full_name} received regular poison pill") + self._strategy.finalize() + processed_sub_items: dict[str, Any] | None = self._strategy.process(item) + except Exception as e: + get_logger().error(f"{self.full_name} failed to process item {item}. Error: {e}") + if self.exit_on_processing_error: + raise ProcessorException(f"{self.full_name} failed to process item {item}.") from e + else: + return None + return processed_sub_items + + def _forward_sub_items(self, processed_sub_items: dict[str, Any] | None): + if processed_sub_items is None: + return + # place the processed sub items in the correct out queues + for destination_q_key, processed_sub_item in processed_sub_items.items(): + if destination_q_key == self._logging_message_q_key: + processed_sub_item.process_id = self._process_id + processed_sub_item.process_type = self._process_type + self._producers[destination_q_key].put_item(processed_sub_item, stop_event=self._stop_event) + + def run(self): + try: + with self._strategy: + while True: + if self._consumer is None: + # if there is no consumer, we are the first processor and need to generate the items + try: + processed_sub_items: dict[str, Any] = self._generate_item() + except ProcessingStrategyDoneException: + get_logger().info(f"{self.full_name} received done. Exiting...") + break + else: + item = self._consumer.get_item(stop_event=self._stop_event) + processed_sub_items: dict[str, Any] = self._process_item(item) + self._forward_sub_items(processed_sub_items) + + except ProcessorStopEventException: + # if the stop event was set, some process in the pipeline failed and we need to exit + get_logger().info(f"{self.full_name} received forced stop event. Exiting...") + except Exception as e: + # in this block, every exception comes from this very process and we need to set the stop event + # to signal the other processes of the pipeline that something went wrong + stacktrace = traceback.format_exc() + get_logger().error(f"Stacktrace for {self.full_name} : {stacktrace}") + get_logger().error(f"{self.full_name} failed with error: {e}, setting stop event") + self._stop_event.set() + get_logger().error(f"{self.full_name} exiting...") diff --git a/src/modalities/dataloader/preprocessing/queued_processing/queue_items.py b/src/modalities/dataloader/preprocessing/queued_processing/queue_items.py new file mode 100644 index 000000000..3f99d39a4 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/queued_processing/queue_items.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Optional + + +@dataclass +class ReadingJob: + sample_id: int + batch_size: int + + +@dataclass +class ProgressMessage: + worker_type: Enum + num_samples: int + process_type: Optional[str] = None + process_id: Optional[str] = None diff --git a/src/modalities/dataloader/preprocessing/queued_processing/queued_processing.py b/src/modalities/dataloader/preprocessing/queued_processing/queued_processing.py deleted file mode 100644 index 618627886..000000000 --- a/src/modalities/dataloader/preprocessing/queued_processing/queued_processing.py +++ /dev/null @@ -1,106 +0,0 @@ -import multiprocessing as mp -import queue -import traceback -from multiprocessing.synchronize import Event -from typing import Any, Optional - -from modalities.dataloader.preprocessing.queued_processing.processing_strategy_if import ProcessingStrategyIF -from modalities.exceptions import ProcessorStopEventException -from modalities.utils.logging import get_logger - - -class Processor(mp.Process): - class QueueConsumer: - def __init__(self, in_q: mp.Queue, in_q_timeout: int): - self._in_q = in_q - self._in_q_timeout = in_q_timeout - - def get_item(self, stop_event: Event) -> Any: - while not stop_event.is_set(): - try: - item = self._in_q.get(timeout=self._in_q_timeout) - except queue.Empty: - continue - return item - raise ProcessorStopEventException("Stop event was set") - - class QueueProducer: - def __init__(self, out_q: mp.Queue, out_q_timeout: int): - self._out_q = out_q - self._out_q_timeout = out_q_timeout - - def put_item(self, item: Any, stop_event: Event): - while not stop_event.is_set(): - try: - self._out_q.put(item, timeout=self._out_q_timeout) - except queue.Full: - continue - return - raise ProcessorStopEventException("Stop event was set") - - def __init__( - self, - in_q: mp.Queue, - out_qs: dict[str, mp.Queue], - in_q_timeout: int, - out_q_timeout: int, - strategy: ProcessingStrategyIF, - process_id: str, - process_type: str, - stop_event: Event, - logging_message_q_key: Optional[str] = None, - ): - super().__init__() - self._consumer = Processor.QueueConsumer(in_q, in_q_timeout) - self._producers: dict[str, Processor.QueueProducer] = { - q_key: Processor.QueueProducer(out_q, out_q_timeout) for q_key, out_q in out_qs.items() - } - - self._strategy = strategy - self._stop_event = stop_event - self._process_type = process_type - self._process_id = process_id - self._logging_message_q_key = logging_message_q_key - - def run(self): - with self._strategy: - while True: - try: - item = self._consumer.get_item(stop_event=self._stop_event) - except ProcessorStopEventException: - get_logger().info(f"{self._process_type}:{self._process_id} received forced stop event") - break - if item is None: - get_logger().info(f"{self._process_type}:{self._process_id} received regular poison pill") - self._strategy.finalize() - break - try: - processed_sub_items: dict[str, Any] | None = self._strategy.process(item) - except Exception as e: - get_logger().error( - f"{self._process_type}:{self._process_id} failed to process item {item}. Error: {e}" - ) - stacktrace = traceback.format_exc() - get_logger().error(f"Stacktrace for {self._process_type}:{self._process_id} : {stacktrace}") - get_logger().error(f"{self._process_id} setting stop event and then exiting...") - self._stop_event.set() - break - - # if the strategy returns None, we don't have to put anything in any of the out_qs - if processed_sub_items is None: - continue - else: - try: - # place the processed sub items in the correct out queues - for destination_q_key, processed_sub_item in processed_sub_items.items(): - if destination_q_key == self._logging_message_q_key: - processed_sub_item.process_id = self._process_id - processed_sub_item.process_type = self._process_type - if destination_q_key == "writing_q_key": - continue - self._producers[destination_q_key].put_item(processed_sub_item, stop_event=self._stop_event) - - except ProcessorStopEventException: - get_logger().info(f"{self._process_type}:{self._process_id} received forced stop event") - break - get_logger().info(f"{self._process_type}:{self._process_id} exiting...") diff --git a/src/modalities/dataloader/preprocessing/tokenization/queue_items.py b/src/modalities/dataloader/preprocessing/tokenization/queue_items.py new file mode 100644 index 000000000..a9754fce3 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/tokenization/queue_items.py @@ -0,0 +1,20 @@ +from pathlib import Path +from typing import Optional + +from pydantic import BaseModel + + +class Sample(BaseModel): + # If the index is not shuffled, then the incrementeal_line_id + # points to the position in the dataset + # If the index is shuffled, then the incremental_line_id + # points to the position in the shuffled index and the + # shuffled_line_id points to the position in the original index + incremental_line_id: int + raw_data_path: Path + offset: int + sample_length_in_bytes: int + content_raw: str | bytes + content_tokenized: Optional[bytes] = None + token_size_in_bytes: Optional[int] = None + shuffled_line_id: Optional[int] = None From ac225968103daf3bf4584e236541944bfd13ecb0 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Fri, 10 Jan 2025 01:04:04 +0100 Subject: [PATCH 15/25] refactor: adapted the create_packed_data endpoint to fit the previous refactoring --- src/modalities/api.py | 76 +++++++++++++------ src/modalities/config/component_factory.py | 4 +- .../tokenization/large_file_lines_reader.py | 23 +----- 3 files changed, 57 insertions(+), 46 deletions(-) diff --git a/src/modalities/api.py b/src/modalities/api.py index ac9347bab..0391439a7 100644 --- a/src/modalities/api.py +++ b/src/modalities/api.py @@ -13,20 +13,17 @@ from modalities.config.instantiation_models import TokenizationInstantiationModel from modalities.dataloader.preprocessing.indexation.create_index import IndexGenerator from modalities.dataloader.preprocessing.queued_processing.process_controller import PipelineStep, ProcessController -from modalities.dataloader.preprocessing.queued_processing.queued_processing import Processor +from modalities.dataloader.preprocessing.queued_processing.processors import Processor from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import ( EmbeddedStreamData, join_embedded_stream_data, ) from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LocalLargeFileLinesReader -from modalities.dataloader.preprocessing.tokenization.tokenization_strategies import ( - ProcessingStrategyFactory, - WorkerTypes, - populate_reader_q, -) +from modalities.dataloader.preprocessing.tokenization.strategies import ProcessingStrategyFactory, WorkerTypes from modalities.models.huggingface_adapters.hf_adapter import HFModelAdapter from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry +from modalities.utils.env_variables import temporary_env_vars_decorator from modalities.utils.logging import get_logger @@ -102,6 +99,9 @@ def convert_pytorch_to_hf_checkpoint( return hf_model +# not setting this can cause deadlocks when using hf's "FastTokenizers". See also: +# https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning/67254879#67254879 +@temporary_env_vars_decorator({"TOKENIZERS_PARALLELISM": "false"}) def pack_encoded_data(config_dict: dict): """Packs and encodes an indexed, large jsonl-file. (see also `create_index` for more information) @@ -111,13 +111,6 @@ def pack_encoded_data(config_dict: dict): Args: config_dict (dict): Dictionary containing the configuration for the packed data generation. """ - - # TODO: if we want to use alternative entrypoints together with the ResolverRegistry, - # we can currently not rely on the existing class resolver. - # This is based on its connection to the overall `AppConfig`. - # One would requires an object of it to instantiate the ResolverRegistry. - # This could get resolved by implementing on own ResolverRegistry for each entrypoint or adapting the existing - # ResolverRegistry to work dynamically with any type-hinted config object from config.py. registry = Registry(COMPONENTS) component_factory = ComponentFactory(registry=registry) instantion_model: TokenizationInstantiationModel = component_factory.build_components( @@ -131,13 +124,30 @@ def pack_encoded_data(config_dict: dict): # build the workers stop_event = mp.Event() - + reader_q_key = "reader_q" tokenizer_q_key = "tokenizer_q" writer_q_key = "writer_q" logging_message_q_key = "logging_message_q" - reader_settings = instantion_model.reader_worker_settings.reader_settings + populating_worker = Processor( + out_qs={reader_q_key: reader_q, logging_message_q_key: logging_message_q}, + in_q_timeout=instantion_model.in_q_timeout, + out_q_timeout=instantion_model.out_q_timeout, + strategy=ProcessingStrategyFactory.get_populating_strategy( + reader_q_key=reader_q_key, + logging_message_q_key=logging_message_q_key, + index_start=instantion_model.index_start, + num_samples=instantion_model.num_samples, + batch_size=instantion_model.batch_size, + ), + process_type=WorkerTypes.POPULATOR, + process_id=0, + logging_message_q_key=logging_message_q_key, + set_stop_event_on_processing_error=True, + stop_event=stop_event, + ) + reader_settings = instantion_model.reader_worker_settings.reader_settings reader_workers = [ Processor( in_q=reader_q, @@ -150,6 +160,7 @@ def pack_encoded_data(config_dict: dict): process_type=WorkerTypes.READER, process_id=i, logging_message_q_key=logging_message_q_key, + set_stop_event_on_processing_error=False, stop_event=stop_event, ) for i in range(instantion_model.reader_worker_settings.num_workers) @@ -169,6 +180,7 @@ def pack_encoded_data(config_dict: dict): process_type=WorkerTypes.TOKENIZER, process_id=i, logging_message_q_key=logging_message_q_key, + set_stop_event_on_processing_error=False, stop_event=stop_event, ) for i in range(instantion_model.tokenizer_worker_settings.num_workers) @@ -185,25 +197,39 @@ def pack_encoded_data(config_dict: dict): process_type=WorkerTypes.WRITER, process_id=0, logging_message_q_key=logging_message_q_key, + set_stop_event_on_processing_error=True, + stop_event=stop_event, + ) + + logging_worker = Processor( + in_q=logging_message_q, + out_qs={}, + in_q_timeout=instantion_model.in_q_timeout, + out_q_timeout=instantion_model.out_q_timeout, + strategy=ProcessingStrategyFactory.get_progress_logging_strategy( + logging_interval=instantion_model.logging_interval, + total_num_samples=instantion_model.num_samples, + q_dict={ + reader_q_key: reader_q, + tokenizer_q_key: tokenizer_q, + writer_q_key: writer_q, + logging_message_q_key: logging_message_q, + }, + ), + process_type=WorkerTypes.LOGGING, + process_id=0, stop_event=stop_event, ) pipeline_steps = [ + PipelineStep(name="populating", input_queue=None, processors=[populating_worker]), PipelineStep(name="reading", input_queue=reader_q, processors=reader_workers), PipelineStep(name="tokenizing", input_queue=tokenizer_q, processors=tokenizer_workers), PipelineStep(name="writing", input_queue=writer_q, processors=[writer_worker]), + PipelineStep(name="logging", input_queue=logging_message_q, processors=[logging_worker]), ] - def populate(): - populate_reader_q( - reader_q=reader_q, - index_start=instantion_model.index_start, - num_samples=instantion_model.num_samples, - num_reader_processes=instantion_model.reader_worker_settings.num_workers, - batch_size=instantion_model.batch_size, - ) - - process_controller = ProcessController(pipeline_steps=pipeline_steps, populate_jobs=populate) + process_controller = ProcessController(pipeline_steps=pipeline_steps) process_controller.run() diff --git a/src/modalities/config/component_factory.py b/src/modalities/config/component_factory.py index e284a52e2..c8ff89896 100644 --- a/src/modalities/config/component_factory.py +++ b/src/modalities/config/component_factory.py @@ -75,7 +75,7 @@ def _build_component( # instantiate component config component_key = current_component_config["component_key"] variant_key = current_component_config["variant_key"] - current_component_config = self.instantiate_component_config( + current_component_config = self._instantiate_component_config( component_key=component_key, variant_key=variant_key, config_dict=materialized_component_config["config"], @@ -139,7 +139,7 @@ def _is_reference_config(config_dict: dict) -> bool: # TODO instead of field checks, we should introduce an enum for the config type. return {"instance_key", "pass_type"} == config_dict.keys() - def instantiate_component_config(self, component_key: str, variant_key: str, config_dict: dict) -> BaseModel: + def _instantiate_component_config(self, component_key: str, variant_key: str, config_dict: dict) -> BaseModel: component_config_type: Type[BaseModel] = self.registry.get_config(component_key, variant_key) self._assert_valid_config_keys( component_key=component_key, diff --git a/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py b/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py index 3ff40e135..219453be5 100644 --- a/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py +++ b/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py @@ -1,33 +1,16 @@ import mmap import pickle from abc import ABC, abstractmethod -from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import Optional import numpy as np +from modalities.dataloader.preprocessing.tokenization.queue_items import Sample from modalities.exceptions import ReaderIndexationError -@dataclass -class Sample: - # If the index is not shuffled, then the incrementeal_line_id - # points to the position in the dataset - # If the index is shuffled, then the incremental_line_id - # points to the position in the shuffled index and the - # shuffled_line_id points to the position in the original index - incremental_line_id: int - raw_data_path: Path - offset: int - sample_length_in_bytes: int - content_raw: str | bytes - content_tokenized: Optional[bytes] = None - token_size_in_bytes: Optional[int] = None - shuffled_line_id: Optional[int] = None - - class BaseReader(ABC): @abstractmethod def __len__(self) -> int: @@ -245,7 +228,9 @@ def __getitem__(self, key: int) -> Sample: with open(abs_raw_file_path, "rb") as fd: raw_data_mmap = mmap.mmap(fd.fileno(), 0, access=mmap.ACCESS_READ) - content = raw_data_mmap[offset : offset + sample_length_in_bytes] + content = bytes(raw_data_mmap[offset : offset + sample_length_in_bytes]) + raw_data_mmap.close() # Explicitly close mmap + if self.encoding is not None: content = content.decode(self.encoding) return Sample( From f2422a1bdbec6edf4b2e9bc69e5307b403d80629 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Sat, 11 Jan 2025 00:49:02 +0100 Subject: [PATCH 16/25] refactor: fixed issues in joining processors --- src/modalities/api.py | 13 +++++----- .../queued_processing/process_controller.py | 25 +++++++++++++------ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/modalities/api.py b/src/modalities/api.py index 0391439a7..f53dc332f 100644 --- a/src/modalities/api.py +++ b/src/modalities/api.py @@ -218,18 +218,19 @@ def pack_encoded_data(config_dict: dict): ), process_type=WorkerTypes.LOGGING, process_id=0, + set_stop_event_on_processing_error=False, stop_event=stop_event, ) pipeline_steps = [ - PipelineStep(name="populating", input_queue=None, processors=[populating_worker]), - PipelineStep(name="reading", input_queue=reader_q, processors=reader_workers), - PipelineStep(name="tokenizing", input_queue=tokenizer_q, processors=tokenizer_workers), - PipelineStep(name="writing", input_queue=writer_q, processors=[writer_worker]), - PipelineStep(name="logging", input_queue=logging_message_q, processors=[logging_worker]), + PipelineStep(name="populating", input_queue=None, processors=[populating_worker], poisonable=False), + PipelineStep(name="reading", input_queue=reader_q, processors=reader_workers, poisonable=True), + PipelineStep(name="tokenizing", input_queue=tokenizer_q, processors=tokenizer_workers, poisonable=True), + PipelineStep(name="writing", input_queue=writer_q, processors=[writer_worker], poisonable=True), + PipelineStep(name="logging", input_queue=logging_message_q, processors=[logging_worker], poisonable=True), ] - process_controller = ProcessController(pipeline_steps=pipeline_steps) + process_controller = ProcessController(pipeline_steps=pipeline_steps, stop_event=stop_event) process_controller.run() diff --git a/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py b/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py index 3f3541d36..f5ff38a1a 100644 --- a/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py +++ b/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py @@ -11,6 +11,7 @@ @dataclass class PipelineStep: name: str + poisonable: bool input_queue: mp.Queue processors: list[Processor] @@ -30,14 +31,22 @@ def join_processors_in_step(self, step: PipelineStep): If the stop_event is set, the processors are terminated """ # poison the input queues of the processors - for _ in tqdm.tqdm(step.processors, desc=f"Poisoning {step.name} processes"): - if step.input_queue is not None: - step.input_queue.put(None) + if step.poisonable: + for _ in tqdm.tqdm(step.processors, desc=f"Poisoning {step.name} processes"): + if step.input_queue is not None: + step.input_queue.put(None) # join the processors num_exits = 0 while num_exits < len(step.processors): processor = step.processors[num_exits] + + # if the processor is not alive, we continue with the next one + if not processor.is_alive(): + get_logger().info(f"Processor {processor.full_name} is not alive. Continuing with the next processor.") + num_exits += 1 + continue + # if the stop event is set, we terminate the processor if self._stop_event.is_set(): try: processor.terminate() @@ -51,12 +60,14 @@ def join_processors_in_step(self, step: PipelineStep): continue get_logger().info(f"Terminated processor {processor.full_name}") num_exits += 1 + # if the stop event is not set, we join the processor else: - try: - processor.join(timeout=self.join_timeout) - except TimeoutError: + get_logger().info(f"Joining {processor.full_name} ...") + processor.join(timeout=self._join_timeout) + if processor.exitcode is None: + get_logger().info(f"Joining {processor.full_name} timed out. Exit code: {processor.exitcode} ...") continue - get_logger().info(f"Joined processor {processor.full_name}") + get_logger().info(f"Joined processor {processor.full_name}. Exit code: {processor.exitcode}") num_exits += 1 def run(self): From ad55dfb02076f75f491954ada2cd1926fea69df9 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Sat, 11 Jan 2025 00:50:09 +0100 Subject: [PATCH 17/25] refactor: finalized processors and strategy setup --- .../processing_strategy_if.py | 2 +- .../queued_processing/processors.py | 54 ++++++++++--------- .../preprocessing/tokenization/strategies.py | 10 ++-- 3 files changed, 36 insertions(+), 30 deletions(-) diff --git a/src/modalities/dataloader/preprocessing/queued_processing/processing_strategy_if.py b/src/modalities/dataloader/preprocessing/queued_processing/processing_strategy_if.py index d243c20bc..e2b29aec0 100644 --- a/src/modalities/dataloader/preprocessing/queued_processing/processing_strategy_if.py +++ b/src/modalities/dataloader/preprocessing/queued_processing/processing_strategy_if.py @@ -3,7 +3,7 @@ class ProcessingStrategyIF(ABC): - def process(self, item: Optional[Any] = None) -> dict[str, Any] | None: + def process(self, item: Optional[Any] = None) -> dict[str, Any]: raise NotImplementedError def __enter__(self): diff --git a/src/modalities/dataloader/preprocessing/queued_processing/processors.py b/src/modalities/dataloader/preprocessing/queued_processing/processors.py index bb82e4f4b..39ed11b4a 100644 --- a/src/modalities/dataloader/preprocessing/queued_processing/processors.py +++ b/src/modalities/dataloader/preprocessing/queued_processing/processors.py @@ -13,6 +13,7 @@ class QueueConsumer: def __init__(self, in_q: mp.Queue, in_q_timeout: int): self._in_q = in_q self._in_q_timeout = in_q_timeout + self._consumed_items = 0 def get_item(self, stop_event: Event) -> Any: while not stop_event.is_set(): @@ -20,6 +21,9 @@ def get_item(self, stop_event: Event) -> Any: item = self._in_q.get(timeout=self._in_q_timeout) except queue.Empty: continue + if item is None: + pass + self._consumed_items += 1 return item raise ProcessorStopEventException("Stop event was set") @@ -65,6 +69,8 @@ def __init__( self._process_id = process_id self.exit_on_processing_error = set_stop_event_on_processing_error self._logging_message_q_key = logging_message_q_key + # if the consumer is None, we are the first processor in the pipeline and we need to generate the items + self._processing_fun = self._generate_item if self._consumer is None else self._process_item @property def process_id(self) -> str: @@ -78,27 +84,32 @@ def process_type(self) -> str: def full_name(self) -> str: return f"{self._process_type}:{self._process_id}" - def _generate_item(self) -> dict[str, Any]: - processed_sub_items: dict[str, Any] = self._strategy.process() - return processed_sub_items - - def _process_item(self, item: Any) -> dict[str, Any] | None: + def _generate_item(self): + try: + processed_sub_items: dict[str, Any] = self._strategy.process() + except ProcessingStrategyDoneException as e: + self._strategy.finalize() + get_logger().info(f"{self.full_name} received done (iterator exhausted). Exiting...") + raise e + self._forward_sub_items(processed_sub_items) + + def _process_item(self): + item = self._consumer.get_item(stop_event=self._stop_event) + if item is None: + self._strategy.finalize() + raise ProcessingStrategyDoneException(f"{self.full_name} received done (poison pill).") + # process the item try: - if item is None: - get_logger().info(f"{self.full_name} received regular poison pill") - self._strategy.finalize() processed_sub_items: dict[str, Any] | None = self._strategy.process(item) except Exception as e: get_logger().error(f"{self.full_name} failed to process item {item}. Error: {e}") if self.exit_on_processing_error: raise ProcessorException(f"{self.full_name} failed to process item {item}.") from e - else: - return None - return processed_sub_items + return # continue with the next item + # forward the processed sub items to the respective queues + self._forward_sub_items(processed_sub_items) - def _forward_sub_items(self, processed_sub_items: dict[str, Any] | None): - if processed_sub_items is None: - return + def _forward_sub_items(self, processed_sub_items: dict[str, Any]): # place the processed sub items in the correct out queues for destination_q_key, processed_sub_item in processed_sub_items.items(): if destination_q_key == self._logging_message_q_key: @@ -110,18 +121,9 @@ def run(self): try: with self._strategy: while True: - if self._consumer is None: - # if there is no consumer, we are the first processor and need to generate the items - try: - processed_sub_items: dict[str, Any] = self._generate_item() - except ProcessingStrategyDoneException: - get_logger().info(f"{self.full_name} received done. Exiting...") - break - else: - item = self._consumer.get_item(stop_event=self._stop_event) - processed_sub_items: dict[str, Any] = self._process_item(item) - self._forward_sub_items(processed_sub_items) - + self._processing_fun() + except ProcessingStrategyDoneException: + pass except ProcessorStopEventException: # if the stop event was set, some process in the pipeline failed and we need to exit get_logger().info(f"{self.full_name} received forced stop event. Exiting...") diff --git a/src/modalities/dataloader/preprocessing/tokenization/strategies.py b/src/modalities/dataloader/preprocessing/tokenization/strategies.py index 3c4a73559..f45e14d04 100644 --- a/src/modalities/dataloader/preprocessing/tokenization/strategies.py +++ b/src/modalities/dataloader/preprocessing/tokenization/strategies.py @@ -22,7 +22,7 @@ ) from modalities.dataloader.preprocessing.tokenization.queue_items import Sample from modalities.dataloader.preprocessing.tokenization.worker_types import WorkerTypes -from modalities.exceptions import EmptySampleError +from modalities.exceptions import EmptySampleError, ProcessingStrategyDoneException from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper @@ -58,7 +58,7 @@ def __init__( self._reader_q_key = reader_q_key self._logging_message_q_key = logging_message_q_key self._batch_size = batch_size - self._reading_range = range(index_start, index_start + num_samples, batch_size) + self._reading_iter = iter(range(index_start, index_start + num_samples, batch_size)) def __enter__(self): return self @@ -70,7 +70,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass def process(self) -> dict[str, ReadingJob | ProgressMessage]: - sample_id = next(self._reading_range) + try: + sample_id = next(self._reading_iter) + except StopIteration: + raise ProcessingStrategyDoneException("PopulatingStrategy done.") reading_job = ReadingJob(sample_id=sample_id, batch_size=self._batch_size) progress_message = ProgressMessage(WorkerTypes.POPULATOR, num_samples=self._batch_size) return {self._reader_q_key: reading_job, self._logging_message_q_key: progress_message} @@ -307,6 +310,7 @@ def process(self, item: ProgressMessage) -> dict: if passed_time > self._logging_interval: self._log_and_reset(passed_time) self._last_logged = time.time() + return {} def _add_progress_message(self, progress_message: ProgressMessage): if progress_message.worker_type not in self._worker_to_pid_to_num_samples: From 4da0959f0f1b3fe8ad71734d27b9ce1d0bb096cb Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Sun, 12 Jan 2025 22:40:18 +0100 Subject: [PATCH 18/25] feat: enhanced tokenization instantiation with pydantic model --- src/modalities/api.py | 14 ++-- src/modalities/config/instantiation_models.py | 16 +++- src/modalities/config/pydanctic_if_types.py | 2 + .../tokenization/large_file_lines_reader.py | 26 +++++- .../preprocessing/tokenization/strategies.py | 10 ++- src/modalities/registry/components.py | 80 ++++++++++++------- 6 files changed, 104 insertions(+), 44 deletions(-) diff --git a/src/modalities/api.py b/src/modalities/api.py index f53dc332f..74181c95b 100644 --- a/src/modalities/api.py +++ b/src/modalities/api.py @@ -119,7 +119,9 @@ def pack_encoded_data(config_dict: dict): # build the queues reader_q, tokenizer_q, writer_q, logging_message_q = ProcessingStrategyFactory.get_process_queues( - writer_q_maxsize=instantion_model.writer_q_maxsize, tokenizer_q_maxsize=instantion_model.tokenizer_q_maxsize + reader_q_maxsize=instantion_model.reader_q_maxsize, + writer_q_maxsize=instantion_model.writer_q_maxsize, + tokenizer_q_maxsize=instantion_model.tokenizer_q_maxsize, ) # build the workers @@ -136,9 +138,9 @@ def pack_encoded_data(config_dict: dict): strategy=ProcessingStrategyFactory.get_populating_strategy( reader_q_key=reader_q_key, logging_message_q_key=logging_message_q_key, - index_start=instantion_model.index_start, - num_samples=instantion_model.num_samples, - batch_size=instantion_model.batch_size, + index_start=instantion_model.populate_worker_settings.index_start, + num_samples=instantion_model.populate_worker_settings.num_samples, + batch_size=instantion_model.populate_worker_settings.batch_size, ), process_type=WorkerTypes.POPULATOR, process_id=0, @@ -207,8 +209,8 @@ def pack_encoded_data(config_dict: dict): in_q_timeout=instantion_model.in_q_timeout, out_q_timeout=instantion_model.out_q_timeout, strategy=ProcessingStrategyFactory.get_progress_logging_strategy( - logging_interval=instantion_model.logging_interval, - total_num_samples=instantion_model.num_samples, + logging_interval=instantion_model.logging_worker_settings.logging_interval, + total_num_samples=instantion_model.logging_worker_settings.num_samples, q_dict={ reader_q_key: reader_q, tokenizer_q_key: tokenizer_q, diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index 764a2d8c9..16232d17e 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -191,6 +191,11 @@ def _check_token_amount_in_dataset(self) -> "TrainingComponentsInstantiationMode class TokenizationInstantiationModel(BaseModel): + class PopulateWorkerSettings(BaseModel): + num_samples: Annotated[int, Field(strict=True, ge=1)] + batch_size: Annotated[int, Field(strict=True, ge=1)] + index_start: Optional[Annotated[int, Field(strict=True, ge=0)]] = 0 + class ReaderWorkerSettings(BaseModel): class ReaderSettings(BaseModel): class LocalReaderArgs(BaseModel): @@ -236,16 +241,19 @@ def ensure_path_does_not_exist(cls, value): raise ValueError(f"The filepath '{path}' already exists.") return path + class LoggingWorkerSettings(BaseModel): + logging_interval: Annotated[int, Field(strict=True, ge=1)] + num_samples: Optional[Annotated[int, Field(strict=True, ge=1)]] = None + paths: dict[str, Path] + populate_worker_settings: PopulateWorkerSettings reader_worker_settings: ReaderWorkerSettings tokenizer_worker_settings: TokenizerWorkerSettings writer_worker_settings: WriterWorkerSettings + logging_worker_settings: LoggingWorkerSettings + reader_q_maxsize: Annotated[int, Field(strict=True, ge=1)] tokenizer_q_maxsize: Annotated[int, Field(strict=True, ge=1)] writer_q_maxsize: Annotated[int, Field(strict=True, ge=1)] - index_start: Annotated[int, Field(strict=True, ge=0)] - num_samples: Annotated[int, Field(strict=True, ge=1)] - batch_size: Annotated[int, Field(strict=True, ge=1)] - logging_interval: Annotated[int, Field(strict=True, ge=1)] in_q_timeout: Annotated[int, Field(strict=True, ge=0)] out_q_timeout: Annotated[int, Field(strict=True, ge=0)] diff --git a/src/modalities/config/pydanctic_if_types.py b/src/modalities/config/pydanctic_if_types.py index 3761eb8df..4256113ed 100644 --- a/src/modalities/config/pydanctic_if_types.py +++ b/src/modalities/config/pydanctic_if_types.py @@ -14,6 +14,7 @@ from modalities.checkpointing.checkpoint_saving import CheckpointSaving, CheckpointSavingExecutionABC from modalities.checkpointing.checkpoint_saving_strategies import CheckpointSavingStrategyIF from modalities.dataloader.dataloader import LLMDataLoader +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import BaseReaderIF from modalities.inference.text.inference_component import TextInferenceComponent from modalities.logging_broker.subscriber import MessageSubscriberIF from modalities.loss_functions import Loss @@ -65,3 +66,4 @@ def __get_pydantic_core_schema__( PydanticTextInferenceComponentType = Annotated[TextInferenceComponent, PydanticThirdPartyTypeIF(TextInferenceComponent)] PydanticGradientClipperIFType = Annotated[GradientClipperIF, PydanticThirdPartyTypeIF(GradientClipperIF)] PydanticModelInitializationIFType = Annotated[ModelInitializationIF, PydanticThirdPartyTypeIF(ModelInitializationIF)] +PydanticBaseReaderIFType = Annotated[BaseReaderIF, PydanticThirdPartyTypeIF(BaseReaderIF)] diff --git a/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py b/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py index 219453be5..1162a4cf1 100644 --- a/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py +++ b/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py @@ -6,12 +6,13 @@ from typing import Optional import numpy as np +from pydantic import BaseModel from modalities.dataloader.preprocessing.tokenization.queue_items import Sample from modalities.exceptions import ReaderIndexationError -class BaseReader(ABC): +class BaseReaderIF(ABC): @abstractmethod def __len__(self) -> int: raise NotImplementedError @@ -21,7 +22,7 @@ def __getitem__(self, key: int) -> Sample: raise NotImplementedError -class LocalLargeFileLinesReader(BaseReader): +class LocalLargeFileLinesReader(BaseReaderIF): """LargeFileLinesReader class that read lines from a large file efficiently.""" def __init__( @@ -144,7 +145,7 @@ def _read_from_raw_file(self, offset: int, sample_length_in_bytes: int) -> str | return data -class GlobalLargeFileLinesReader(BaseReader): +class GlobalLargeFileLinesReader(BaseReaderIF): """LargeFileLinesReader class that read lines from a large file efficiently.""" def __init__( @@ -248,6 +249,25 @@ class LargeFileLinesReaderTypes(Enum): GLOBAL = "GLOBAL" +class IndexTypes(Enum): + LOCAL = "LOCAL" + GLOBAL = "GLOBAL" + + +class LocalLargeFileLinesReaderConfig(BaseModel): + raw_data_path: Path + index_path: Optional[Path] = None + encoding: Optional[str] = "utf-8" + + +class GlobalLargeFileLinesReaderConfig(BaseModel): + global_inorder_index_path: Path + raw_data_file_list_path: Path + raw_data_root_path: Path + global_shuffle_index_path: Optional[Path] = None + encoding: Optional[str] = "utf-8" + + class LargeFileLinesReaderFactory: @staticmethod def get_local_reader( diff --git a/src/modalities/dataloader/preprocessing/tokenization/strategies.py b/src/modalities/dataloader/preprocessing/tokenization/strategies.py index f45e14d04..1d508498a 100644 --- a/src/modalities/dataloader/preprocessing/tokenization/strategies.py +++ b/src/modalities/dataloader/preprocessing/tokenization/strategies.py @@ -16,7 +16,7 @@ from modalities.dataloader.preprocessing.queued_processing.queue_items import ProgressMessage, ReadingJob from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import EmbeddedStreamData from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import ( - BaseReader, + BaseReaderIF, LargeFileLinesReaderFactory, LargeFileLinesReaderTypes, ) @@ -81,7 +81,7 @@ def process(self) -> dict[str, ReadingJob | ProgressMessage]: class ReadingStrategy(ProcessingStrategyIF): def __init__( - self, reader_type: Type[BaseReader], reader_args: BaseModel, tokenizer_q_key: str, logging_message_q_key: str + self, reader_type: Type[BaseReaderIF], reader_args: BaseModel, tokenizer_q_key: str, logging_message_q_key: str ): self._reader_type = reader_type self._reader_args = reader_args @@ -437,8 +437,10 @@ def get_progress_logging_strategy( ) @staticmethod - def get_process_queues(tokenizer_q_maxsize: int, writer_q_maxsize) -> tuple[mp.Queue, mp.Queue, mp.Queue]: - reader_q = mp.Queue() # containes line_ids to be read + def get_process_queues( + reader_q_maxsize: int, tokenizer_q_maxsize: int, writer_q_maxsize + ) -> tuple[mp.Queue, mp.Queue, mp.Queue, mp.Queue]: + reader_q = mp.Queue(maxsize=reader_q_maxsize) # containes line_ids to be read tokenizer_q = mp.Queue(maxsize=tokenizer_q_maxsize) # contains (line_id, line) pairs to be tokenized writer_q = mp.Queue(maxsize=writer_q_maxsize) # contains (line_id, tokenized_line) to be written to disc logging_message_q = mp.Queue() diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index ecda86995..f712a6fe1 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -54,6 +54,11 @@ from modalities.dataloader.dataloader_factory import DataloaderFactory from modalities.dataloader.dataset import DummyDatasetConfig from modalities.dataloader.dataset_factory import DatasetFactory +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import ( + GlobalLargeFileLinesReaderConfig, + LargeFileLinesReaderFactory, + LocalLargeFileLinesReaderConfig, +) from modalities.dataloader.samplers import ResumableDistributedSampler from modalities.logging_broker.subscriber_impl.subscriber_factory import ( ProgressSubscriberFactory, @@ -87,14 +92,16 @@ from modalities.utils.number_conversion import ( LocalNumBatchesFromNumSamplesConfig, LocalNumBatchesFromNumTokensConfig, - NumberConversion, NumberConversionFromCheckpointPathConfig, NumSamplesFromNumTokensConfig, + NumSamplesFromReaderConfig, NumStepsFromNumSamplesConfig, NumStepsFromNumTokensConfig, NumStepsFromRawDatasetIndexConfig, NumTokensFromNumStepsConfig, NumTokensFromPackedMemMapDatasetContinuousConfig, + PreprocessingNumberConversion, + TrainingNumberConversion, ) @@ -246,83 +253,102 @@ class ComponentEntity: "gradient_clipper", "fsdp_logging_only", FSDPLoggingOnlyGradientClipper, FSDPDummyGradientClipperConfig ), ComponentEntity("gradient_clipper", "dummy", DummyGradientClipper, DummyGradientClipperConfig), + # large file lines reader + ComponentEntity( + "large_file_lines_reader", + "local", + LargeFileLinesReaderFactory.get_local_reader, + LocalLargeFileLinesReaderConfig, + ), + ComponentEntity( + "large_file_lines_reader", + "global", + LargeFileLinesReaderFactory.get_local_reader, + GlobalLargeFileLinesReaderConfig, + ), # Number conversion ComponentEntity( - "number_conversion", + "training_number_conversion", "local_num_batches_from_num_samples", - NumberConversion.get_local_num_batches_from_num_samples, + TrainingNumberConversion.get_local_num_batches_from_num_samples, LocalNumBatchesFromNumSamplesConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "local_num_batches_from_num_tokens", - NumberConversion.get_local_num_batches_from_num_tokens, + TrainingNumberConversion.get_local_num_batches_from_num_tokens, LocalNumBatchesFromNumTokensConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_samples_from_num_tokens", - NumberConversion.get_num_samples_from_num_tokens, + TrainingNumberConversion.get_num_samples_from_num_tokens, NumSamplesFromNumTokensConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_steps_from_num_samples", - NumberConversion.get_num_steps_from_num_samples, + TrainingNumberConversion.get_num_steps_from_num_samples, NumStepsFromNumSamplesConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_steps_from_num_tokens", - NumberConversion.get_num_steps_from_num_tokens, + TrainingNumberConversion.get_num_steps_from_num_tokens, NumStepsFromNumTokensConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_tokens_from_num_steps", - NumberConversion.get_num_tokens_from_num_steps, + TrainingNumberConversion.get_num_tokens_from_num_steps, NumTokensFromNumStepsConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "last_step_from_checkpoint_path", - NumberConversion.get_last_step_from_checkpoint_path, + TrainingNumberConversion.get_last_step_from_checkpoint_path, NumberConversionFromCheckpointPathConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_seen_steps_from_checkpoint_path", - NumberConversion.get_num_seen_steps_from_checkpoint_path, + TrainingNumberConversion.get_num_seen_steps_from_checkpoint_path, NumberConversionFromCheckpointPathConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "global_num_seen_tokens_from_checkpoint_path", - NumberConversion.get_global_num_seen_tokens_from_checkpoint_path, + TrainingNumberConversion.get_global_num_seen_tokens_from_checkpoint_path, NumberConversionFromCheckpointPathConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_target_steps_from_checkpoint_path", - NumberConversion.get_num_target_steps_from_checkpoint_path, + TrainingNumberConversion.get_num_target_steps_from_checkpoint_path, NumberConversionFromCheckpointPathConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "global_num_target_tokens_from_checkpoint_path", - NumberConversion.get_global_num_target_tokens_from_checkpoint_path, + TrainingNumberConversion.get_global_num_target_tokens_from_checkpoint_path, NumberConversionFromCheckpointPathConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_tokens_from_packed_mem_map_dataset_continuous", - NumberConversion.get_num_tokens_from_packed_mem_map_dataset_continuous, + TrainingNumberConversion.get_num_tokens_from_packed_mem_map_dataset_continuous, NumTokensFromPackedMemMapDatasetContinuousConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_steps_from_raw_dataset_index", - NumberConversion.get_num_steps_from_raw_dataset_index, + TrainingNumberConversion.get_num_steps_from_raw_dataset_index, NumStepsFromRawDatasetIndexConfig, ), + ComponentEntity( + "preprocessing_number_conversion", + "num_samples", + PreprocessingNumberConversion.get_num_samples_from_reader, + NumSamplesFromReaderConfig, + ), ] From 004daeb077d8b5eee1de8ac5a09c869818124c17 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Sun, 12 Jan 2025 22:40:39 +0100 Subject: [PATCH 19/25] feat: added PreprocessingNumberConversion --- src/modalities/utils/number_conversion.py | 52 +++++++++++++++++------ 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/src/modalities/utils/number_conversion.py b/src/modalities/utils/number_conversion.py index 3dc56732b..201183b5f 100644 --- a/src/modalities/utils/number_conversion.py +++ b/src/modalities/utils/number_conversion.py @@ -1,10 +1,13 @@ import re +from functools import lru_cache from pathlib import Path from typing import Annotated from pydantic import BaseModel, Field +from modalities.config.pydanctic_if_types import PydanticBaseReaderIFType from modalities.dataloader.dataset_factory import DatasetFactory +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import BaseReaderIF class LocalNumBatchesFromNumSamplesConfig(BaseModel): @@ -67,7 +70,13 @@ class NumStepsFromRawDatasetIndexConfig(BaseModel): gradient_accumulation_steps: Annotated[int, Field(strict=True, gt=0)] -class NumberConversion: +class NumSamplesFromReaderConfig(BaseModel): + reader: PydanticBaseReaderIFType + index_start: Annotated[int, Field(strict=True, ge=0)] = 0 + num_samples: Annotated[int, Field(strict=True, ge=1)] = None + + +class TrainingNumberConversion: @staticmethod def _get_checkpoint_parameter_value(pattern: str, string: str) -> int: matches = re.findall(pattern, string) @@ -134,7 +143,7 @@ def get_local_num_batches_from_num_tokens( int: Number of local batches for single rank. """ global_num_samples = global_num_tokens // sequence_length - return NumberConversion.get_local_num_batches_from_num_samples( + return TrainingNumberConversion.get_local_num_batches_from_num_samples( num_ranks=num_ranks, global_num_samples=global_num_samples, local_micro_batch_size=local_micro_batch_size ) @@ -178,7 +187,7 @@ def get_num_steps_from_num_tokens( int: Number of steps. """ global_num_samples = global_num_tokens // sequence_length - return NumberConversion.get_num_steps_from_num_samples( + return TrainingNumberConversion.get_num_steps_from_num_samples( num_ranks=num_ranks, local_micro_batch_size=local_micro_batch_size, global_num_samples=global_num_samples, @@ -221,7 +230,7 @@ def get_last_step_from_checkpoint_path(checkpoint_path: Path) -> int: """ # Regex pattern to match 'num_steps_' followed by digits pattern = r"seen_steps_(\d+)" - num_seen_steps = NumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) + num_seen_steps = TrainingNumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) return num_seen_steps - 1 @staticmethod @@ -236,7 +245,7 @@ def get_num_seen_steps_from_checkpoint_path(checkpoint_path: Path) -> int: """ # Regex pattern to match 'num_steps_' followed by digits pattern = r"seen_steps_(\d+)" - num_seen_steps = NumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) + num_seen_steps = TrainingNumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) return num_seen_steps @staticmethod @@ -251,7 +260,7 @@ def get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path: Path) -> in """ # Regex pattern to match 'num_steps_' followed by digits pattern = r"seen_tokens_(\d+)" - num_seen_tokens = NumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) + num_seen_tokens = TrainingNumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) return num_seen_tokens @staticmethod @@ -266,16 +275,18 @@ def get_global_num_target_tokens_from_checkpoint_path(checkpoint_path: Path) -> """ # Regex pattern to match 'num_steps_' followed by digits pattern = r"target_tokens_(\d+)" - num_target_tokens = NumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) + num_target_tokens = TrainingNumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) return num_target_tokens @staticmethod def get_num_target_steps_from_checkpoint_path(checkpoint_path: Path) -> int: - tokens_per_step = NumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path) / ( - NumberConversion.get_last_step_from_checkpoint_path(checkpoint_path) + 1 + tokens_per_step = TrainingNumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path) / ( + TrainingNumberConversion.get_last_step_from_checkpoint_path(checkpoint_path) + 1 ) - global_num_target_tokens = NumberConversion.get_global_num_target_tokens_from_checkpoint_path(checkpoint_path) + global_num_target_tokens = TrainingNumberConversion.get_global_num_target_tokens_from_checkpoint_path( + checkpoint_path + ) num_target_steps = global_num_target_tokens // tokens_per_step if isinstance(num_target_steps, float) and not num_target_steps.is_integer(): @@ -315,7 +326,7 @@ def get_num_tokens_from_packed_mem_map_dataset_continuous( raw_data_path=dataset_path, sequence_length=sequence_length, sample_key="text" ) global_num_tokens_dataset = len(dataset) * sequence_length - num_steps = NumberConversion.get_num_steps_from_num_tokens( + num_steps = TrainingNumberConversion.get_num_steps_from_num_tokens( num_ranks=num_ranks, local_micro_batch_size=local_micro_batch_size, global_num_tokens=global_num_tokens_dataset, @@ -323,7 +334,7 @@ def get_num_tokens_from_packed_mem_map_dataset_continuous( gradient_accumulation_steps=gradient_accumulation_steps, ) - global_num_tokens_actual = NumberConversion.get_num_tokens_from_num_steps( + global_num_tokens_actual = TrainingNumberConversion.get_num_tokens_from_num_steps( num_ranks=num_ranks, local_micro_batch_size=local_micro_batch_size, sequence_length=sequence_length, @@ -356,10 +367,25 @@ def get_num_steps_from_raw_dataset_index( """ index = DatasetFactory.get_raw_index(raw_index_path=raw_index_path) num_samples = len(index) - num_steps = NumberConversion.get_num_steps_from_num_samples( + num_steps = TrainingNumberConversion.get_num_steps_from_num_samples( num_ranks=num_ranks, local_micro_batch_size=local_micro_batch_size, global_num_samples=num_samples, gradient_accumulation_steps=gradient_accumulation_steps, ) return num_steps + + +class PreprocessingNumberConversion: + @lru_cache(maxsize=128) + @staticmethod + def get_num_samples_from_reader(reader: BaseReaderIF, index_start: int = 0, num_samples: int = None): + max_num_samples = len(reader) - index_start + if num_samples is not None and num_samples > max_num_samples: + raise ValueError( + f"num_samples ({num_samples}) is greater than the maximum number of samples " + f"(len(large_file_lines_reader) - index_start = {max_num_samples})" + ) + if num_samples is None: + num_samples = max_num_samples + return num_samples From 8be1ad5dffa2988652203c7a850fb9c52368ca1e Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Mon, 13 Jan 2025 10:42:19 +0100 Subject: [PATCH 20/25] feat: integrated global index creation --- src/modalities/__main__.py | 28 ++++- src/modalities/api.py | 12 +- .../indexation/create_global_index.py | 112 ++++++++++++++++++ tests/utils/test_number_conversion.py | 44 ++++--- 4 files changed, 172 insertions(+), 24 deletions(-) create mode 100644 src/modalities/dataloader/preprocessing/indexation/create_global_index.py diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 164209dcc..0a88b1387 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -9,12 +9,13 @@ import click import click_pathlib -from modalities.utils.logging import get_logger from pydantic import BaseModel, FilePath from modalities.api import ( convert_pytorch_to_hf_checkpoint, - create_raw_data_index, + create_global_index, + create_local_index, + create_shuffled_global_index, generate_text, merge_packed_data_files, pack_encoded_data, @@ -35,6 +36,7 @@ from modalities.running_env.cuda_env import CudaEnv from modalities.trainer import Trainer from modalities.util import get_total_number_of_trainable_parameters, print_rank_0 +from modalities.utils.logging import get_logger @click.group() @@ -124,7 +126,7 @@ def data(): pass -@data.command(name="create_raw_index") +@data.command(name="create_local_index") @click.argument("src_path", type=Path) @click.option( "--index_path", @@ -132,7 +134,7 @@ def data(): default=None, help="output path for index. will use parent directory of src_path if none.", ) -def CMD_entry_point_data_create_raw_index(src_path: Path, index_path: Path): +def CMD_entry_point_data_create_local_index(src_path: Path, index_path: Path): """Utility CMD for indexing the content of a large jsonl-file. Background is the ability to further process the respective file without loading it, while splitting its content line-based. This step is necessary in advance of further processing like tokenization. @@ -145,7 +147,23 @@ def CMD_entry_point_data_create_raw_index(src_path: Path, index_path: Path): Raises: ValueError: If the index file already exists. """ - create_raw_data_index(src_path=src_path, index_path=index_path) + create_local_index(src_path=src_path, index_path=index_path) + + +@data.command(name="create_global_index") +@click.option("--file_list_path", type=Path, required=True) +@click.option("--root_index_path", type=Path, required=True) +@click.option("--global_index_root_path", type=Path, required=True) +def CMD_entry_point_create_global_index(file_list_path: Path, root_index_path: Path, global_index_root_path: Path): + create_global_index( + file_list_path=file_list_path, root_index_path=root_index_path, global_index_root_path=global_index_root_path + ) + + +@data.command(name="create_shuffled_global_index") +@click.option("--global_index_file_path", type=Path, required=True) +def CMD_entry_point_create_shuffled_global_index(global_index_file_path: Path): + create_shuffled_global_index(global_index_file_path=global_index_file_path) @data.command(name="pack_encoded_data") diff --git a/src/modalities/api.py b/src/modalities/api.py index 74181c95b..bb2ec101a 100644 --- a/src/modalities/api.py +++ b/src/modalities/api.py @@ -33,7 +33,7 @@ class FileExistencePolicy(Enum): OVERRIDE = "override" -def create_raw_data_index( +def create_local_index( src_path: Path, index_path: Path, file_existence_policy: FileExistencePolicy = FileExistencePolicy.ERROR ): """Creates the index file for the content of a large jsonl-file. The index file @@ -71,6 +71,16 @@ def create_raw_data_index( generator.create_index(index_path) +def create_global_index(file_list_path: Path, root_index_path: Path, global_index_root_path: Path) -> Path: + global_index_file_path = create_global_index(file_list_path, root_index_path, global_index_root_path) + return global_index_file_path + + +def create_shuffled_global_index(global_index_file_path: Path) -> Path: + global_shuffled_index_file_path = create_shuffled_global_index(global_index_file_path) + return global_shuffled_index_file_path + + def generate_text(config_file_path: FilePath): """Inference function to generate text with a given model. diff --git a/src/modalities/dataloader/preprocessing/indexation/create_global_index.py b/src/modalities/dataloader/preprocessing/indexation/create_global_index.py new file mode 100644 index 000000000..8add0f353 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/indexation/create_global_index.py @@ -0,0 +1,112 @@ +import pickle +from pathlib import Path + +import numpy as np +import tqdm + + +def _get_global_index_file_path(global_index_root_path: Path) -> Path: + global_index_file_path = global_index_root_path / f"{global_index_root_path.name}_inorder.idx" + return global_index_file_path + + +def _get_file_list(file_list_path: Path) -> list[Path]: + file_list: list[Path] = [] + with open(file_list_path, "r") as f: + for line in f: + file_list.append(Path(line.strip())) + return file_list + + +def _get_file_id_file_path_mappings(file_list: list[Path]) -> tuple[dict[Path, int], dict[int, Path]]: + file_path_to_id = {file_path.with_suffix(""): i for i, file_path in enumerate(file_list)} + id_to_file_path = {i: file_path.with_suffix("") for i, file_path in enumerate(file_list)} + return file_path_to_id, id_to_file_path + + +def _get_local_index_paths(file_list: list[Path], root_index_path: Path, global_index_root_path: Path) -> list[Path]: + local_index_paths = [ + path.with_suffix(".idx") + for path in file_list + if (root_index_path / path).is_relative_to(global_index_root_path) + ] + return local_index_paths + + +def _get_total_num_documents(local_index_paths: list[Path], root_index_path: Path) -> int: + num_documents = 0 + for local_index_path in tqdm.tqdm(local_index_paths, desc="Counting total number of documents"): + with open(root_index_path / local_index_path, "rb") as f: + index = pickle.load(f) + num_documents += len(index) + return num_documents + + +def _populate_global_index_array( + global_index_file_path: Path, + num_documents: int, + local_index_paths: list[Path], + root_index_path: Path, + file_path_to_id: dict[Path, int], +) -> np.memmap: + shape = (num_documents + 1, 3) + global_index_array = np.memmap(global_index_file_path, dtype="int64", mode="w+", shape=shape) + + # the first row is reserved for the shape of the array and whether rows are shuffled. + # + global_index_array[0] = np.array([*shape, 0]) + start_index = 1 + for local_index_path in tqdm.tqdm(local_index_paths, desc="Populating global index array"): + with open(root_index_path / local_index_path, "rb") as f: + local_index = pickle.load(f) + + local_index_array = np.array(local_index) + # add the file id to the local index + file_id = file_path_to_id[local_index_path.with_suffix("")] + local_index_array = np.insert(local_index_array, 0, file_id, axis=1) + + global_index_array[start_index : start_index + len(local_index_array)] = local_index_array + start_index += len(local_index_array) + global_index_array.flush() + return global_index_array + + +def create_global_index(file_list_path: Path, root_index_path: Path, global_index_root_path: Path) -> Path: + global_index_file_path = _get_global_index_file_path(global_index_root_path) + + file_list = _get_file_list(file_list_path) + + file_path_to_id, _ = _get_file_id_file_path_mappings(file_list) + local_index_paths = _get_local_index_paths(file_list, root_index_path, global_index_root_path) + num_documents = _get_total_num_documents(local_index_paths, root_index_path) + + _populate_global_index_array( + global_index_file_path, num_documents, local_index_paths, root_index_path, file_path_to_id + ) + return global_index_file_path + + +def create_shuffled_global_index(global_index_file_path: Path) -> Path: + global_shuffled_index_file_path = ( + global_index_file_path.parent / f"{global_index_file_path.stem.replace('inorder', 'shuffle_index')}.idx" + ) + print(global_shuffled_index_file_path) + + # global index array + num_rows, _, _ = np.memmap(global_index_file_path, dtype="int64", mode="r")[0:3] + + print(f"Shuffling {num_rows-1} global index indices") + # we count from 1 since the 0th row contains meta information (num_rows, num_cols, is_shuffled) + indices = np.arange(1, num_rows) + np.random.shuffle(indices) + + print(f"Writing out shuffled global index array with {num_rows} elements") + global_shuffled_index_array = np.memmap( + global_shuffled_index_file_path, dtype="int64", mode="w+", shape=(len(indices),) + ) + chunk_size = 10 + for i in tqdm.tqdm(range(0, len(indices), chunk_size)): + chunk_indices = indices[i : i + chunk_size] + global_shuffled_index_array[i : i + len(chunk_indices)] = chunk_indices + global_shuffled_index_array.flush() + return global_shuffled_index_file_path diff --git a/tests/utils/test_number_conversion.py b/tests/utils/test_number_conversion.py index f54d807ae..531aaad98 100644 --- a/tests/utils/test_number_conversion.py +++ b/tests/utils/test_number_conversion.py @@ -4,7 +4,7 @@ import pytest from modalities.dataloader.dataset_factory import DatasetFactory -from modalities.utils.number_conversion import NumberConversion +from modalities.utils.number_conversion import TrainingNumberConversion @pytest.mark.parametrize( @@ -15,7 +15,9 @@ def test_get_local_num_batches_from_num_samples( num_ranks: int, global_num_samples: int, local_micro_batch_size: int, expected: int ): assert ( - NumberConversion.get_local_num_batches_from_num_samples(num_ranks, global_num_samples, local_micro_batch_size) + TrainingNumberConversion.get_local_num_batches_from_num_samples( + num_ranks, global_num_samples, local_micro_batch_size + ) == expected ) @@ -28,7 +30,7 @@ def test_get_local_num_batches_from_num_tokens( num_ranks: int, global_num_tokens: int, sequence_length: int, local_micro_batch_size: int, expected: int ): assert ( - NumberConversion.get_local_num_batches_from_num_tokens( + TrainingNumberConversion.get_local_num_batches_from_num_tokens( num_ranks, global_num_tokens, sequence_length, local_micro_batch_size ) == expected @@ -47,7 +49,7 @@ def test_get_num_steps_from_num_samples( expected: int, ): assert ( - NumberConversion.get_num_steps_from_num_samples( + TrainingNumberConversion.get_num_steps_from_num_samples( num_ranks, local_micro_batch_size, global_num_samples, gradient_accumulation_steps ) == expected @@ -76,7 +78,7 @@ def test_get_num_steps_from_num_tokens( expected: int, ): assert ( - NumberConversion.get_num_steps_from_num_tokens( + TrainingNumberConversion.get_num_steps_from_num_tokens( num_ranks, local_micro_batch_size, global_num_tokens, sequence_length, gradient_accumulation_steps ) == expected @@ -101,7 +103,7 @@ def test_get_num_tokens_from_num_steps( expected: int, ): assert ( - NumberConversion.get_num_tokens_from_num_steps( + TrainingNumberConversion.get_num_tokens_from_num_steps( num_steps=num_steps, num_ranks=num_ranks, local_micro_batch_size=local_micro_batch_size, @@ -141,9 +143,9 @@ def test_get_last_step_from_checkpoint_path(checkpoint_path: Path, expected: int if expected_exception: # Expecting an exception for this test case with pytest.raises(expected_exception): - NumberConversion.get_last_step_from_checkpoint_path(checkpoint_path=checkpoint_path) + TrainingNumberConversion.get_last_step_from_checkpoint_path(checkpoint_path=checkpoint_path) else: - assert NumberConversion.get_last_step_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected + assert TrainingNumberConversion.get_last_step_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected @pytest.mark.parametrize( @@ -175,9 +177,12 @@ def test_get_num_seen_steps_from_checkpoint_path(checkpoint_path: Path, expected if expected_exception: # Expecting an exception for this test case with pytest.raises(expected_exception): - NumberConversion.get_num_seen_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) + TrainingNumberConversion.get_num_seen_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) else: - assert NumberConversion.get_num_seen_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected + assert ( + TrainingNumberConversion.get_num_seen_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) + == expected + ) @pytest.mark.parametrize( @@ -211,10 +216,10 @@ def test_get_global_num_seen_tokens_from_checkpoint_path( if expected_exception: # Expecting an exception for this test case with pytest.raises(expected_exception): - NumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) + TrainingNumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) else: assert ( - NumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) + TrainingNumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected ) @@ -250,10 +255,10 @@ def test_get_global_num_target_tokens_from_checkpoint_path( if expected_exception: # Expecting an exception for this test case with pytest.raises(expected_exception): - NumberConversion.get_global_num_target_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) + TrainingNumberConversion.get_global_num_target_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) else: assert ( - NumberConversion.get_global_num_target_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) + TrainingNumberConversion.get_global_num_target_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected ) @@ -287,9 +292,12 @@ def test_get_num_target_steps_from_checkpoint_path(checkpoint_path: Path, expect if expected_exception: # Expecting an exception for this test case with pytest.raises(expected_exception): - NumberConversion.get_num_target_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) + TrainingNumberConversion.get_num_target_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) else: - assert NumberConversion.get_num_target_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected + assert ( + TrainingNumberConversion.get_num_target_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) + == expected + ) @pytest.mark.parametrize( @@ -336,7 +344,7 @@ def test_get_num_tokens_from_packed_mem_map_dataset_continuous( ) assert ( - NumberConversion.get_num_tokens_from_packed_mem_map_dataset_continuous( + TrainingNumberConversion.get_num_tokens_from_packed_mem_map_dataset_continuous( dataset_path=dataset_path, sequence_length=sequence_length, num_ranks=num_ranks, @@ -369,7 +377,7 @@ def test_num_steps_from_raw_dataset_index( with open(raw_index_path, "rb") as f: index_length = len(pickle.load(f)) - num_steps_from_number_conversion = NumberConversion.get_num_steps_from_raw_dataset_index( + num_steps_from_number_conversion = TrainingNumberConversion.get_num_steps_from_raw_dataset_index( raw_index_path=raw_index_path, num_ranks=num_ranks, local_micro_batch_size=local_micro_batch_size, From 438858f494b8e81a5599d48192079d6d4d77150b Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Mon, 13 Jan 2025 17:32:57 +0100 Subject: [PATCH 21/25] chore: minor renaming fix --- tests/dataloader/yaml_configs/skipped_dataloader.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dataloader/yaml_configs/skipped_dataloader.yaml b/tests/dataloader/yaml_configs/skipped_dataloader.yaml index ddd81bbe1..b7f5910e8 100644 --- a/tests/dataloader/yaml_configs/skipped_dataloader.yaml +++ b/tests/dataloader/yaml_configs/skipped_dataloader.yaml @@ -29,7 +29,7 @@ train_dataset: sample_key: ${settings.referencing_keys.sample_key} skip_num_samples: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_samples_from_num_tokens config: num_tokens: ${settings.training.global_num_seen_tokens} From b5273b28d3e5a1db3e70f90d70834ce944d06328 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Mon, 13 Jan 2025 19:38:33 +0100 Subject: [PATCH 22/25] chore: fixed imports --- config_files/training/config_example_coca.yaml | 4 ++-- config_files/training/config_lorem_ipsum.yaml | 4 ++-- src/modalities/api.py | 9 ++++++--- ...create_global_index.py => global_indexation.py} | 0 .../{create_index.py => local_indexation.py} | 14 +++++++++----- tests/conftest.py | 2 +- .../yaml_configs/skipped_dataloader.yaml | 2 +- tests/end2end_tests/gpt2_train_num_steps_8.yaml | 4 ++-- 8 files changed, 23 insertions(+), 16 deletions(-) rename src/modalities/dataloader/preprocessing/indexation/{create_global_index.py => global_indexation.py} (100%) rename src/modalities/dataloader/preprocessing/indexation/{create_index.py => local_indexation.py} (93%) diff --git a/config_files/training/config_example_coca.yaml b/config_files/training/config_example_coca.yaml index be9060ee6..1f3aadb00 100644 --- a/config_files/training/config_example_coca.yaml +++ b/config_files/training/config_example_coca.yaml @@ -27,7 +27,7 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_num_steps config: num_steps: ${settings.training_target.num_target_steps} @@ -36,7 +36,7 @@ settings: sequence_length: ${settings.step_profile.sequence_length} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_num_samples config: num_ranks: ${settings.cuda_env.world_size} diff --git a/config_files/training/config_lorem_ipsum.yaml b/config_files/training/config_lorem_ipsum.yaml index 670610da6..231f59dc6 100644 --- a/config_files/training/config_lorem_ipsum.yaml +++ b/config_files/training/config_lorem_ipsum.yaml @@ -27,7 +27,7 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_packed_mem_map_dataset_continuous config: dataset_path: ${settings.paths.train_dataset_path} @@ -36,7 +36,7 @@ settings: local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_num_tokens config: num_ranks: ${settings.cuda_env.world_size} diff --git a/src/modalities/api.py b/src/modalities/api.py index bb2ec101a..ab75a9f48 100644 --- a/src/modalities/api.py +++ b/src/modalities/api.py @@ -7,11 +7,12 @@ from pydantic import FilePath +import modalities.dataloader.preprocessing.indexation.global_indexation as global_indexation import modalities.inference.inference as inference from modalities.checkpointing.checkpoint_conversion import CheckpointConversion from modalities.config.component_factory import ComponentFactory from modalities.config.instantiation_models import TokenizationInstantiationModel -from modalities.dataloader.preprocessing.indexation.create_index import IndexGenerator +from modalities.dataloader.preprocessing.indexation.local_indexation import IndexGenerator from modalities.dataloader.preprocessing.queued_processing.process_controller import PipelineStep, ProcessController from modalities.dataloader.preprocessing.queued_processing.processors import Processor from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import ( @@ -72,12 +73,14 @@ def create_local_index( def create_global_index(file_list_path: Path, root_index_path: Path, global_index_root_path: Path) -> Path: - global_index_file_path = create_global_index(file_list_path, root_index_path, global_index_root_path) + global_index_file_path = global_indexation.create_global_index( + file_list_path, root_index_path, global_index_root_path + ) return global_index_file_path def create_shuffled_global_index(global_index_file_path: Path) -> Path: - global_shuffled_index_file_path = create_shuffled_global_index(global_index_file_path) + global_shuffled_index_file_path = global_indexation.create_shuffled_global_index(global_index_file_path) return global_shuffled_index_file_path diff --git a/src/modalities/dataloader/preprocessing/indexation/create_global_index.py b/src/modalities/dataloader/preprocessing/indexation/global_indexation.py similarity index 100% rename from src/modalities/dataloader/preprocessing/indexation/create_global_index.py rename to src/modalities/dataloader/preprocessing/indexation/global_indexation.py diff --git a/src/modalities/dataloader/preprocessing/indexation/create_index.py b/src/modalities/dataloader/preprocessing/indexation/local_indexation.py similarity index 93% rename from src/modalities/dataloader/preprocessing/indexation/create_index.py rename to src/modalities/dataloader/preprocessing/indexation/local_indexation.py index 17266b91c..d4b5469ca 100644 --- a/src/modalities/dataloader/preprocessing/indexation/create_index.py +++ b/src/modalities/dataloader/preprocessing/indexation/local_indexation.py @@ -1,14 +1,14 @@ -import json import os import pickle as pkl import queue import threading import time from pathlib import Path + import jq +from tqdm import tqdm from modalities.utils.logging import get_logger -from tqdm import tqdm class IndexGenerator: @@ -104,12 +104,16 @@ def parse_line_as_json(line_id: int, line_start_byte_pos: int, line: bytes, jq_f if len(jq_retrieved_text) > 0: self._index_map.append((line_start_byte_pos, len(line))) else: - get_logger(name="main").warning(f'Faulty line {line_id} (no text) in {str(self.src_file)}, skipping...') + get_logger(name="main").warning( + f"Faulty line {line_id} (no text) in {str(self.src_file)}, skipping..." + ) else: if self.drop_faulty_entries: - get_logger(name="main").warning(f'Faulty line {line_id} (parsing error) in {str(self.src_file)}, skipping...') + get_logger(name="main").warning( + f"Faulty line {line_id} (parsing error) in {str(self.src_file)}, skipping..." + ) else: - get_logger(name="main").warning(f'Faulty line {line_id} (parsing error), stopping...') + get_logger(name="main").warning(f"Faulty line {line_id} (parsing error), stopping...") err = ValueError(f'Faulty line "{line} in {str(self.src_file)}') self._exception_buffer.append(err) diff --git a/tests/conftest.py b/tests/conftest.py index 7214dac84..25b5d3bce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,8 +11,8 @@ from modalities.checkpointing.checkpoint_saving import CheckpointSaving from modalities.config.config import load_app_config_dict -from modalities.dataloader.preprocessing.indexation.create_index import IndexGenerator from modalities.dataloader.dataloader import LLMDataLoader +from modalities.dataloader.preprocessing.indexation.local_indexation import IndexGenerator from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LocalLargeFileLinesReader from modalities.evaluator import Evaluator from modalities.logging_broker.publisher import MessagePublisher diff --git a/tests/dataloader/yaml_configs/skipped_dataloader.yaml b/tests/dataloader/yaml_configs/skipped_dataloader.yaml index ddd81bbe1..30966e490 100644 --- a/tests/dataloader/yaml_configs/skipped_dataloader.yaml +++ b/tests/dataloader/yaml_configs/skipped_dataloader.yaml @@ -29,7 +29,7 @@ train_dataset: sample_key: ${settings.referencing_keys.sample_key} skip_num_samples: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_samples_from_num_tokens config: num_tokens: ${settings.training.global_num_seen_tokens} diff --git a/tests/end2end_tests/gpt2_train_num_steps_8.yaml b/tests/end2end_tests/gpt2_train_num_steps_8.yaml index 4954e6a92..cef1a88a0 100644 --- a/tests/end2end_tests/gpt2_train_num_steps_8.yaml +++ b/tests/end2end_tests/gpt2_train_num_steps_8.yaml @@ -27,7 +27,7 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_packed_mem_map_dataset_continuous config: dataset_path: ${settings.paths.train_dataset_path} @@ -36,7 +36,7 @@ settings: local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_num_tokens config: num_ranks: ${settings.cuda_env.world_size} From ea694a179954350b9a5b93e36d76949e6a413e72 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Mon, 13 Jan 2025 19:39:20 +0100 Subject: [PATCH 23/25] chore: adapted configs after renamings --- tests/end2end_tests/gpt2_warm_start_from_step_4.yaml | 10 +++++----- tests/test_yaml_configs/config_lorem_ipsum.yaml | 4 ++-- tutorials/getting_started/example_config.yaml | 4 ++-- tutorials/library_usage/config_lorem_ipsum.yaml | 4 ++-- .../configs/pretraining_config.yaml | 4 ++-- tutorials/warmstart/configs/pre_training_config.yaml | 4 ++-- tutorials/warmstart/configs/warmstart_config.yaml | 12 ++++++------ 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml b/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml index 1a7c9da6b..5bf7b0669 100644 --- a/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml +++ b/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml @@ -27,28 +27,28 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: global_num_target_tokens_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_target_steps_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} training_progress: global_num_seen_tokens: # used below - component_key: number_conversion + component_key: training_number_conversion variant_key: global_num_seen_tokens_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} num_seen_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_seen_steps_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} num_seen_samples: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_samples_from_num_tokens config: num_tokens: ${settings.training_progress.global_num_seen_tokens} diff --git a/tests/test_yaml_configs/config_lorem_ipsum.yaml b/tests/test_yaml_configs/config_lorem_ipsum.yaml index e9552785b..f0582a89f 100644 --- a/tests/test_yaml_configs/config_lorem_ipsum.yaml +++ b/tests/test_yaml_configs/config_lorem_ipsum.yaml @@ -26,7 +26,7 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_packed_mem_map_dataset_continuous config: dataset_path: ${settings.paths.train_dataset_path} @@ -35,7 +35,7 @@ settings: local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_num_tokens config: num_ranks: ${settings.cuda_env.world_size} diff --git a/tutorials/getting_started/example_config.yaml b/tutorials/getting_started/example_config.yaml index f1737f940..965302e36 100644 --- a/tutorials/getting_started/example_config.yaml +++ b/tutorials/getting_started/example_config.yaml @@ -28,7 +28,7 @@ settings: sequence_length: 512 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_packed_mem_map_dataset_continuous config: dataset_path: ${settings.paths.train_dataset_path} @@ -37,7 +37,7 @@ settings: local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_num_tokens config: num_ranks: ${settings.cuda_env.world_size} diff --git a/tutorials/library_usage/config_lorem_ipsum.yaml b/tutorials/library_usage/config_lorem_ipsum.yaml index 915e0ebd0..f8bdbb517 100644 --- a/tutorials/library_usage/config_lorem_ipsum.yaml +++ b/tutorials/library_usage/config_lorem_ipsum.yaml @@ -29,7 +29,7 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_num_steps config: num_steps: ${settings.training_target.num_target_steps} @@ -38,7 +38,7 @@ settings: sequence_length: ${settings.step_profile.sequence_length} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_raw_dataset_index config: raw_index_path: ${settings.paths.index_path} diff --git a/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml b/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml index 166d25fb5..bd91e965c 100644 --- a/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml +++ b/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml @@ -27,7 +27,7 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_packed_mem_map_dataset_continuous config: dataset_path: ${settings.paths.train_dataset_path} @@ -36,7 +36,7 @@ settings: local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_num_tokens config: num_ranks: ${settings.cuda_env.world_size} diff --git a/tutorials/warmstart/configs/pre_training_config.yaml b/tutorials/warmstart/configs/pre_training_config.yaml index 30db4adf6..e3e83a64a 100644 --- a/tutorials/warmstart/configs/pre_training_config.yaml +++ b/tutorials/warmstart/configs/pre_training_config.yaml @@ -27,7 +27,7 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_packed_mem_map_dataset_continuous config: dataset_path: ${settings.paths.train_dataset_path} @@ -36,7 +36,7 @@ settings: local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_num_tokens config: num_ranks: ${settings.cuda_env.world_size} diff --git a/tutorials/warmstart/configs/warmstart_config.yaml b/tutorials/warmstart/configs/warmstart_config.yaml index 1858d9a11..6fc68619a 100644 --- a/tutorials/warmstart/configs/warmstart_config.yaml +++ b/tutorials/warmstart/configs/warmstart_config.yaml @@ -27,28 +27,28 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: global_num_target_tokens_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_target_steps_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} training_progress: global_num_seen_tokens: # used below - component_key: number_conversion + component_key: training_number_conversion variant_key: global_num_seen_tokens_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} num_seen_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_seen_steps_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} local_num_seen_batches: # for the dataloader - component_key: number_conversion + component_key: training_number_conversion variant_key: local_num_batches_from_num_tokens config: num_ranks: ${settings.cuda_env.world_size} @@ -56,7 +56,7 @@ settings: sequence_length: ${settings.step_profile.sequence_length} local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} last_step: # for the scheduler - component_key: number_conversion + component_key: training_number_conversion variant_key: last_step_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} From 1a515baf561620c4af6634f00c655935da89aa00 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Mon, 13 Jan 2025 19:58:10 +0100 Subject: [PATCH 24/25] refactor: improved code quality in strategies.py --- .../preprocessing/tokenization/strategies.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/modalities/dataloader/preprocessing/tokenization/strategies.py b/src/modalities/dataloader/preprocessing/tokenization/strategies.py index 1d508498a..cb54d83ef 100644 --- a/src/modalities/dataloader/preprocessing/tokenization/strategies.py +++ b/src/modalities/dataloader/preprocessing/tokenization/strategies.py @@ -8,7 +8,6 @@ from typing import Type import jq -from data_quality_ablations.utils.logging import get_logger from pydantic import BaseModel from modalities.config.instantiation_models import TokenizationInstantiationModel @@ -26,6 +25,7 @@ from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper +from modalities.utils.logging import get_logger def get_required_num_of_bytes_to_repr(int_to_get_repr: int) -> int: @@ -72,8 +72,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): def process(self) -> dict[str, ReadingJob | ProgressMessage]: try: sample_id = next(self._reading_iter) - except StopIteration: - raise ProcessingStrategyDoneException("PopulatingStrategy done.") + except StopIteration as e: + raise ProcessingStrategyDoneException("PopulatingStrategy done.") from e reading_job = ReadingJob(sample_id=sample_id, batch_size=self._batch_size) progress_message = ProgressMessage(WorkerTypes.POPULATOR, num_samples=self._batch_size) return {self._reader_q_key: reading_job, self._logging_message_q_key: progress_message} @@ -121,6 +121,9 @@ def __init__( self._jq_filter = jq.compile(jq_pattern) self._writer_q_key = writer_q_key self._logging_message_q_key = logging_message_q_key + self._tokenizer = None + self._token_size_in_bytes = None + self._encoded_eos_token_as_bytes = None def __enter__(self): registry = Registry(COMPONENTS) @@ -177,13 +180,20 @@ def __init__(self, dst_path: Path, index_start: int, logging_message_q_key: str) self._dst_path = dst_path self._index_start = index_start self._logging_message_q_key = logging_message_q_key + self._dst_fd = None + self._finalized = None + self._curr_offset = None + self._prev_line_id = None + self._batch_dict = None + self._index_list = None + self._has_seen_first_batch = None if not self._dst_path.parent.exists(): self._dst_path.parent.mkdir(parents=True, exist_ok=True) def __enter__(self): self._dst_fd = self._dst_path.open("wb") - self.finalized = False + self._finalized = False # allocate first self.header_size_in_bytes bytes for header (encodes length of data section) # not possible to prepend header after determining size of data section self._dst_fd.write((0).to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little")) @@ -212,10 +222,10 @@ def finalize(self): self._dst_fd.write(pickle.dumps(self._index_list)) self._dst_fd.close() self._update_data_length_in_pre_allocated_header(self._dst_path, self._index_list) - self.finalized = True + self._finalized = True def __exit__(self, exc_type, exc_val, exc_tb): - if not self.finalized: + if not self._finalized: self._dst_fd.close() # if the process was stopped due to a stop event or the index list is empty, we remove the file get_logger(name="main").warning( @@ -293,6 +303,7 @@ def __init__( self._worker_to_pid_to_num_samples: dict[WorkerTypes, dict[int, int]] = {} self._worker_type_to_processed_num_samples = {worker_type: 0 for worker_type in WorkerTypes} self._q_dict = q_dict + self._last_logged = None def __enter__(self): self._last_logged = time.time() @@ -354,7 +365,7 @@ def _log_and_reset(self, passed_time: int): for q_key, q in self._q_dict.items(): logging_message += f"\t{q_key}: {q.qsize()} batches (approx.)\n" - get_logger().info(logging_message) + get_logger().info("%s", logging_message) # reset values for worker_type in self._worker_to_pid_to_num_samples.keys(): @@ -400,6 +411,7 @@ def get_reader_strategy( else: raise ValueError(f"Reader type {reader_type} is not supported.") + @staticmethod def get_tokenizer_strategy( tokenizer_settings: TokenizationInstantiationModel.TokenizerWorkerSettings.TokenizerSettings, writer_q_key: str, @@ -414,6 +426,7 @@ def get_tokenizer_strategy( ) return tokenizing_strategy + @staticmethod def get_writing_strategy( ww_settings: TokenizationInstantiationModel.WriterWorkerSettings, logging_message_q_key: str, @@ -425,6 +438,7 @@ def get_writing_strategy( ) return writing_strategy + @staticmethod def get_progress_logging_strategy( logging_interval: int, total_num_samples: int, From 716a98c5ac81808c236ae1897bcf64b7e00c552b Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Mon, 13 Jan 2025 20:35:17 +0100 Subject: [PATCH 25/25] refactor: integrated changes from PR #283 into tokenization strategy --- .../preprocessing/tokenization/strategies.py | 17 ++++++++++------- .../tokenization/tokenizer_wrapper.py | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/modalities/dataloader/preprocessing/tokenization/strategies.py b/src/modalities/dataloader/preprocessing/tokenization/strategies.py index cb54d83ef..86e01c239 100644 --- a/src/modalities/dataloader/preprocessing/tokenization/strategies.py +++ b/src/modalities/dataloader/preprocessing/tokenization/strategies.py @@ -123,7 +123,7 @@ def __init__( self._logging_message_q_key = logging_message_q_key self._tokenizer = None self._token_size_in_bytes = None - self._encoded_eos_token_as_bytes = None + self._encoded_eod_token_as_bytes = None def __enter__(self): registry = Registry(COMPONENTS) @@ -133,10 +133,10 @@ def __enter__(self): ) self._tokenizer: TokenizerWrapper = tokenizer_type(**self._tokenizer_instantiation_setings.config) - encoded_eod_token = self._tokenizer.get_token_id(self._eod_token) self._token_size_in_bytes = get_required_num_of_bytes_to_repr(self._tokenizer.vocab_size) - self._encoded_eos_token_as_bytes = TokenizingStrategy._encoded_token_to_bytes( - token_size_in_bytes=self._token_size_in_bytes, encoded_token=encoded_eod_token + eod_token_id = self._tokenizer.get_token_id(self._eod_token) + self._encoded_eod_token_as_bytes = TokenizingStrategy._encoded_token_to_bytes( + token_size_in_bytes=self._token_size_in_bytes, encoded_token=eod_token_id ) return self @@ -164,10 +164,13 @@ def _process_line(self, line: str) -> bytes: tokens = self._tokenizer.tokenize(jq_retrieved_text) if len(tokens) == 0: raise EmptySampleError("Received empty sample...") - return ( - b"".join(map(self._encoded_token_to_bytes, [self._token_size_in_bytes] * len(tokens), tokens)) - + self._encoded_eos_token_as_bytes + + token_byte_string = b"".join( + map(self._encoded_token_to_bytes, [self._token_size_in_bytes] * len(tokens), tokens) ) + if not token_byte_string.endswith(self._encoded_eod_token_as_bytes): + token_byte_string = token_byte_string + self._encoded_eod_token_as_bytes + return token_byte_string @staticmethod def _encoded_token_to_bytes(token_size_in_bytes: int, encoded_token: int) -> bytes: diff --git a/src/modalities/tokenization/tokenizer_wrapper.py b/src/modalities/tokenization/tokenizer_wrapper.py index 211f5801f..f02643c17 100644 --- a/src/modalities/tokenization/tokenizer_wrapper.py +++ b/src/modalities/tokenization/tokenizer_wrapper.py @@ -260,7 +260,7 @@ def get_token_id(self, token: str) -> int: if not isinstance(piece_id, int): raise ValueError("Token cannot be represented by a single token ID!") if piece_id == self.tokenizer.unk_id(): - raise ValueError("Token cannot be represented by a single token id!") + raise ValueError("Token cannot be represented by a single token id!") return piece_id def is_special_token_id(self, token_id: int) -> bool: