From db75657d44764e17cc2b8b6cbc0b95f94ddb1e7b Mon Sep 17 00:00:00 2001 From: Arjun Dinesh Jagdale <142811259+ArjunJagdale@users.noreply.github.com> Date: Fri, 27 Jun 2025 21:48:19 +0530 Subject: [PATCH 01/11] temp1 temp2 --- src/datasets/packaged_modules/json/json.py | 23 ++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index 426083fc718..23e8d5c4f49 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -50,6 +50,7 @@ class JsonConfig(datasets.BuilderConfig): block_size: Optional[int] = None # deprecated chunksize: int = 10 << 20 # 10MB newlines_in_values: Optional[bool] = None + columns: Optional[List[str]] = None def __post_init__(self): super().__post_init__() @@ -107,14 +108,18 @@ def _generate_tables(self, files): if df.columns.tolist() == [0]: df.columns = list(self.config.features) if self.config.features else ["text"] pa_table = pa.Table.from_pandas(df, preserve_index=False) + + # Filter only selected columns if specified + if self.config.columns is not None: + keep_cols = [col for col in self.config.columns if col in pa_table.column_names] + pa_table = pa_table.select(keep_cols) + yield file_idx, self._cast_table(pa_table) # If the file has one json object per line else: with open(file, "rb") as f: batch_idx = 0 - # Use block_size equal to the chunk size divided by 32 to leverage multithreading - # Set a default minimum value of 16kB if the chunk size is really small block_size = max(self.config.chunksize // 32, 16 << 10) encoding_errors = ( self.config.encoding_errors if self.config.encoding_errors is not None else "strict" @@ -123,12 +128,10 @@ def _generate_tables(self, files): batch = f.read(self.config.chunksize) if not batch: break - # Finish current line try: batch += f.readline() except (AttributeError, io.UnsupportedOperation): batch += readline(f) - # PyArrow only accepts utf-8 encoded bytes if self.config.encoding != "utf-8": batch = batch.decode(self.config.encoding, errors=encoding_errors).encode("utf-8") try: @@ -137,6 +140,10 @@ def _generate_tables(self, files): pa_table = paj.read_json( io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size) ) + if self.config.columns is not None: + keep_cols = [col for col in self.config.columns if col in pa_table.column_names] + pa_table = pa_table.select(keep_cols) + yield (file_idx, batch_idx), self._cast_table(pa_table) break except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e: if ( @@ -146,8 +153,6 @@ def _generate_tables(self, files): ): raise else: - # Increase the block size in case it was too small. - # The block size will be reset for the next file. logger.debug( f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. Retrying with block_size={block_size * 2}." ) @@ -165,6 +170,10 @@ def _generate_tables(self, files): df.columns = list(self.config.features) if self.config.features else ["text"] try: pa_table = pa.Table.from_pandas(df, preserve_index=False) + if self.config.columns is not None: + keep_cols = [col for col in self.config.columns if col in pa_table.column_names] + pa_table = pa_table.select(keep_cols) + yield (file_idx, batch_idx), self._cast_table(pa_table) except pa.ArrowInvalid as e: logger.error( f"Failed to convert pandas DataFrame to Arrow Table from file '{file}' with error {type(e)}: {e}" @@ -172,7 +181,5 @@ def _generate_tables(self, files): raise ValueError( f"Failed to convert pandas DataFrame to Arrow Table from file {file}." ) from None - yield file_idx, self._cast_table(pa_table) break - yield (file_idx, batch_idx), self._cast_table(pa_table) batch_idx += 1 From c7872cb3d28f7f56506672c8eb8381585aa66a37 Mon Sep 17 00:00:00 2001 From: Arjun Dinesh Jagdale <142811259+ArjunJagdale@users.noreply.github.com> Date: Fri, 27 Jun 2025 22:04:19 +0530 Subject: [PATCH 02/11] Update load.py --- src/datasets/load.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/datasets/load.py b/src/datasets/load.py index bc2b0e679b6..cd4e4cb2637 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1200,6 +1200,7 @@ def load_dataset( streaming: bool = False, num_proc: Optional[int] = None, storage_options: Optional[dict] = None, + columns: Optional[List[str]] = None, **config_kwargs, ) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]: """Load a dataset from the Hugging Face Hub, or a local dataset. @@ -1388,6 +1389,8 @@ def load_dataset( (verification_mode or VerificationMode.BASIC_CHECKS) if not save_infos else VerificationMode.ALL_CHECKS ) + if path == "json" and columns is not None: + config_kwargs["columns"] = columns # Create a dataset builder builder_instance = load_dataset_builder( path=path, From a0fedf50360495578b8cd7c63cb363a0b12af379 Mon Sep 17 00:00:00 2001 From: Arjun Dinesh Jagdale <142811259+ArjunJagdale@users.noreply.github.com> Date: Fri, 27 Jun 2025 22:08:30 +0530 Subject: [PATCH 03/11] Update test_json.py --- tests/packaged_modules/test_json.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/packaged_modules/test_json.py b/tests/packaged_modules/test_json.py index 18f066b5e68..47efe35fa63 100644 --- a/tests/packaged_modules/test_json.py +++ b/tests/packaged_modules/test_json.py @@ -1,8 +1,14 @@ +# Standard library +import json +import tempfile import textwrap -import pyarrow as pa +# Third-party import pytest +import pyarrow as pa +# First-party (datasets) +from datasets import load_dataset from datasets import Features, Value from datasets.builder import InvalidConfigName from datasets.data_files import DataFilesList @@ -265,3 +271,14 @@ def test_json_generate_tables_with_sorted_columns(file_fixture, config_kwargs, r generator = builder._generate_tables([[request.getfixturevalue(file_fixture)]]) pa_table = pa.concat_tables([table for _, table in generator]) assert pa_table.column_names == ["ID", "Language", "Topic"] + +def test_load_dataset_json_with_columns_filtering(): + sample = {"a": 1, "b": 2, "c": 3} + + with tempfile.NamedTemporaryFile("w+", suffix=".jsonl", delete=False) as f: + f.write(json.dumps(sample) + "\n") + f.write(json.dumps(sample) + "\n") + path = f.name + + dataset = load_dataset("json", data_files=path, columns=["a", "c"]) + assert set(dataset["train"].column_names) == {"a", "c"} From d23a48be0ded7f3d2c3a79cc51672c337b08ced8 Mon Sep 17 00:00:00 2001 From: Arjun Dinesh Jagdale <142811259+ArjunJagdale@users.noreply.github.com> Date: Thu, 3 Jul 2025 15:20:26 +0530 Subject: [PATCH 04/11] Update json.py --- src/datasets/packaged_modules/json/json.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index 23e8d5c4f49..623d10298d8 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -111,8 +111,10 @@ def _generate_tables(self, files): # Filter only selected columns if specified if self.config.columns is not None: - keep_cols = [col for col in self.config.columns if col in pa_table.column_names] - pa_table = pa_table.select(keep_cols) + missing_cols = [col for col in self.config.columns if col not in pa_table.column_names] + for col in missing_cols: + pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows)) + pa_table = pa_table.select(self.config.columns) yield file_idx, self._cast_table(pa_table) @@ -141,8 +143,10 @@ def _generate_tables(self, files): io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size) ) if self.config.columns is not None: - keep_cols = [col for col in self.config.columns if col in pa_table.column_names] - pa_table = pa_table.select(keep_cols) + missing_cols = [col for col in self.config.columns if col not in pa_table.column_names] + for col in missing_cols: + pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows)) + pa_table = pa_table.select(self.config.columns) yield (file_idx, batch_idx), self._cast_table(pa_table) break except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e: @@ -171,8 +175,10 @@ def _generate_tables(self, files): try: pa_table = pa.Table.from_pandas(df, preserve_index=False) if self.config.columns is not None: - keep_cols = [col for col in self.config.columns if col in pa_table.column_names] - pa_table = pa_table.select(keep_cols) + missing_cols = [col for col in self.config.columns if col not in pa_table.column_names] + for col in missing_cols: + pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows)) + pa_table = pa_table.select(self.config.columns) yield (file_idx, batch_idx), self._cast_table(pa_table) except pa.ArrowInvalid as e: logger.error( From 5d3cc12af5c7907c9557f99ddc487746020c42b3 Mon Sep 17 00:00:00 2001 From: Arjun Jagdale <142811259+ArjunJagdale@users.noreply.github.com> Date: Fri, 15 Aug 2025 00:18:36 +0530 Subject: [PATCH 05/11] Update src/datasets/load.py Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- src/datasets/load.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/datasets/load.py b/src/datasets/load.py index cd4e4cb2637..69149f612a7 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1389,8 +1389,6 @@ def load_dataset( (verification_mode or VerificationMode.BASIC_CHECKS) if not save_infos else VerificationMode.ALL_CHECKS ) - if path == "json" and columns is not None: - config_kwargs["columns"] = columns # Create a dataset builder builder_instance = load_dataset_builder( path=path, From eec7df9330581db2a9cd3d68b1779165b9424242 Mon Sep 17 00:00:00 2001 From: Arjun Jagdale <142811259+ArjunJagdale@users.noreply.github.com> Date: Fri, 15 Aug 2025 00:18:46 +0530 Subject: [PATCH 06/11] Update src/datasets/load.py Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- src/datasets/load.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/datasets/load.py b/src/datasets/load.py index 69149f612a7..bc2b0e679b6 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1200,7 +1200,6 @@ def load_dataset( streaming: bool = False, num_proc: Optional[int] = None, storage_options: Optional[dict] = None, - columns: Optional[List[str]] = None, **config_kwargs, ) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]: """Load a dataset from the Hugging Face Hub, or a local dataset. From 5e93f702ce6f47c371d2013e986a98830d9d77ed Mon Sep 17 00:00:00 2001 From: Arjun Jagdale <142811259+ArjunJagdale@users.noreply.github.com> Date: Fri, 15 Aug 2025 01:02:12 +0530 Subject: [PATCH 07/11] Update json.py --- src/datasets/packaged_modules/json/json.py | 103 +++++++++++---------- 1 file changed, 52 insertions(+), 51 deletions(-) diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index 623d10298d8..14bea59c582 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -1,7 +1,7 @@ import io import itertools from dataclasses import dataclass -from typing import Optional +from typing import Optional, List import pandas as pd import pyarrow as pa @@ -122,6 +122,8 @@ def _generate_tables(self, files): else: with open(file, "rb") as f: batch_idx = 0 + # Use block_size equal to the chunk size divided by 32 to leverage multithreading + # Set a default minimum value of 16kB if the chunk size is really small block_size = max(self.config.chunksize // 32, 16 << 10) encoding_errors = ( self.config.encoding_errors if self.config.encoding_errors is not None else "strict" @@ -136,56 +138,55 @@ def _generate_tables(self, files): batch += readline(f) if self.config.encoding != "utf-8": batch = batch.decode(self.config.encoding, errors=encoding_errors).encode("utf-8") - try: - while True: - try: - pa_table = paj.read_json( - io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size) - ) - if self.config.columns is not None: - missing_cols = [col for col in self.config.columns if col not in pa_table.column_names] - for col in missing_cols: - pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows)) - pa_table = pa_table.select(self.config.columns) - yield (file_idx, batch_idx), self._cast_table(pa_table) - break - except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e: - if ( - isinstance(e, pa.ArrowInvalid) - and "straddling" not in str(e) - or block_size > len(batch) - ): - raise - else: - logger.debug( - f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. Retrying with block_size={block_size * 2}." - ) - block_size *= 2 - except pa.ArrowInvalid as e: - try: - with open( - file, encoding=self.config.encoding, errors=self.config.encoding_errors - ) as f: - df = pandas_read_json(f) - except ValueError: - logger.error(f"Failed to load JSON from file '{file}' with error {type(e)}: {e}") - raise e - if df.columns.tolist() == [0]: - df.columns = list(self.config.features) if self.config.features else ["text"] + + while True: try: - pa_table = pa.Table.from_pandas(df, preserve_index=False) - if self.config.columns is not None: - missing_cols = [col for col in self.config.columns if col not in pa_table.column_names] - for col in missing_cols: - pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows)) - pa_table = pa_table.select(self.config.columns) - yield (file_idx, batch_idx), self._cast_table(pa_table) - except pa.ArrowInvalid as e: - logger.error( - f"Failed to convert pandas DataFrame to Arrow Table from file '{file}' with error {type(e)}: {e}" + pa_table = paj.read_json( + io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size) ) - raise ValueError( - f"Failed to convert pandas DataFrame to Arrow Table from file {file}." - ) from None - break + break + except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e: + if ( + isinstance(e, pa.ArrowInvalid) + and "straddling" not in str(e) + ) or block_size > len(batch): + raise + logger.debug( + f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. " + f"Retrying with block_size={block_size * 2}." + ) + block_size *= 2 + + if self.config.columns is not None: + missing_cols = [col for col in self.config.columns if col not in pa_table.column_names] + for col in missing_cols: + pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows)) + pa_table = pa_table.select(self.config.columns) + + yield (file_idx, batch_idx), self._cast_table(pa_table) batch_idx += 1 + + # Pandas fallback in case of ArrowInvalid + try: + with open(file, encoding=self.config.encoding, errors=self.config.encoding_errors) as f: + df = pandas_read_json(f) + except ValueError as e: + logger.error(f"Failed to load JSON from file '{file}' with error {type(e)}: {e}") + raise e + if df.columns.tolist() == [0]: + df.columns = list(self.config.features) if self.config.features else ["text"] + try: + pa_table = pa.Table.from_pandas(df, preserve_index=False) + if self.config.columns is not None: + missing_cols = [col for col in self.config.columns if col not in pa_table.column_names] + for col in missing_cols: + pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows)) + pa_table = pa_table.select(self.config.columns) + yield (file_idx, batch_idx), self._cast_table(pa_table) + except pa.ArrowInvalid as e: + logger.error( + f"Failed to convert pandas DataFrame to Arrow Table from file '{file}' with error {type(e)}: {e}" + ) + raise ValueError( + f"Failed to convert pandas DataFrame to Arrow Table from file {file}." + ) from None From 608ed213d04a027bfec751d518d3fbbbb65cc94d Mon Sep 17 00:00:00 2001 From: Arjun Jagdale <142811259+ArjunJagdale@users.noreply.github.com> Date: Fri, 15 Aug 2025 01:04:37 +0530 Subject: [PATCH 08/11] Update json.py --- src/datasets/packaged_modules/json/json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index 14bea59c582..cf78718b0a2 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -1,7 +1,7 @@ import io import itertools from dataclasses import dataclass -from typing import Optional, List +from typing import Optional import pandas as pd import pyarrow as pa From 428444dbeb51a0292b12035883390b5cd924d9fe Mon Sep 17 00:00:00 2001 From: Arjun Jagdale <142811259+ArjunJagdale@users.noreply.github.com> Date: Tue, 26 Aug 2025 23:37:10 +0530 Subject: [PATCH 09/11] Update json.py --- src/datasets/packaged_modules/json/json.py | 152 +++++++++------------ 1 file changed, 67 insertions(+), 85 deletions(-) diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index a5a0c6d901d..f9b96fd74cd 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -110,97 +110,79 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table: pa_table = table_cast(pa_table, self.config.features.arrow_schema) return pa_table - def _generate_tables(self, files): - for file_idx, file in enumerate(itertools.chain.from_iterable(files)): - # If the file is one json object and if we need to look at the items in one specific field + def _generate_tables(self, files: List[str]) -> Generator: + for file_idx, file in enumerate(files): if self.config.field is not None: - with open(file, encoding=self.config.encoding, errors=self.config.encoding_errors) as f: - dataset = ujson_loads(f.read()) - # We keep only the field we are interested in - dataset = dataset[self.config.field] - df = pandas_read_json(io.StringIO(ujson_dumps(dataset))) - if df.columns.tolist() == [0]: - df.columns = list(self.config.features) if self.config.features else ["text"] - pa_table = pa.Table.from_pandas(df, preserve_index=False) - - # Filter only selected columns if specified - if self.config.columns is not None: - missing_cols = [col for col in self.config.columns if col not in pa_table.column_names] - for col in missing_cols: - pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows)) - pa_table = pa_table.select(self.config.columns) - - yield file_idx, self._cast_table(pa_table) - - # If the file has one json object per line - else: - with open(file, "rb") as f: - batch_idx = 0 - # Use block_size equal to the chunk size divided by 32 to leverage multithreading - # Set a default minimum value of 16kB if the chunk size is really small - block_size = max(self.config.chunksize // 32, 16 << 10) - encoding_errors = ( - self.config.encoding_errors if self.config.encoding_errors is not None else "strict" - ) - while True: - batch = f.read(self.config.chunksize) - if not batch: - break - try: - batch += f.readline() - except (AttributeError, io.UnsupportedOperation): - batch += readline(f) - if self.config.encoding != "utf-8": - batch = batch.decode(self.config.encoding, errors=encoding_errors).encode("utf-8") - - while True: - try: - pa_table = paj.read_json( - io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size) - ) - break - except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e: - if ( - isinstance(e, pa.ArrowInvalid) - and "straddling" not in str(e) - ) or block_size > len(batch): - raise - logger.debug( - f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. " - f"Retrying with block_size={block_size * 2}." - ) - block_size *= 2 + # Load JSON with field selection + try: + for batch_idx, json_obj in enumerate( + ijson.items( + open( + file, + encoding=self.config.encoding, + errors=self.config.encoding_errors, + ), + self.config.field, + ) + ): + pa_table = pa.Table.from_pandas(pd.DataFrame(json_obj)) if self.config.columns is not None: - missing_cols = [col for col in self.config.columns if col not in pa_table.column_names] + missing_cols = set(self.config.columns) - set( + pa_table.column_names + ) for col in missing_cols: - pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows)) + pa_table = pa_table.append_column( + col, pa.array([None] * pa_table.num_rows) + ) pa_table = pa_table.select(self.config.columns) yield (file_idx, batch_idx), self._cast_table(pa_table) - batch_idx += 1 - # Pandas fallback in case of ArrowInvalid - try: - with open(file, encoding=self.config.encoding, errors=self.config.encoding_errors) as f: - df = pandas_read_json(f) - except ValueError as e: - logger.error(f"Failed to load JSON from file '{file}' with error {type(e)}: {e}") - raise e - if df.columns.tolist() == [0]: - df.columns = list(self.config.features) if self.config.features else ["text"] - try: - pa_table = pa.Table.from_pandas(df, preserve_index=False) - if self.config.columns is not None: - missing_cols = [col for col in self.config.columns if col not in pa_table.column_names] - for col in missing_cols: - pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows)) - pa_table = pa_table.select(self.config.columns) - yield (file_idx, batch_idx), self._cast_table(pa_table) - except pa.ArrowInvalid as e: - logger.error( - f"Failed to convert pandas DataFrame to Arrow Table from file '{file}' with error {type(e)}: {e}" + except Exception as e: + raise DatasetGenerationError( + f"Failed to parse JSON with field {self.config.field}: {e}" + ) from e + + else: + # Load JSON line by line + batch_idx = 0 + while True: + try: + pa_table = paj.read_json( + file, + read_options=paj.ReadOptions( + use_threads=True, + block_size=1 << 20, + ), + parse_options=paj.ParseOptions(explicit_schema=None), + ) + break + + except pa.ArrowInvalid: + # Pandas fallback only if Arrow fails + with open( + file, + encoding=self.config.encoding, + errors=self.config.encoding_errors, + ) as f: + df = pandas_read_json(f) + pa_table = pa.Table.from_pandas(df) + break + + except StopIteration: + # End of file + return + + # Apply columns selection after table is ready + if self.config.columns is not None: + missing_cols = set(self.config.columns) - set( + pa_table.column_names ) - raise ValueError( - f"Failed to convert pandas DataFrame to Arrow Table from file {file}." - ) from None + for col in missing_cols: + pa_table = pa_table.append_column( + col, pa.array([None] * pa_table.num_rows) + ) + pa_table = pa_table.select(self.config.columns) + + yield (file_idx, batch_idx), self._cast_table(pa_table) From ee86c9a4020e908c0b87facb23019977e749f68d Mon Sep 17 00:00:00 2001 From: Arjun Jagdale <142811259+ArjunJagdale@users.noreply.github.com> Date: Thu, 4 Sep 2025 22:51:54 +0530 Subject: [PATCH 10/11] Update json.py --- src/datasets/packaged_modules/json/json.py | 78 +++++++++------------- 1 file changed, 33 insertions(+), 45 deletions(-) diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index f9b96fd74cd..c23068d2777 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -110,11 +110,12 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table: pa_table = table_cast(pa_table, self.config.features.arrow_schema) return pa_table - def _generate_tables(self, files: List[str]) -> Generator: - for file_idx, file in enumerate(files): + def _generate_tables(self, files): + for file_idx, file in enumerate(itertools.chain.from_iterable(files)): if self.config.field is not None: # Load JSON with field selection try: + import ijson for batch_idx, json_obj in enumerate( ijson.items( open( @@ -126,63 +127,50 @@ def _generate_tables(self, files: List[str]) -> Generator: ) ): pa_table = pa.Table.from_pandas(pd.DataFrame(json_obj)) - + # Apply columns filtering if requested if self.config.columns is not None: - missing_cols = set(self.config.columns) - set( - pa_table.column_names - ) + missing_cols = set(self.config.columns) - set(pa_table.column_names) for col in missing_cols: - pa_table = pa_table.append_column( - col, pa.array([None] * pa_table.num_rows) - ) + pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows)) pa_table = pa_table.select(self.config.columns) - yield (file_idx, batch_idx), self._cast_table(pa_table) except Exception as e: - raise DatasetGenerationError( + raise datasets.DatasetGenerationError( f"Failed to parse JSON with field {self.config.field}: {e}" ) from e else: # Load JSON line by line batch_idx = 0 - while True: - try: - pa_table = paj.read_json( - file, - read_options=paj.ReadOptions( - use_threads=True, - block_size=1 << 20, - ), - parse_options=paj.ParseOptions(explicit_schema=None), - ) - break - - except pa.ArrowInvalid: - # Pandas fallback only if Arrow fails - with open( - file, - encoding=self.config.encoding, - errors=self.config.encoding_errors, - ) as f: - df = pandas_read_json(f) - pa_table = pa.Table.from_pandas(df) - break - - except StopIteration: - # End of file - return - - # Apply columns selection after table is ready - if self.config.columns is not None: - missing_cols = set(self.config.columns) - set( - pa_table.column_names + try: + pa_table = paj.read_json( + file, + read_options=paj.ReadOptions( + use_threads=True, + block_size=1 << 20, # 1MB default + ), + parse_options=paj.ParseOptions(explicit_schema=None), ) + except pa.ArrowInvalid: + # Pandas fallback only if Arrow fails + with open( + file, + encoding=self.config.encoding, + errors=self.config.encoding_errors, + ) as f: + df = pandas_read_json(f) + pa_table = pa.Table.from_pandas(df) + except StopIteration: + # End of file + return + + # Apply columns filtering if requested + if self.config.columns is not None: + missing_cols = set(self.config.columns) - set(pa_table.column_names) for col in missing_cols: - pa_table = pa_table.append_column( - col, pa.array([None] * pa_table.num_rows) - ) + pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows)) pa_table = pa_table.select(self.config.columns) yield (file_idx, batch_idx), self._cast_table(pa_table) + From 760157c6f70f29ea368384feab659c41f9fd913b Mon Sep 17 00:00:00 2001 From: Arjun Jagdale <142811259+ArjunJagdale@users.noreply.github.com> Date: Thu, 4 Sep 2025 22:55:22 +0530 Subject: [PATCH 11/11] Update test_json.py --- tests/packaged_modules/test_json.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/packaged_modules/test_json.py b/tests/packaged_modules/test_json.py index 47efe35fa63..db7063c9169 100644 --- a/tests/packaged_modules/test_json.py +++ b/tests/packaged_modules/test_json.py @@ -282,3 +282,20 @@ def test_load_dataset_json_with_columns_filtering(): dataset = load_dataset("json", data_files=path, columns=["a", "c"]) assert set(dataset["train"].column_names) == {"a", "c"} + +def test_load_dataset_json_with_missing_columns(): + sample = {"x": 1} + with tempfile.NamedTemporaryFile("w+", suffix=".jsonl", delete=False) as f: + f.write(json.dumps(sample) + "\n") + f.write(json.dumps(sample) + "\n") + path = f.name + + dataset = load_dataset("json", data_files=path, columns=["x", "y"]) + assert set(dataset["train"].column_names) == {"x", "y"} + # "y" should be filled with None + assert dataset["train"]["y"] == [None, None] + +def test_load_dataset_json_without_columns_filtering(jsonl_file): + dataset = load_dataset("json", data_files=jsonl_file) + # Original behavior: no columns filter → all keys are present + assert set(dataset["train"].column_names) == {"col_1", "col_2"}