Skip to content
152 changes: 74 additions & 78 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class JsonConfig(datasets.BuilderConfig):
block_size: Optional[int] = None # deprecated
chunksize: int = 10 << 20 # 10MB
newlines_in_values: Optional[bool] = None
columns: Optional[List[str]] = None

def __post_init__(self):
super().__post_init__()
Expand Down Expand Up @@ -109,84 +110,79 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
pa_table = table_cast(pa_table, self.config.features.arrow_schema)
return pa_table

def _generate_tables(self, files):
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
# If the file is one json object and if we need to look at the items in one specific field
def _generate_tables(self, files: List[str]) -> Generator:
for file_idx, file in enumerate(files):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why assuming that files is a list of strings ? it isn't

and overall this PR is pretty hard to review since you're doing a lot if small and unnecessary changes, I'd suggest opening a new PR and try to do minimal changes instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

if self.config.field is not None:
with open(file, encoding=self.config.encoding, errors=self.config.encoding_errors) as f:
dataset = ujson_loads(f.read())
# We keep only the field we are interested in
dataset = dataset[self.config.field]
df = pandas_read_json(io.StringIO(ujson_dumps(dataset)))
if df.columns.tolist() == [0]:
df.columns = list(self.config.features) if self.config.features else ["text"]
pa_table = pa.Table.from_pandas(df, preserve_index=False)
yield file_idx, self._cast_table(pa_table)

# If the file has one json object per line
else:
with open(file, "rb") as f:
batch_idx = 0
# Use block_size equal to the chunk size divided by 32 to leverage multithreading
# Set a default minimum value of 16kB if the chunk size is really small
Comment on lines -130 to -131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this comment deletion and the 2 others

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this comment deletion and the 2 others

Wanted clarification on “the 2 others” to ensure no comment restorations were missed. Actually i have restored the two missing comments above - are they at the right place? :)

block_size = max(self.config.chunksize // 32, 16 << 10)
encoding_errors = (
self.config.encoding_errors if self.config.encoding_errors is not None else "strict"
)
while True:
batch = f.read(self.config.chunksize)
if not batch:
break
# Finish current line
try:
batch += f.readline()
except (AttributeError, io.UnsupportedOperation):
batch += readline(f)
# PyArrow only accepts utf-8 encoded bytes
if self.config.encoding != "utf-8":
batch = batch.decode(self.config.encoding, errors=encoding_errors).encode("utf-8")
try:
while True:
try:
pa_table = paj.read_json(
io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size)
)
break
except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e:
if (
isinstance(e, pa.ArrowInvalid)
and "straddling" not in str(e)
or block_size > len(batch)
):
raise
else:
# Increase the block size in case it was too small.
# The block size will be reset for the next file.
logger.debug(
f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. Retrying with block_size={block_size * 2}."
)
block_size *= 2
except pa.ArrowInvalid as e:
try:
with open(
file, encoding=self.config.encoding, errors=self.config.encoding_errors
) as f:
df = pandas_read_json(f)
except ValueError:
logger.error(f"Failed to load JSON from file '{file}' with error {type(e)}: {e}")
raise e
if df.columns.tolist() == [0]:
df.columns = list(self.config.features) if self.config.features else ["text"]
try:
pa_table = pa.Table.from_pandas(df, preserve_index=False)
except pa.ArrowInvalid as e:
logger.error(
f"Failed to convert pandas DataFrame to Arrow Table from file '{file}' with error {type(e)}: {e}"
# Load JSON with field selection
try:
for batch_idx, json_obj in enumerate(
ijson.items(
open(
file,
encoding=self.config.encoding,
errors=self.config.encoding_errors,
),
self.config.field,
)
):
pa_table = pa.Table.from_pandas(pd.DataFrame(json_obj))

if self.config.columns is not None:
missing_cols = set(self.config.columns) - set(
pa_table.column_names
)
for col in missing_cols:
pa_table = pa_table.append_column(
col, pa.array([None] * pa_table.num_rows)
)
raise ValueError(
f"Failed to convert pandas DataFrame to Arrow Table from file {file}."
) from None
yield file_idx, self._cast_table(pa_table)
break
pa_table = pa_table.select(self.config.columns)

yield (file_idx, batch_idx), self._cast_table(pa_table)
batch_idx += 1

except Exception as e:
raise DatasetGenerationError(
f"Failed to parse JSON with field {self.config.field}: {e}"
) from e

else:
# Load JSON line by line
batch_idx = 0
while True:
try:
pa_table = paj.read_json(
file,
read_options=paj.ReadOptions(
use_threads=True,
block_size=1 << 20,
),
parse_options=paj.ParseOptions(explicit_schema=None),
)
break

except pa.ArrowInvalid:
# Pandas fallback only if Arrow fails
with open(
file,
encoding=self.config.encoding,
errors=self.config.encoding_errors,
) as f:
df = pandas_read_json(f)
pa_table = pa.Table.from_pandas(df)
break

except StopIteration:
# End of file
return

# Apply columns selection after table is ready
if self.config.columns is not None:
missing_cols = set(self.config.columns) - set(
pa_table.column_names
)
for col in missing_cols:
pa_table = pa_table.append_column(
col, pa.array([None] * pa_table.num_rows)
)
pa_table = pa_table.select(self.config.columns)

yield (file_idx, batch_idx), self._cast_table(pa_table)
19 changes: 18 additions & 1 deletion tests/packaged_modules/test_json.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
# Standard library
import json
import tempfile
import textwrap

import pyarrow as pa
# Third-party
import pytest
import pyarrow as pa

# First-party (datasets)
from datasets import load_dataset
from datasets import Features, Value
from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
Expand Down Expand Up @@ -265,3 +271,14 @@ def test_json_generate_tables_with_sorted_columns(file_fixture, config_kwargs, r
generator = builder._generate_tables([[request.getfixturevalue(file_fixture)]])
pa_table = pa.concat_tables([table for _, table in generator])
assert pa_table.column_names == ["ID", "Language", "Topic"]

def test_load_dataset_json_with_columns_filtering():
sample = {"a": 1, "b": 2, "c": 3}

with tempfile.NamedTemporaryFile("w+", suffix=".jsonl", delete=False) as f:
f.write(json.dumps(sample) + "\n")
f.write(json.dumps(sample) + "\n")
path = f.name

dataset = load_dataset("json", data_files=path, columns=["a", "c"])
assert set(dataset["train"].column_names) == {"a", "c"}