Skip to content

Conversation

@BirkhoffG
Copy link
Owner

@BirkhoffG BirkhoffG commented May 26, 2025

This fixes #38

Summary

  • Add collate_fn parameter to DataLoader for custom batch collation functionality
  • Implement support across all three backends (JAX, PyTorch, TensorFlow)
  • Add comprehensive test coverage and documentation

Changes

  • Core DataLoader: Added collate_fn parameter with proper type annotations
  • JAX Backend: Applies collate_fn to batched data from dataset indexing
  • PyTorch Backend: Uses custom collate_fn or defaults to _numpy_collate
  • TensorFlow Backend: Uses tf.data.Dataset.map() for collate_fn application
  • Type System: Added Callable import to typing imports
  • Testing: Added test_collate_fn function in tests.ipynb with cross-backend validation
  • Documentation: Added usage example in core.ipynb

Behavior

The collate_fn parameter behaves consistently with PyTorch's DataLoader, allowing users to customize how individual samples are combined into batches. When collate_fn=None, each backend uses its default collation behavior.

Test plan

  • Test collate_fn functionality with JAX backend
  • Test collate_fn functionality with PyTorch backend
  • Test collate_fn functionality with TensorFlow backend
  • Test that collate_fn=None uses default behavior
  • Test custom transformation functions (e.g., adding constants to features)
  • Verify backward compatibility with existing code
  • Run nbdev_test to ensure all tests pass

- Add collate_fn parameter to core DataLoader class for custom batch collation
- Implement collate_fn support across all backends (JAX, PyTorch, TensorFlow)
- JAX backend applies collate_fn to batched data from dataset indexing
- PyTorch backend uses custom collate_fn or defaults to _numpy_collate
- TensorFlow backend uses tf.data.Dataset.map() for collate_fn application
- Add Callable import to typing imports for proper type annotations
- Add comprehensive test_collate_fn function in tests.ipynb
- Update BaseDataLoader interface to include collate_fn parameter
- Add documentation example in core.ipynb demonstrating usage

The collate_fn parameter behaves consistently with PyTorch's DataLoader,
allowing users to customize how individual samples are combined into batches.
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

BirkhoffG added 3 commits May 26, 2025 16:29
- Fix test_collate_fn to properly handle PyTorch list format vs JAX batched format
- Add support for HuggingFace datasets with PyTorch backend (list of dicts)
- Ensure collate_fn works correctly across all backends (JAX, PyTorch, TensorFlow)
- All collate_fn tests now pass for available backends
- Add tf_collate_wrapper to handle TensorFlow's argument unpacking behavior
- TensorFlow map() calls functions with unpacked args (features, labels)
- Wrapper packs them back into tuple format expected by collate_fn
- Ensures result is properly unpacked for TensorFlow consumption
- Fixes TypeError in CI when TensorFlow is installed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support collate_fn for jdl.DataLoader

1 participant