-
Notifications
You must be signed in to change notification settings - Fork 3k
Add columns support to JSON loader for selective key filtering #7652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
db75657
c7872cb
a0fedf5
d23a48b
5d3cc12
eec7df9
5e93f70
608ed21
9fa38b4
d05759a
428444d
ee86c9a
760157c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,7 @@ class JsonConfig(datasets.BuilderConfig): | |
block_size: Optional[int] = None # deprecated | ||
chunksize: int = 10 << 20 # 10MB | ||
newlines_in_values: Optional[bool] = None | ||
columns: Optional[List[str]] = None | ||
|
||
def __post_init__(self): | ||
super().__post_init__() | ||
|
@@ -109,84 +110,79 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table: | |
pa_table = table_cast(pa_table, self.config.features.arrow_schema) | ||
return pa_table | ||
|
||
def _generate_tables(self, files): | ||
for file_idx, file in enumerate(itertools.chain.from_iterable(files)): | ||
# If the file is one json object and if we need to look at the items in one specific field | ||
def _generate_tables(self, files: List[str]) -> Generator: | ||
for file_idx, file in enumerate(files): | ||
if self.config.field is not None: | ||
with open(file, encoding=self.config.encoding, errors=self.config.encoding_errors) as f: | ||
dataset = ujson_loads(f.read()) | ||
# We keep only the field we are interested in | ||
dataset = dataset[self.config.field] | ||
df = pandas_read_json(io.StringIO(ujson_dumps(dataset))) | ||
if df.columns.tolist() == [0]: | ||
df.columns = list(self.config.features) if self.config.features else ["text"] | ||
pa_table = pa.Table.from_pandas(df, preserve_index=False) | ||
yield file_idx, self._cast_table(pa_table) | ||
|
||
# If the file has one json object per line | ||
else: | ||
with open(file, "rb") as f: | ||
batch_idx = 0 | ||
# Use block_size equal to the chunk size divided by 32 to leverage multithreading | ||
# Set a default minimum value of 16kB if the chunk size is really small | ||
Comment on lines
-130
to
-131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert this comment deletion and the 2 others There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Wanted clarification on “the 2 others” to ensure no comment restorations were missed. Actually i have restored the two missing comments above - are they at the right place? :) |
||
block_size = max(self.config.chunksize // 32, 16 << 10) | ||
encoding_errors = ( | ||
self.config.encoding_errors if self.config.encoding_errors is not None else "strict" | ||
) | ||
while True: | ||
batch = f.read(self.config.chunksize) | ||
if not batch: | ||
break | ||
# Finish current line | ||
try: | ||
batch += f.readline() | ||
except (AttributeError, io.UnsupportedOperation): | ||
batch += readline(f) | ||
# PyArrow only accepts utf-8 encoded bytes | ||
if self.config.encoding != "utf-8": | ||
batch = batch.decode(self.config.encoding, errors=encoding_errors).encode("utf-8") | ||
try: | ||
while True: | ||
try: | ||
pa_table = paj.read_json( | ||
io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size) | ||
) | ||
break | ||
except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e: | ||
if ( | ||
isinstance(e, pa.ArrowInvalid) | ||
and "straddling" not in str(e) | ||
or block_size > len(batch) | ||
): | ||
raise | ||
else: | ||
# Increase the block size in case it was too small. | ||
# The block size will be reset for the next file. | ||
logger.debug( | ||
f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. Retrying with block_size={block_size * 2}." | ||
) | ||
block_size *= 2 | ||
except pa.ArrowInvalid as e: | ||
try: | ||
with open( | ||
file, encoding=self.config.encoding, errors=self.config.encoding_errors | ||
) as f: | ||
df = pandas_read_json(f) | ||
except ValueError: | ||
logger.error(f"Failed to load JSON from file '{file}' with error {type(e)}: {e}") | ||
raise e | ||
if df.columns.tolist() == [0]: | ||
df.columns = list(self.config.features) if self.config.features else ["text"] | ||
try: | ||
pa_table = pa.Table.from_pandas(df, preserve_index=False) | ||
except pa.ArrowInvalid as e: | ||
logger.error( | ||
f"Failed to convert pandas DataFrame to Arrow Table from file '{file}' with error {type(e)}: {e}" | ||
# Load JSON with field selection | ||
try: | ||
for batch_idx, json_obj in enumerate( | ||
ijson.items( | ||
open( | ||
file, | ||
encoding=self.config.encoding, | ||
errors=self.config.encoding_errors, | ||
), | ||
self.config.field, | ||
) | ||
): | ||
pa_table = pa.Table.from_pandas(pd.DataFrame(json_obj)) | ||
|
||
if self.config.columns is not None: | ||
missing_cols = set(self.config.columns) - set( | ||
pa_table.column_names | ||
) | ||
for col in missing_cols: | ||
pa_table = pa_table.append_column( | ||
col, pa.array([None] * pa_table.num_rows) | ||
) | ||
raise ValueError( | ||
f"Failed to convert pandas DataFrame to Arrow Table from file {file}." | ||
) from None | ||
yield file_idx, self._cast_table(pa_table) | ||
break | ||
pa_table = pa_table.select(self.config.columns) | ||
|
||
yield (file_idx, batch_idx), self._cast_table(pa_table) | ||
batch_idx += 1 | ||
|
||
except Exception as e: | ||
raise DatasetGenerationError( | ||
f"Failed to parse JSON with field {self.config.field}: {e}" | ||
) from e | ||
|
||
else: | ||
# Load JSON line by line | ||
batch_idx = 0 | ||
while True: | ||
try: | ||
pa_table = paj.read_json( | ||
file, | ||
read_options=paj.ReadOptions( | ||
use_threads=True, | ||
block_size=1 << 20, | ||
), | ||
parse_options=paj.ParseOptions(explicit_schema=None), | ||
) | ||
break | ||
|
||
except pa.ArrowInvalid: | ||
# Pandas fallback only if Arrow fails | ||
with open( | ||
file, | ||
encoding=self.config.encoding, | ||
errors=self.config.encoding_errors, | ||
) as f: | ||
df = pandas_read_json(f) | ||
pa_table = pa.Table.from_pandas(df) | ||
break | ||
|
||
except StopIteration: | ||
# End of file | ||
return | ||
|
||
# Apply columns selection after table is ready | ||
if self.config.columns is not None: | ||
missing_cols = set(self.config.columns) - set( | ||
pa_table.column_names | ||
) | ||
for col in missing_cols: | ||
pa_table = pa_table.append_column( | ||
col, pa.array([None] * pa_table.num_rows) | ||
) | ||
pa_table = pa_table.select(self.config.columns) | ||
|
||
yield (file_idx, batch_idx), self._cast_table(pa_table) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why assuming that files is a list of strings ? it isn't
and overall this PR is pretty hard to review since you're doing a lot if small and unnecessary changes, I'd suggest opening a new PR and try to do minimal changes instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure