Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8f76d87
:art: Refactor `ModelABC` to Help Use Default Torch Models
shaneahmed Sep 20, 2024
76c7972
:white_check_mark: Fix test
shaneahmed Sep 20, 2024
cc6c1c5
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Nov 8, 2024
0ade741
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Nov 20, 2024
85c72bf
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Nov 21, 2024
e25a122
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Nov 22, 2024
405ad61
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Jan 24, 2025
1b80e9a
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Feb 7, 2025
219c17e
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Mar 7, 2025
274c16b
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Mar 14, 2025
e06d92c
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Mar 21, 2025
1a21b78
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Apr 4, 2025
453b78f
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Apr 11, 2025
52d7a4a
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed May 23, 2025
68ad398
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Jun 13, 2025
5e40533
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Jun 20, 2025
bd766cd
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Jul 4, 2025
9a3d11d
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Jul 11, 2025
a89e529
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Aug 8, 2025
82d6a7c
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Aug 15, 2025
239ca43
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Sep 5, 2025
a57281e
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Oct 3, 2025
16e8bc5
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Oct 10, 2025
a470afb
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-to…
shaneahmed Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/models/test_arch_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch

from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel
from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel, infer_batch
from tiatoolbox.models.models_abc import model_to

ON_GPU = False
Expand Down Expand Up @@ -45,7 +45,7 @@ def test_functional() -> None:
for backbone in backbones:
model = CNNModel(backbone, num_classes=1)
model_ = model_to(device=device, model=model)
model.infer_batch(model_, samples, device=device)
infer_batch(model_, samples, device=device)
except ValueError as exc:
msg = f"Model {backbone} failed."
raise AssertionError(msg) from exc
Expand All @@ -72,7 +72,7 @@ def test_timm_functional() -> None:
for backbone in backbones:
model = TimmModel(backbone=backbone, num_classes=1, pretrained=False)
model_ = model_to(device=device, model=model)
model.infer_batch(model_, samples, device=device)
infer_batch(model_, samples, device=device)
except ValueError as exc:
msg = f"Model {backbone} failed."
raise AssertionError(msg) from exc
Expand Down
34 changes: 34 additions & 0 deletions tiatoolbox/models/architecture/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,40 @@ def _get_architecture(
return model.features


def infer_batch(
model: nn.Module,
batch_data: torch.Tensor,
*,
device: str = "cpu",
) -> dict[str, np.ndarray]:
"""Run inference on an input batch.

Contains logic for forward operation as well as i/o aggregation.

Args:
model (nn.Module):
PyTorch defined model.
batch_data (torch.Tensor):
A batch of data generated by
`torch.utils.data.DataLoader`.
device (str):
Transfers model to the specified device. Default is "cpu".

"""
img_patches_device = batch_data.to(device).type(
torch.float32,
) # to NCHW
img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous()

# Inference mode
model.eval()
# Do not compute the gradient (not training)
with torch.inference_mode():
output = model(img_patches_device)
# Output should be a single tensor or scalar
return {"probabilities": output.cpu().numpy()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current develop branch, neither CNNModel, nor CNNBackbone returned dictionaries as output of their infer_batch() methods. Also, CNNModel currently returns an array, while CNNBackbone returns a list with the array. It might be fine, just wanted to highlight this.

CNNModel

return output.cpu().numpy()

CNNBackbone

return [output.cpu().numpy()]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. We are aware of this. Our preference is to use torch nn models but to generalise for multi modal output we may need dictionaries. This PR is to check if we can move to generic torch models or we will need a sub class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.



def _get_timm_architecture(
arch_name: str,
*,
Expand Down
3 changes: 2 additions & 1 deletion tiatoolbox/models/engine/engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tiatoolbox import DuplicateFilter, logger, rcParam
from tiatoolbox.models.architecture import get_pretrained_model
from tiatoolbox.models.architecture.utils import compile_model
from tiatoolbox.models.architecture.vanilla import infer_batch
from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset
from tiatoolbox.models.models_abc import load_torch_model
from tiatoolbox.utils.misc import (
Expand Down Expand Up @@ -560,7 +561,7 @@ def infer_patches(
zarr_group = zarr.open(save_path, mode="w")

for _, batch_data in enumerate(dataloader):
batch_output = self.model.infer_batch(
batch_output = infer_batch(
self.model,
batch_data["image"],
device=self.device,
Expand Down
24 changes: 0 additions & 24 deletions tiatoolbox/models/models_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,30 +101,6 @@ def forward(
"""Torch method, this contains logic for using layers defined in init."""
... # pragma: no cover

@staticmethod
@abstractmethod
def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> dict:
"""Run inference on an input batch.

Contains logic for forward operation as well as I/O aggregation.

Args:
model (nn.Module):
PyTorch defined model.
batch_data (np.ndarray):
A batch of data generated by
`torch.utils.data.DataLoader`.
device (str):
Transfers model to the specified device. Default is "cpu".

Returns:
dict:
Returns a dictionary of predictions and other expected outputs
depending on the network architecture.

"""
... # pragma: no cover

@staticmethod
def preproc(image: np.ndarray) -> np.ndarray:
"""Define the pre-processing of this class of model."""
Expand Down
Loading