diff --git a/.github/workflows/prepare_test_data.yaml b/.github/workflows/prepare_test_data.yaml index 29a7522c..e694d2de 100644 --- a/.github/workflows/prepare_test_data.yaml +++ b/.github/workflows/prepare_test_data.yaml @@ -49,7 +49,7 @@ jobs: # the Visium HD dataset is licensed as CC BY 4.0, as shown here # https://www.10xgenomics.com/support/software/space-ranger/latest/resources/visium-hd-example-data - # 10x Genomics Visium HD 4.0.1 3' Mouse Brain + # 10x Genomics Visium HD 4.0.1 3' Mouse Brain Chunk curl -O https://cf.10xgenomics.com/samples/spatial-exp/4.0.1/Visium_HD_Tiny_3prime_Dataset/Visium_HD_Tiny_3prime_Dataset_outs.zip # ------- diff --git a/pyproject.toml b/pyproject.toml index f93e46b1..19ca5a0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "joblib", "imagecodecs", "dask-image", - "pyarrow", + "pyarrow<22.0.0", # https://github.com/scverse/spatialdata-io/issues/334 "readfcs", "tifffile>=2023.8.12", "ome-types", diff --git a/src/spatialdata_io/__main__.py b/src/spatialdata_io/__main__.py index c65e46f4..40c45518 100644 --- a/src/spatialdata_io/__main__.py +++ b/src/spatialdata_io/__main__.py @@ -412,11 +412,24 @@ def visium_wrapper( default=False, help="If true, annotates the table by labels. [default: False]", ) +@click.option( + "--load-segmentations-only", + default=None, + help="If `True`, only the segmented cell boundaries and their associated counts will be loaded. All binned data will be skipped. [default: None, which will fall back to `False` with a deprecation warning]", +) +@click.option( + "--load-nucleus-segmentations", + type=bool, + default=False, + help="If `True` and nucleus segmentation files are present, load nucleus segmentation polygons and the corresponding nucleus-filtered count table. [default: False]", +) def visium_hd_wrapper( input: str, output: str, dataset_id: str | None = None, filtered_counts_file: bool = True, + load_segmentations_only: bool | None = None, + load_nucleus_segmentations: bool = False, bin_size: int | list[int] | None = None, bins_as_squares: bool = True, fullres_image_file: str | Path | None = None, @@ -428,6 +441,8 @@ def visium_hd_wrapper( path=input, dataset_id=dataset_id, filtered_counts_file=filtered_counts_file, + load_segmentations_only=load_segmentations_only, + load_nucleus_segmentations=load_nucleus_segmentations, bin_size=bin_size, bins_as_squares=bins_as_squares, fullres_image_file=fullres_image_file, diff --git a/src/spatialdata_io/_constants/_constants.py b/src/spatialdata_io/_constants/_constants.py index c2f6c21f..f9e6f8b8 100644 --- a/src/spatialdata_io/_constants/_constants.py +++ b/src/spatialdata_io/_constants/_constants.py @@ -356,11 +356,16 @@ class VisiumHDKeys(ModeEnum): BIN_PREFIX = "square_" MICROSCOPE_IMAGE = "microscope_image" BINNED_OUTPUTS = "binned_outputs" + SEGMENTATION_OUTPUTS = "segmented_outputs" # counts and locations files FILTERED_COUNTS_FILE = "filtered_feature_bc_matrix.h5" RAW_COUNTS_FILE = "raw_feature_bc_matrix.h5" TISSUE_POSITIONS_FILE = "tissue_positions.parquet" + BARCODE_MAPPINGS_FILE = "barcode_mappings.parquet" + FILTERED_CELL_COUNTS_FILE = "filtered_feature_cell_matrix.h5" + CELL_SEGMENTATION_GEOJSON_PATH = "cell_segmentations.geojson" + NUCLEUS_SEGMENTATION_GEOJSON_PATH = "nucleus_segmentations.geojson" # images IMAGE_HIRES_FILE = "tissue_hires_image.png" @@ -402,3 +407,7 @@ class VisiumHDKeys(ModeEnum): MICROSCOPE_COLROW_TO_SPOT_COLROW = ("microscope_colrow_to_spot_colrow",) SPOT_COLROW_TO_MICROSCOPE_COLROW = ("spot_colrow_to_microscope_colrow",) FILE_FORMAT = "file_format" + + # Cell Segmentation keys + CELL_SEG_KEY_HD = "cell_segmentations" + NUCLEUS_SEG_KEY_HD = "nucleus_segmentations" diff --git a/src/spatialdata_io/readers/visium_hd.py b/src/spatialdata_io/readers/visium_hd.py index 6a5f3bf6..27852d2c 100644 --- a/src/spatialdata_io/readers/visium_hd.py +++ b/src/spatialdata_io/readers/visium_hd.py @@ -10,11 +10,14 @@ import h5py import numpy as np import pandas as pd +import pyarrow.parquet as pq import scanpy as sc from dask_image.imread import imread from geopandas import GeoDataFrame from imageio import imread as imread2 from numpy.random import default_rng +from scipy.sparse import csc_matrix +from shapely.geometry import Polygon from skimage.transform import ProjectiveTransform, warp from spatialdata import ( SpatialData, @@ -32,6 +35,7 @@ if TYPE_CHECKING: from collections.abc import Mapping + from anndata import AnnData from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage from spatialdata._types import ArrayLike @@ -44,6 +48,8 @@ def visium_hd( path: str | Path, dataset_id: str | None = None, filtered_counts_file: bool = True, + load_segmentations_only: bool | None = None, + load_nucleus_segmentations: bool = False, bin_size: int | list[int] | None = None, bins_as_squares: bool = True, annotate_table_by_labels: bool = False, @@ -56,10 +62,6 @@ def visium_hd( ) -> SpatialData: """Read *10x Genomics* Visium HD formatted dataset. - .. seealso:: - - - `Space Ranger output `_. - Parameters ---------- path @@ -70,6 +72,17 @@ def visium_hd( filtered_counts_file It sets the value of `counts_file` to ``{vx.FILTERED_COUNTS_FILE!r}`` (when `True`) or to ``{vx.RAW_COUNTS_FILE!r}`` (when `False`). + load_segmentations_only + If `True`, only the segmented cell boundaries and their associated counts will be loaded. All binned data + will be skipped. If `False`, only the binned data will be loaded (which is consistent with legacy behavior). + If `None` (default), it will be equivalent to `False`, but a deprecation warning will be raised to inform users that + in future releases the default value will be changed to `True`. To avoid the warning, explicitly set this parameter to + `False` or `True`. + load_nucleus_segmentations + If `True` and nucleus segmentation files are present, load nucleus segmentation polygons and the corresponding + nucleus-filtered count table. The counts are aggregated from the 2 µm binned matrix using the provided + barcode mappings so that only bins under segmented nuclei contribute to each cell’s counts. Requires all of: + nucleus segmentation GeoJSON, barcode_mappings.parquet, and the 2 µm filtered_feature_bc_matrix.h5. bin_size When specified, load the data of a specific bin size, or a list of bin sizes. By default, it loads all the available bin sizes. @@ -105,6 +118,37 @@ def visium_hd( images: dict[str, Any] = {} labels: dict[str, Any] = {} + # Deprecation warning for load_segmentations_only default value + if not load_segmentations_only: + warnings.warn( + "`load_segmentations_only` default value will change to `True` in future releases. Please set it " + "explicitly to `True` or `False` to avoid this warning.", + FutureWarning, + stacklevel=2, + ) + + # Check for segmentation files + SEGMENTED_OUTPUTS_PATH = path / VisiumHDKeys.SEGMENTATION_OUTPUTS + COUNT_MATRIX_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.FILTERED_CELL_COUNTS_FILE + CELL_GEOJSON_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.CELL_SEGMENTATION_GEOJSON_PATH + NUCLEUS_GEOJSON_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.NUCLEUS_SEGMENTATION_GEOJSON_PATH + SCALE_FACTORS_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.SPATIAL / VisiumHDKeys.SCALEFACTORS_FILE + BARCODE_MAPPINGS_PATH = next( + (file for file in path.rglob("*") if file.name.endswith(VisiumHDKeys.BARCODE_MAPPINGS_FILE)), + None, + ) + FILTERED_MATRIX_2U_PATH = ( + path / VisiumHDKeys.BINNED_OUTPUTS / f"{VisiumHDKeys.BIN_PREFIX}002um" / VisiumHDKeys.FILTERED_COUNTS_FILE + ) + cell_segmentation_files_exist = ( + COUNT_MATRIX_PATH.exists() and CELL_GEOJSON_PATH.exists() and SCALE_FACTORS_PATH.exists() + ) + nucleus_segmentation_files_exist = ( + NUCLEUS_GEOJSON_PATH.exists() + and (BARCODE_MAPPINGS_PATH is not None and BARCODE_MAPPINGS_PATH.exists()) + and FILTERED_MATRIX_2U_PATH.exists() + ) + if dataset_id is None: dataset_id = _infer_dataset_id(path) @@ -131,154 +175,227 @@ def load_image(path: Path, suffix: str, scale_factors: list[int] | None = None) stacklevel=2, ) - def _get_bins(path_bins: Path) -> list[str]: - return sorted( + with open(SCALE_FACTORS_PATH) as file: + scalefactors = json.load(file) + + transform_lowres = Scale( + np.array( [ - bin_size.name - for bin_size in path_bins.iterdir() - if bin_size.is_dir() and bin_size.name.startswith(VisiumHDKeys.BIN_PREFIX) + scalefactors[VisiumHDKeys.SCALEFACTORS_LOWRES], + scalefactors[VisiumHDKeys.SCALEFACTORS_LOWRES], ] - ) + ), + axes=("x", "y"), + ) + transform_hires = Scale( + np.array( + [ + scalefactors[VisiumHDKeys.SCALEFACTORS_HIRES], + scalefactors[VisiumHDKeys.SCALEFACTORS_HIRES], + ] + ), + axes=("x", "y"), + ) + # scaling + transform_original = Identity() + transformations = { + dataset_id: transform_original, + f"{dataset_id}_downscaled_hires": transform_hires, + f"{dataset_id}_downscaled_lowres": transform_lowres, + } + + # Load Binned Data + if not load_segmentations_only: + + def _get_bins(path_bins: Path) -> list[str]: + return sorted( + [ + bin_size.name + for bin_size in path_bins.iterdir() + if bin_size.is_dir() and bin_size.name.startswith(VisiumHDKeys.BIN_PREFIX) + ] + ) - all_path_bins = [path_bin for path_bin in all_files if VisiumHDKeys.BINNED_OUTPUTS in str(path_bin)] - if len(all_path_bins) != 0: - path_bins_parts = all_path_bins[ - -1 - ].parts # just choosing last one here as users might have tar file which would be first - path_bins = Path(*path_bins_parts[: path_bins_parts.index(VisiumHDKeys.BINNED_OUTPUTS) + 1]) - else: - path_bins = path - all_bin_sizes = _get_bins(path_bins) - - bin_sizes = [] - if bin_size is not None: - if not isinstance(bin_size, list): - bin_size = [bin_size] - bin_sizes = [f"square_{bs:03}um" for bs in bin_size if f"square_{bs:03}um" in all_bin_sizes] - if len(bin_sizes) < len(bin_size): - warnings.warn( - f"Requested bin size {bin_size} (available {all_bin_sizes}); ignoring the bin sizes that are not " - "found.", - UserWarning, - stacklevel=2, + all_path_bins = [path_bin for path_bin in all_files if VisiumHDKeys.BINNED_OUTPUTS in str(path_bin)] + if len(all_path_bins) != 0: + path_bins_parts = all_path_bins[ + -1 + ].parts # just choosing last one here as users might have tar file which would be first + path_bins = Path(*path_bins_parts[: path_bins_parts.index(VisiumHDKeys.BINNED_OUTPUTS) + 1]) + else: + path_bins = path + all_bin_sizes = _get_bins(path_bins) + + bin_sizes = [] + if bin_size is not None and (isinstance(bin_size, int) or len(bin_size) > 0): + if not isinstance(bin_size, list): + bin_size = [bin_size] + bin_sizes = [f"square_{bs:03}um" for bs in bin_size if f"square_{bs:03}um" in all_bin_sizes] + if len(bin_sizes) < len(bin_size): + warnings.warn( + f"Requested bin size {bin_size} (available {all_bin_sizes}); ignoring the bin sizes that are not " + "found.", + UserWarning, + stacklevel=2, + ) + if bin_size is None or bin_sizes == []: + bin_sizes = all_bin_sizes + + # iterate over the given bins and load the data + for bin_size_str in bin_sizes: + path_bin = path_bins / bin_size_str + counts_file = VisiumHDKeys.FILTERED_COUNTS_FILE if filtered_counts_file else VisiumHDKeys.RAW_COUNTS_FILE + adata = sc.read_10x_h5( + path_bin / counts_file, + gex_only=False, + **anndata_kwargs, ) - if bin_size is None or bin_sizes == []: - bin_sizes = all_bin_sizes - - # iterate over the given bins and load the data - for bin_size_str in bin_sizes: - path_bin = path_bins / bin_size_str - counts_file = VisiumHDKeys.FILTERED_COUNTS_FILE if filtered_counts_file else VisiumHDKeys.RAW_COUNTS_FILE - adata = sc.read_10x_h5( - path_bin / counts_file, - gex_only=False, - **anndata_kwargs, - ) - path_bin_spatial = path_bin / VisiumHDKeys.SPATIAL + path_bin_spatial = path_bin / VisiumHDKeys.SPATIAL + + # The scale factors of binned data are consistent to the global ones + # (which are already loaded in "scalefactors"), but the json file in the + # binned spatial folder contains also the bin size information + with open(path_bin_spatial / VisiumHDKeys.SCALEFACTORS_FILE) as file: + scalefactors_bins = json.load(file) + + # consistency check + found_bin_size = re.search(r"\d{3}", bin_size_str) + assert found_bin_size is not None + assert float(found_bin_size.group()) == scalefactors_bins[VisiumHDKeys.SCALEFACTORS_BIN_SIZE_UM] + assert np.isclose( + scalefactors_bins[VisiumHDKeys.SCALEFACTORS_BIN_SIZE_UM] + / scalefactors_bins[VisiumHDKeys.SCALEFACTORS_SPOT_DIAMETER_FULLRES], + scalefactors_bins[VisiumHDKeys.SCALEFACTORS_MICRONS_PER_PIXEL], + ) - with open(path_bin_spatial / VisiumHDKeys.SCALEFACTORS_FILE) as file: - scalefactors = json.load(file) + tissue_positions_file = path_bin_spatial / VisiumHDKeys.TISSUE_POSITIONS_FILE + + # read coordinates and set up adata.obs and adata.obsm + coords = pd.read_parquet(tissue_positions_file) + assert all( + coords.columns.values + == [ + VisiumHDKeys.BARCODE, + VisiumHDKeys.IN_TISSUE, + VisiumHDKeys.ARRAY_ROW, + VisiumHDKeys.ARRAY_COL, + VisiumHDKeys.LOCATIONS_Y, + VisiumHDKeys.LOCATIONS_X, + ] + ) + coords.set_index(VisiumHDKeys.BARCODE, inplace=True, drop=True) + coords_filtered = coords.loc[adata.obs.index] + adata.obs = pd.merge( + adata.obs, + coords_filtered, + how="left", + left_index=True, + right_index=True, + ) + # compatibility to legacy squidpy + adata.obsm["spatial"] = adata.obs[[VisiumHDKeys.LOCATIONS_X, VisiumHDKeys.LOCATIONS_Y]].values + # dropping the spatial coordinates (will be stored in shapes) + adata.obs.drop( + columns=[ + VisiumHDKeys.LOCATIONS_X, + VisiumHDKeys.LOCATIONS_Y, + ], + inplace=True, + ) + adata.obs[VisiumHDKeys.INSTANCE_KEY] = np.arange(len(adata)) - # consistency check - found_bin_size = re.search(r"\d{3}", bin_size_str) - assert found_bin_size is not None - assert float(found_bin_size.group()) == scalefactors[VisiumHDKeys.SCALEFACTORS_BIN_SIZE_UM] - assert np.isclose( - scalefactors[VisiumHDKeys.SCALEFACTORS_BIN_SIZE_UM] - / scalefactors[VisiumHDKeys.SCALEFACTORS_SPOT_DIAMETER_FULLRES], - scalefactors[VisiumHDKeys.SCALEFACTORS_MICRONS_PER_PIXEL], + # parse shapes + shapes_name = dataset_id + "_" + bin_size_str + radius = scalefactors_bins[VisiumHDKeys.SCALEFACTORS_SPOT_DIAMETER_FULLRES] / 2.0 + circles = ShapesModel.parse( + adata.obsm["spatial"], + geometry=0, + radius=radius, + index=adata.obs[VisiumHDKeys.INSTANCE_KEY].copy(), + transformations=transformations, + ) + if not bins_as_squares: + shapes[shapes_name] = circles + else: + squares_series = circles.buffer(radius, cap_style=3) + shapes[shapes_name] = ShapesModel.parse( + GeoDataFrame(geometry=squares_series), + transformations=transformations, + ) + + # parse table + adata.obs[VisiumHDKeys.REGION_KEY] = shapes_name + adata.obs[VisiumHDKeys.REGION_KEY] = adata.obs[VisiumHDKeys.REGION_KEY].astype("category") + + tables[bin_size_str] = TableModel.parse( + adata, + region=shapes_name, + region_key=str(VisiumHDKeys.REGION_KEY), + instance_key=str(VisiumHDKeys.INSTANCE_KEY), + ) + if var_names_make_unique: + tables[bin_size_str].var_names_make_unique() + + # Integrate the segmentation data (skipped if segmentation files are not found) + if cell_segmentation_files_exist: + print("Found segmentation data. Incorporating cell_segmentations.") + cell_adata_hd = sc.read_10x_h5(COUNT_MATRIX_PATH) + cell_adata_hd.var_names_make_unique() + + cell_shapes_gdf = _extract_geometries_from_geojson( + cell_adata_hd, + geojson_path=CELL_GEOJSON_PATH, ) - tissue_positions_file = path_bin_spatial / VisiumHDKeys.TISSUE_POSITIONS_FILE - - # read coordinates and set up adata.obs and adata.obsm - coords = pd.read_parquet(tissue_positions_file) - assert all( - coords.columns.values - == [ - VisiumHDKeys.BARCODE, - VisiumHDKeys.IN_TISSUE, - VisiumHDKeys.ARRAY_ROW, - VisiumHDKeys.ARRAY_COL, - VisiumHDKeys.LOCATIONS_Y, - VisiumHDKeys.LOCATIONS_X, - ] + SHAPES_KEY_HD = f"{dataset_id}_{VisiumHDKeys.CELL_SEG_KEY_HD}" + cell_adata_hd.obs["cell_id"] = cell_adata_hd.obs.index + cell_adata_hd.obs["region"] = SHAPES_KEY_HD + cell_adata_hd.obs["region"] = cell_adata_hd.obs["region"].astype("category") + cell_adata_hd = cell_adata_hd[cell_shapes_gdf.index].copy() + + shapes[SHAPES_KEY_HD] = ShapesModel.parse(cell_shapes_gdf, transformations=transformations) + tables[VisiumHDKeys.CELL_SEG_KEY_HD] = TableModel.parse( + cell_adata_hd, + region=SHAPES_KEY_HD, + region_key="region", + instance_key="cell_id", ) - coords.set_index(VisiumHDKeys.BARCODE, inplace=True, drop=True) - coords_filtered = coords.loc[adata.obs.index] - adata.obs = pd.merge(adata.obs, coords_filtered, how="left", left_index=True, right_index=True) - # compatibility to legacy squidpy - adata.obsm["spatial"] = adata.obs[[VisiumHDKeys.LOCATIONS_X, VisiumHDKeys.LOCATIONS_Y]].values - # dropping the spatial coordinates (will be stored in shapes) - adata.obs.drop( - columns=[ - VisiumHDKeys.LOCATIONS_X, - VisiumHDKeys.LOCATIONS_Y, - ], - inplace=True, - ) - adata.obs[VisiumHDKeys.INSTANCE_KEY] = np.arange(len(adata)) - # scaling - transform_original = Identity() - transform_lowres = Scale( - np.array( - [ - scalefactors[VisiumHDKeys.SCALEFACTORS_LOWRES], - scalefactors[VisiumHDKeys.SCALEFACTORS_LOWRES], - ] - ), - axes=("x", "y"), - ) - transform_hires = Scale( - np.array( - [ - scalefactors[VisiumHDKeys.SCALEFACTORS_HIRES], - scalefactors[VisiumHDKeys.SCALEFACTORS_HIRES], - ] - ), - axes=("x", "y"), - ) - # parse shapes - shapes_name = dataset_id + "_" + bin_size_str - radius = scalefactors[VisiumHDKeys.SCALEFACTORS_SPOT_DIAMETER_FULLRES] / 2.0 - transformations = { - dataset_id: transform_original, - f"{dataset_id}_downscaled_hires": transform_hires, - f"{dataset_id}_downscaled_lowres": transform_lowres, - } - circles = ShapesModel.parse( - adata.obsm["spatial"], - geometry=0, - radius=radius, - index=adata.obs[VisiumHDKeys.INSTANCE_KEY].copy(), - transformations=transformations, - ) - if not bins_as_squares: - shapes[shapes_name] = circles - else: - squares_series = circles.buffer(radius, cap_style=3) - shapes[shapes_name] = ShapesModel.parse( - GeoDataFrame(geometry=squares_series), transformations=transformations - ) + # load nucleus segmentations if available + if nucleus_segmentation_files_exist and load_nucleus_segmentations: + print("Found nucleus segmentation data. Incorporating nucleus_segmentations.") - # parse table - adata.obs[VisiumHDKeys.REGION_KEY] = shapes_name - adata.obs[VisiumHDKeys.REGION_KEY] = adata.obs[VisiumHDKeys.REGION_KEY].astype("category") + # we already ensure this by having nucleus_segmentation_files_exist True, but + # mypy is not able to infer that + assert BARCODE_MAPPINGS_PATH is not None - tables[bin_size_str] = TableModel.parse( - adata, - region=shapes_name, - region_key=str(VisiumHDKeys.REGION_KEY), - instance_key=str(VisiumHDKeys.INSTANCE_KEY), - ) - if var_names_make_unique: - tables[bin_size_str].var_names_make_unique() + nucleus_adata_hd = _make_filtered_nucleus_adata( + filtered_matrix_h5_path=FILTERED_MATRIX_2U_PATH, + barcode_mappings_parquet_path=BARCODE_MAPPINGS_PATH, + ) + nucleus_shapes_gdf = _extract_geometries_from_geojson( + adata=nucleus_adata_hd, geojson_path=NUCLEUS_GEOJSON_PATH + ) - # read full resolution image + SHAPES_KEY_HD = f"{dataset_id}_{VisiumHDKeys.NUCLEUS_SEG_KEY_HD}" + nucleus_adata_hd.obs["cell_id"] = nucleus_adata_hd.obs.index + nucleus_adata_hd.obs["region"] = SHAPES_KEY_HD + nucleus_adata_hd.obs["region"] = nucleus_adata_hd.obs["region"].astype("category") + nucleus_adata_hd = nucleus_adata_hd[nucleus_shapes_gdf.index].copy() + + shapes[SHAPES_KEY_HD] = ShapesModel.parse(nucleus_shapes_gdf, transformations=transformations) + tables[VisiumHDKeys.NUCLEUS_SEG_KEY_HD] = TableModel.parse( + nucleus_adata_hd, + region=SHAPES_KEY_HD, + region_key="region", + instance_key="cell_id", + ) + + # Read all images and add transformations + fullres_image_file_paths = [] if fullres_image_file is not None: - fullres_image_file = Path(fullres_image_file) + fullres_image_file_paths.append(Path(fullres_image_file)) else: path_fullres = path / VisiumHDKeys.MICROSCOPE_IMAGE if path_fullres.exists(): @@ -305,62 +422,65 @@ def _get_bins(path_bins: Path) -> list[str]: if fullres_image_file is not None: load_image( - path=fullres_image_file, + path=fullres_image_file_paths[0], suffix="_full_image", scale_factors=[2, 2, 2, 2], ) + else: + warnings.warn( + "No full resolution image found. If incorrect, please specify the path in the " + "`fullres_image_file` parameter when calling the `visium_hd` reader function.", + UserWarning, + stacklevel=2, + ) # hires image hires_image_path = [path for path in all_files if VisiumHDKeys.IMAGE_HIRES_FILE in str(path)] - if len(hires_image_path) == 0: + if len(hires_image_path) > 0: + load_image( + path=hires_image_path[0], + suffix="_hires_image", + ) + set_transformation( + images[dataset_id + "_hires_image"], + { + f"{dataset_id}_downscaled_hires": Identity(), + dataset_id: transform_hires.inverse(), + }, + set_all=True, + ) + else: warnings.warn( f"No image path found containing the hires image: {VisiumHDKeys.IMAGE_HIRES_FILE}", UserWarning, stacklevel=2, ) - load_image( - path=hires_image_path[0], - suffix="_hires_image", - ) - set_transformation( - images[dataset_id + "_hires_image"], - { - f"{dataset_id}_downscaled_hires": Identity(), - dataset_id: transform_hires.inverse(), - }, - set_all=True, - ) # lowres image lowres_image_path = [path for path in all_files if VisiumHDKeys.IMAGE_LOWRES_FILE in str(path)] - if len(lowres_image_path) == 0: + if len(lowres_image_path) > 0: + load_image( + path=lowres_image_path[0], + suffix="_lowres_image", + ) + set_transformation( + images[dataset_id + "_lowres_image"], + { + f"{dataset_id}_downscaled_lowres": Identity(), + dataset_id: transform_lowres.inverse(), + }, + set_all=True, + ) + else: warnings.warn( f"No image path found containing the lowres image: {VisiumHDKeys.IMAGE_LOWRES_FILE}", UserWarning, stacklevel=2, ) - load_image( - path=lowres_image_path[0], - suffix="_lowres_image", - ) - set_transformation( - images[dataset_id + "_lowres_image"], - { - f"{dataset_id}_downscaled_lowres": Identity(), - dataset_id: transform_lowres.inverse(), - }, - set_all=True, - ) # cytassist image cytassist_path = [path for path in all_files if VisiumHDKeys.IMAGE_CYTASSIST in str(path)] - if len(cytassist_path) == 0: - warnings.warn( - f"No image path found containing the cytassist image: {VisiumHDKeys.IMAGE_CYTASSIST}", - UserWarning, - stacklevel=2, - ) - if load_all_images: + if load_all_images and len(cytassist_path) > 0: load_image( path=cytassist_path[0], suffix="_cytassist_image", @@ -401,7 +521,10 @@ def _get_bins(path_bins: Path) -> list[str]: ) # the first two components are <= 0, we just discard them since the cytassist image has a lot of padding # and therefore we can safely discard pixels with negative coordinates - transformed_shape = (np.ceil(transformed_bounds[2]), np.ceil(transformed_bounds[3])) + transformed_shape = ( + np.ceil(transformed_bounds[2]), + np.ceil(transformed_bounds[3]), + ) # flip xy transformed_shape = (transformed_shape[1], transformed_shape[0]) @@ -409,23 +532,34 @@ def _get_bins(path_bins: Path) -> list[str]: # the cytassist image is a small, single-scale image, so we can compute it in memory numpy_data = image.transpose("y", "x", "c").data.compute() warped = warp( - numpy_data, ProjectiveTransform(projective_shift).inverse, output_shape=transformed_shape, order=1 + numpy_data, + ProjectiveTransform(projective_shift).inverse, + output_shape=transformed_shape, + order=1, ) warped = np.round(warped * 255).astype(np.uint8) - warped = Image2DModel.parse(warped, dims=("y", "x", "c"), transformations={dataset_id: affine}, rgb=True) - + warped = Image2DModel.parse( + warped, + dims=("y", "x", "c"), + transformations={dataset_id: affine}, + rgb=True, + ) # we replace the cytassist image with the warped image images[dataset_id + "_cytassist_image"] = warped + elif load_all_images: + warnings.warn( + f"No image path found containing the cytassist image: {VisiumHDKeys.IMAGE_CYTASSIST}", + UserWarning, + stacklevel=2, + ) sdata = SpatialData(tables=tables, images=images, shapes=shapes, labels=labels) if annotate_table_by_labels: for bin_size_str in bin_sizes: shapes_name = dataset_id + "_" + bin_size_str - # add labels layer (rasterized bins). labels_name = f"{dataset_id}_{bin_size_str}_labels" - labels_element = rasterize_bins( sdata, bins=shapes_name, @@ -435,7 +569,6 @@ def _get_bins(path_bins: Path) -> list[str]: value_key=None, return_region_as_labels=True, ) - sdata[labels_name] = labels_element rasterize_bins_link_table_to_labels( sdata=sdata, table_name=bin_size_str, rasterized_labels_name=labels_name @@ -511,7 +644,9 @@ def _projective_matrix_is_affine(projective_matrix: ArrayLike) -> bool: return np.allclose(projective_matrix[2, :2], [0, 0]) -def _decompose_projective_matrix(projective_matrix: ArrayLike) -> tuple[ArrayLike, ArrayLike]: +def _decompose_projective_matrix( + projective_matrix: ArrayLike, +) -> tuple[ArrayLike, ArrayLike]: """Decompose a projective transformation matrix into an affine transformation and a projective shift. Parameters @@ -584,3 +719,116 @@ def _get_transform_matrices(metadata: dict[str, Any], hd_layout: dict[str, Any]) transform_matrices[key] = np.array(coefficients).reshape(3, 3) return transform_matrices + + +def _make_filtered_nucleus_adata( + filtered_matrix_h5_path: Path, + barcode_mappings_parquet_path: Path, + bin_col_name: str = "square_002um", + aggregate_col_name: str = "cell_id", +) -> AnnData: + """Generate a filtered AnnData object by aggregating 2um binned data based on nucleus segmentation. + + Uses filtered_feature_bc_matrix.h5 file and a barcode_mappings.parquet file containing + barcode mappings, filters the data to include only valid nucleus mappings, + and aggregates the data based on specified bin into cell IDs which only contain + the 2um square data under segmented nuclei. + + Parameters: + ----------- + filtered_matrix_h5_path + Path to the 10x Genomics HDF5 matrix file. + barcode_mappings_parquet_path + Path to the Parquet file containing barcode mappings. + bin_col_name + Column name in the barcode mappings that specifies the spatial bin (default is 'square_002um'). + aggregate_col_name + Column name in the barcode mappings that specifies the aggregate cell ID (default is 'cell_id'). + + Returns: + -------- + AnnData + An AnnData object where the observations correspond to filtered cell IDs + and the variables correspond to the original features from the input data. + """ + # Read in the necessary files + adata_2um = sc.read_10x_h5(filtered_matrix_h5_path) + barcode_mappings = pq.read_table(barcode_mappings_parquet_path) + + # Filter to only include valid cell IDs that are in both nucleus and cell + barcode_mappings = barcode_mappings.filter( + (barcode_mappings["cell_id"].is_valid()) and barcode_mappings["in_nucleus"] + ) + + # Filter the 2um adata to only include squares present in the barcode mappings + valid_squares = barcode_mappings[bin_col_name].unique() + squares_to_keep = np.intersect1d(adata_2um.obs_names, valid_squares) + adata_filtered = adata_2um[squares_to_keep, :].copy() + + # Map each square to its corresponding cell ID + square_to_cell_map = dict( + zip( + barcode_mappings[bin_col_name].to_pylist(), + barcode_mappings[aggregate_col_name].to_pylist(), + strict=True, + ) + ) + ordered_cell_ids = [square_to_cell_map[square] for square in adata_filtered.obs_names] + unique_cells = list(dict.fromkeys(ordered_cell_ids).keys()) + cell_to_idx = {cell: i for i, cell in enumerate(unique_cells)} + + # Make the aggregation matrix + col_indices = [cell_to_idx[cell] for cell in ordered_cell_ids] + row_indices = np.arange(len(ordered_cell_ids)) + data = np.ones_like(row_indices) + + aggregation_matrix = csc_matrix( + (data, (row_indices, col_indices)), + shape=(adata_filtered.n_obs, len(unique_cells)), + ) + + # Make the final AnnData object where cell IDs are filtered + # to the data under the segmented nuclei + nucleus_matrix_sparse = adata_filtered.X.T.dot(aggregation_matrix) + adata_nucleus = sc.AnnData(nucleus_matrix_sparse.T) + adata_nucleus.obs_names = unique_cells + adata_nucleus.var = adata_filtered.var + + return adata_nucleus + + +def _extract_geometries_from_geojson( + adata: AnnData, + geojson_path: Path, +) -> GeoDataFrame: + """Extract geometries and create a GeoDataFrame from a GeoJSON features map. + + Parameters + ---------- + cell_adata + AnnData object containing cell data. + geojson_path + Path to the GeoJSON file containing cell segmentation geometries. + + Returns + ------- + GeoDataFrame + A GeoDataFrame containing cell IDs and their corresponding geometries. + """ + with open(geojson_path) as f: + geojson_data = json.load(f) + geojson_features_map: dict[str, Any] = { + f"cellid_{feature['properties']['cell_id']:09d}-1": feature for feature in geojson_data["features"] + } + + geometries = [] + cell_ids_ordered = [] + + for obs_index_str in adata.obs.index: + feature = geojson_features_map.get(obs_index_str) + if feature: + polygon_coords = np.array(feature["geometry"]["coordinates"][0]) + geometries.append(Polygon(polygon_coords)) + cell_ids_ordered.append(obs_index_str) + + return GeoDataFrame({"cell_id": cell_ids_ordered, "geometry": geometries}, index=cell_ids_ordered) diff --git a/tests/test_visium_hd.py b/tests/test_visium_hd.py new file mode 100644 index 00000000..41dfd890 --- /dev/null +++ b/tests/test_visium_hd.py @@ -0,0 +1,215 @@ +import math +from pathlib import Path +from tempfile import TemporaryDirectory + +import numpy as np +import pytest +from click.testing import CliRunner +from spatialdata import get_extent, read_zarr +from spatialdata.models import get_table_keys + +from spatialdata_io.__main__ import visium_hd_wrapper +from spatialdata_io._constants._constants import VisiumHDKeys +from spatialdata_io.readers.visium_hd import ( + _decompose_projective_matrix, + _projective_matrix_is_affine, + visium_hd, +) +from tests._utils import skip_if_below_python_version + +# --- UNIT TESTS FOR HELPER FUNCTIONS --- + + +def test_projective_matrix_is_affine() -> None: + """Test the affine matrix check function.""" + # An affine matrix should have [0, 0, 1] as its last row + affine_matrix = np.array([[2, 0.5, 10], [0.5, 2, 20], [0, 0, 1]]) + assert _projective_matrix_is_affine(affine_matrix) + + # A projective matrix is not affine if the last row is different + projective_matrix = np.array([[2, 0.5, 10], [0.5, 2, 20], [0.01, 0.02, 1]]) + assert not _projective_matrix_is_affine(projective_matrix) + + +def test_decompose_projective_matrix() -> None: + """Test the decomposition of a projective matrix into affine and shift components.""" + projective_matrix = np.array([[1, 2, 3], [4, 5, 6], [0.1, 0.2, 1]]) + affine, shift = _decompose_projective_matrix(projective_matrix) + + expected_affine = np.array([[1, 2, 3], [4, 5, 6], [0, 0, 1]]) + + # The affine component should be correctly extracted + assert np.allclose(affine, expected_affine) + # Recomposing the affine and shift matrices should yield the original projective matrix + assert np.allclose(affine @ shift, projective_matrix) + + +# --- END-TO-END TESTS ON EXAMPLE DATA --- +# This dataset name is used to locate the test data in the './data/' directory. +# See https://github.com/scverse/spatialdata-io/blob/main/.github/workflows/prepare_test_data.yaml +# for instructions on how to download and place the data on disk. +DATASET_FOLDER = "Visium_HD_Tiny_3prime_Dataset_outs" +DATASET_ID = "visium_hd_tiny" + + +@skip_if_below_python_version() +def test_visium_hd_data_extent() -> None: + """Check the spatial extent of the loaded Visium HD data.""" + f = Path("./data") / DATASET_FOLDER + if not f.is_dir(): + pytest.skip(f"Test data not found at '{f}'. Skipping extent test.") + + sdata = visium_hd(f, dataset_id=DATASET_ID) + extent = get_extent(sdata, exact=False, coordinate_system="visium_hd_tiny_downscaled_lowres") + extent = {ax: (math.floor(extent[ax][0]), math.ceil(extent[ax][1])) for ax in extent} + + # TODO: Replace with the actual expected extent of your test data + expected_extent = "{'y': (-31, 540), 'x': (0, 652)}" + assert str(extent) == expected_extent + + +@skip_if_below_python_version() +@pytest.mark.parametrize( + "params", + [ + # Test case 1: Default binned data loading (squares) + { + "load_segmentations_only": False, + "load_nucleus_segmentations": False, + "bins_as_squares": True, + "annotate_table_by_labels": False, + "load_all_images": False, + }, + # Test case 2: Binned data as circles + { + "load_segmentations_only": False, + "load_nucleus_segmentations": False, + "bins_as_squares": False, + "annotate_table_by_labels": False, + "load_all_images": False, + }, + # Test case 3: Binned data with tables annotating labels instead of shapes + { + "load_segmentations_only": False, + "load_nucleus_segmentations": False, + "bins_as_squares": True, + "annotate_table_by_labels": True, + "load_all_images": False, + }, + # Test case 4: Load binned data AND all segmentations (cell + nucleus) + { + "load_segmentations_only": False, + "load_nucleus_segmentations": True, + "bins_as_squares": True, + "annotate_table_by_labels": False, + "load_all_images": False, + }, + # Test case 5: Load cell segmentations only + { + "load_segmentations_only": True, + "load_nucleus_segmentations": False, + "bins_as_squares": True, + "annotate_table_by_labels": False, + "load_all_images": False, + }, + # Test case 6: Load all segmentations (cell + nucleus) only + { + "load_segmentations_only": True, + "load_nucleus_segmentations": True, + "bins_as_squares": True, + "annotate_table_by_labels": False, + "load_all_images": False, + }, + ], +) +def test_visium_hd_data_integrity(params: dict[str, bool]) -> None: + """Check the integrity of various components of the loaded SpatialData object.""" + f = Path("./data") / DATASET_FOLDER + if not f.is_dir(): + pytest.skip(f"Test data not found at '{f}'. Skipping integrity test.") + + sdata = visium_hd(f, dataset_id=DATASET_ID, **params) + + # --- IMAGE CHECKS --- + if params.get("load_all_images", False): + assert f"{DATASET_ID}_hires_image" in sdata.images + assert f"{DATASET_ID}_lowres_image" in sdata.images + assert f"{DATASET_ID}_cytassist_image" not in sdata.images + assert f"{DATASET_ID}_full_image" not in sdata.images + + # --- SEGMENTATION CHECKS (loaded in all modes if present) --- + # TODO: Update placeholder values with actual data from your test dataset + assert VisiumHDKeys.CELL_SEG_KEY_HD in sdata.tables + assert f"{DATASET_ID}_{VisiumHDKeys.CELL_SEG_KEY_HD}" in sdata.shapes + cell_table = sdata.tables[VisiumHDKeys.CELL_SEG_KEY_HD] + assert cell_table.shape == (612, 32285) + assert "cellid_000000001-1" in cell_table.obs_names + + if params["load_nucleus_segmentations"]: + assert VisiumHDKeys.NUCLEUS_SEG_KEY_HD in sdata.tables + assert f"{DATASET_ID}_{VisiumHDKeys.NUCLEUS_SEG_KEY_HD}" in sdata.shapes + nuc_table = sdata.tables[VisiumHDKeys.NUCLEUS_SEG_KEY_HD] + assert nuc_table.shape == (950, 32285) + else: + assert VisiumHDKeys.NUCLEUS_SEG_KEY_HD not in sdata.tables + + # --- BINNED DATA CHECKS --- + if params["load_segmentations_only"]: + assert "square_002um" not in sdata.tables + else: + assert "square_008um" in sdata.tables + table = sdata.tables["square_008um"] + assert table.shape == (20830, 32285) + assert "s_008um_00118_00105-1" in table.obs_names + + shape_name = f"{DATASET_ID}_square_008um" + labels_name = f"{shape_name}_labels" + if params["annotate_table_by_labels"]: + assert labels_name in sdata.labels + region, _, _ = get_table_keys(table) + assert region == labels_name + else: + assert shape_name in sdata.shapes + region, _, _ = get_table_keys(table) + assert region == shape_name + # Check for circles vs. squares + if params["bins_as_squares"]: + assert "radius" not in sdata.shapes[shape_name] + else: + assert "radius" in sdata.shapes[shape_name] + + +# --- CLI WRAPPER TEST --- + + +@skip_if_below_python_version() +@pytest.mark.parametrize( + "dataset", + ["Visium_HD_Tiny_3prime_Dataset_outs"], +) +def test_cli_visium_hd(runner: CliRunner, dataset: str) -> None: + """Test the command-line interface for the Visium HD reader.""" + f = Path("./data") / dataset[0] + + if not f.is_dir(): + pytest.skip(f"Test data not found at '{f}'. Skipping CLI test.") + + with TemporaryDirectory() as tmpdir: + output_zarr = Path(tmpdir) / "data.zarr" + result = runner.invoke( + visium_hd_wrapper, + [ + "--input", + str(f), + "--output", + str(output_zarr), + "--dataset-id", + DATASET_ID, + ], + ) + assert result.exit_code == 0, result.output + # Verify the output can be read + sdata = read_zarr(output_zarr) + + # A simple check to confirm data was loaded + assert f"{DATASET_ID}_lowres_image" in sdata.images