Skip to content

bowang-lab/JSC

Repository files navigation

JSC

JSC is a Python-based project that extends the nnUNet framework with joint segmentation and classification capabilities. It combines the power of nnUNet's segmentation architecture with binary / multi-class classification, making it ideal for medical imaging tasks that require both pixel-level segmentation and image-level classification predictions.

Features

  • Joint Architecture: Simultaneous segmentation and classification in a single model
  • Stratified Data Splitting: Advanced data splitting with demographic stratification (age, gender)
  • Multi-Modal Input: Support for multi-channel medical images (CT, PET, MRI)
  • Flexible Classification: Support for both binary and multi-class classification tasks
  • Comprehensive Inference: Batch processing with sliding window prediction and test-time augmentation

Installation

To set up JSC, first install the required dependencies:

pip install wandb
pip install torchmetrics
pip install -e .

Configure nnUNet Paths

Before using JSC, you need to configure the nnUNet environment paths. Modify the paths in nnunetv2/paths.py to point to your desired directories:

# Edit nnunetv2/paths.py
nnUNet_raw = "/path/to/your/nnUNet_raw"
nnUNet_preprocessed = "/path/to/your/nnUNet_preprocessed" 
nnUNet_results = "/path/to/your/nnUNet_results"

Usage

1. Prepare Data

Preprocessing with nnUNet

JSC relies on nnUNet’s preprocessing pipeline to standardize image spacing, intensity normalization, and patch extraction. Preprocessing must be completed before training or inference.

1. Required Data Structure

Your dataset must follow the nnUNet folder convention:

nnUNet_raw/
└── Dataset<ID>_<NAME>/
    ├── dataset.json                # Dataset configuration file
    ├── imagesTr/                   # Training images
    │   ├── PatientID_0000.nii.gz   # Modality 1
    │   ├── PatientID_0001.nii.gz   # Modality 2
    │   └── ...
    ├── labelsTr/                   # Training labels
    │   ├── PatientID.nii.gz        # Matches PatientID (no _0000)
    │   └── ...
    ├── imagesTs/                   # Test images (Inference)
    │   ├── TestID_0000.nii.gz
    │   └── ...
    └── labelsTs/                   # Test labels (Optional/Evaluation)
        ├── TestID.nii.gz           # Matches TestID (no _0000)
        └── ...

Notes:

  • Each modality is indexed as _0000, _0001, etc.
  • Segmentation labels must have the same base name as the training images (without modality suffix).
  • dataset.json defines modalities, labels, and dataset splits.

2. Default Preprocessing

Run the standard nnUNet preprocessing:

nnUNetv2_plan_and_preprocess -d <DATASET_ID> -c 3d_fullres --verify_dataset_integrity

For Res Encoder

nnUNetv2_plan_experiment -d <DATASET_ID> -pl nnUNetPlannerResEncM #nnUNetPlannerResEncL / nnUNetPlannerResEncXL

Use generate_cls_data.py to create stratified train/validation/test splits from your clinical dataset:

python generate_cls_data.py \
    --input_path /path/to/clinical_data.csv \
    --output_path /path/to/output/folder \
    --identifier_column PatientID \
    --label_column diagnosis

Arguments:

  • --input_path, -i: Path to CSV/Excel file containing clinical and imaging information
  • --output_path, -o: Directory to save classification data and splits
  • --identifier_column, -id: Column name for patient identifiers (default: 'patient_id')
  • --label_column, -label: Column name for classification labels (default: 'label')

Required CSV columns:

  • Patient identifiers (e.g., 'PatientID')
  • Classification labels
  • Age_at_StudyDate: For age-based stratification
  • Gender: For gender-based stratification

Outputs:

  • cls_data.csv: Classification dataset
  • test_data.csv: Held-out test set (20% of data)
  • splits_final.json: 5-fold cross-validation splits with stratification
  • Automatic filtering of cases without segmentation data Notes: make sure cls_data.csv and splits_final.json are under nnUNet_preprocessed

2. Training

nnUNetv2_train 161 3d_fullres 0 -tr <TrainerName>
nnUNetv2_train 714 3d_fullres all -p nnUNetResEncUNetMPlans

Notes:

3. Inference

Run joint segmentation and classification inference on NIfTI images:

python segcls_ensemble_infer.py \
    --input_path /path/to/input/images/ \
    --output_path /path/to/output/ \
    --model_path /path/to/trained/model \
    --fold 0 \ 
    --checkpoint checkpoint_best.pth \
    --device cuda \
    --cls_mode mean

For 5-fold ensemble fold should set as (0,1,2,3,4)

Arguments:

  • --input_path, -i: Directory containing input NIfTI images (expects *_000X.nii.gz naming convention)
  • --output_path, -o: Directory to save segmentation masks and classification results
  • --model_path: Path to trained nnUNet model directory
  • --fold: Fold number or 'all' for ensemble prediction (default: 'all')
  • --checkpoint: Checkpoint filename (default: 'checkpoint_best.pth')
  • --use_softmax: Apply softmax to segmentation output (default: False)
  • --device: Computing device ('cuda' or 'cpu', default: 'cuda')
  • --cls_mode: Classification aggregation mode ('mean' or 'weighted', default: 'mean')

Input Format: Images should follow nnUNet naming convention:

  • PatientID_0000.nii.gz (first modality)
  • PatientID_0001.nii.gz (second modality)
  • etc.

Outputs:

  • {PatientID}.nii.gz: Segmentation masks for each case
  • results.csv: Classification probabilities for all cases

4. Key Features

Stratified Cross-Validation:

  • Creates balanced splits based on age quartiles, gender, and target labels
  • Ensures representative distribution across all folds
  • 80/20 train-test split with 5-fold cross-validation on training data

Advanced Inference:

  • Sliding window prediction with Gaussian weighting
  • Test-time augmentation with mirroring
  • Multi-fold model ensembling
  • Memory-efficient processing for large images
  • Automatic batch processing of multiple cases

Classification Modes:

  • mean: Average classification scores across all patches
  • weighted: Weight classification by segmentation confidence

Model Architecture

The framework extends nnUNet with:

  • Shared encoder for both segmentation and classification
  • Dual output heads (segmentation + classification)
  • Feature aggregation from the final encoder stage
  • Support for both binary and multi-class classification

License

This project is licensed under the Apache License 2.0.

Citation

If you use JSC in your research, please cite the original nnUNet paper and this extension.


Note: Ensure your input data follows the nnUNet preprocessing requirements and naming conventions for optimal performance.

About

Joint segmentation and clasification for 3D medical images

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published