-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Description
Is your feature request related to a problem? Please describe.
I’m working on 3D volumetric data (medical OCT volumes) and would like to use timm as a unified model zoo across my projects. However, as far as I can tell from the documentation and model list, timm currently focuses on 2D image models and I don’t see any officially supported 3D CNN backbones or APIs for volumetric data (e.g. tensors shaped like [B, C, D, H, W] or [B, C, T, H, W]).
This leaves a gap for users who would like to use timm-style APIs and architectures (ResNet, ConvNeXt, EfficientNet, ViT, etc.) directly on 3D data (volumes or videos) instead of only 2D images. If I’m missing existing 3D support, I would really appreciate clarification on the current status and recommended usage patterns
Describe the solution you'd like
There are two closely related things I’d like to ask for:
Clarification of current 3D support
Is there any existing or experimental 3D model support in timm (e.g. 3D variants of ResNet/ConvNeXt, or a recommended way to use timm backbones for 3D inputs)?
If yes, could this be documented more explicitly (which models, expected input shapes, example usage)?
Official 3D model support (if it does not exist yet)
Provide a small but representative set of 3D backbones with the same API style as timm, for example:
3D ResNet family (e.g. R3D-18 style),
3D ConvNeXt / EfficientNet style backbones,
Possibly a generic “dimension-agnostic” implementation where spatial_dims=2/3 could be chosen.
Add helper functions or flags, e.g. create_model(model_name, spatial_dims=3, in_chans=1, num_classes=...) or clearly separate 2D vs 3D model names (e.g. "resnet50_3d", "convnext_tiny_3d").
Include at least one simple example in the docs or notebooks showing how to train a 3D classifier on [B, C, D, H, W] inputs.
This would allow users working on 3D medical imaging, video classification, and other volumetric tasks to rely on timm as a single, consistent model zoo.
Describe alternatives you've considered
Using 2D timm models on slices/frames:
Treat the depth/time dimension as extra slices, feed each slice/frame into a 2D timm backbone, and then aggregate features. This works but:
Loses native 3D spatial modeling,
Requires custom aggregation code and is less elegant than having proper 3D convolutional backbones,
Makes it harder to reuse pre-defined training scripts or configs.
Using external libraries/frameworks for 3D models (e.g. medical imaging frameworks or community 3D forks built “on top of timm”):
These often have their own APIs and may not stay fully in sync with the latest timm architectures and weights.
It fragments the workflow: 2D tasks use timm, 3D tasks use a separate ecosystem, instead of a unified interface.
Because timm has become a de facto standard for image backbones in PyTorch, having first-class 3D support (or at least a clearly documented position on it) would be extremely valuable.
Additional context