Add collate_fn parameter support to DataLoader #48
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This fixes #38
Summary
collate_fnparameter toDataLoaderfor custom batch collation functionalityChanges
collate_fnparameter with proper type annotationscollate_fnto batched data from dataset indexingcollate_fnor defaults to_numpy_collatetf.data.Dataset.map()forcollate_fnapplicationCallableimport to typing importstest_collate_fnfunction intests.ipynbwith cross-backend validationcore.ipynbBehavior
The
collate_fnparameter behaves consistently with PyTorch's DataLoader, allowing users to customize how individual samples are combined into batches. Whencollate_fn=None, each backend uses its default collation behavior.Test plan
collate_fnfunctionality with JAX backendcollate_fnfunctionality with PyTorch backendcollate_fnfunctionality with TensorFlow backendcollate_fn=Noneuses default behaviornbdev_testto ensure all tests pass