diff --git a/README.md b/README.md index ac17da91..e5aed36e 100644 --- a/README.md +++ b/README.md @@ -31,13 +31,20 @@ apt update && apt install -y libvips libvips-tools libvips-dev base_url = 'https://raw.githubusercontent.com/broadinstitute/celldega_Xenium_Prime_Human_Skin_FFPE_outs/main/Xenium_Prime_Human_Skin_FFPE_outs' landscape_ist = dega.viz.Landscape( - technology='Xenium', - ini_zoom = -4.5, + technology="Xenium", + ini_zoom=-4.5, ini_x=6000, ini_y=8000, - base_url = base_url, - height = 700, - width= 600 + base_url=base_url, + height=700, + width=600, +) + +# Alternatively pass an AnnData object to auto-populate cell metadata +# including "leiden" clusters, colors and UMAP coordinates. +landscape_from_adata = dega.viz.Landscape( + base_url=base_url, + AnnData=adata, ) file_path = 'https://raw.githubusercontent.com/broadinstitute/celldega_Xenium_Prime_Human_Skin_FFPE_outs/main/Xenium_Prime_Human_Skin_FFPE_outs/df_sig.parquet' diff --git a/js/deck-gl/layers/cell_layer.js b/js/deck-gl/layers/cell_layer.js index 600d5547..f99ecb26 100644 --- a/js/deck-gl/layers/cell_layer.js +++ b/js/deck-gl/layers/cell_layer.js @@ -107,12 +107,16 @@ export const ini_cell_layer = async (base_url, viz_state) => { set_cell_name_to_index_map(viz_state.cats); if (viz_state.cats.has_meta_cell) { - viz_state.cats.cell_cats = viz_state.cats.cell_names_array.map( - (name) => viz_state.cats.meta_cell[name] + // look up the index of the inst_cell_attr in the meta_cell_attr array + const inst_index = viz_state.cats.meta_cell_attr.indexOf( + viz_state.cats.inst_cell_attr ); - } else { - // default clustering + viz_state.cats.cell_cats = viz_state.cats.cell_names_array.map((name) => { + const attrs = viz_state.cats.meta_cell[name]; + return attrs?.[inst_index] ?? 'N.A.'; + }); + } else { const cluster_arrow_table = await get_arrow_table( `${base_url}/cell_clusters${viz_state.seg.version && viz_state.seg.version !== 'default' ? `_${viz_state.seg.version}` : ''}/cluster.parquet`, options.fetch, diff --git a/js/global_variables/meta_cluster.js b/js/global_variables/meta_cluster.js index 8d1b921c..f7eb8731 100644 --- a/js/global_variables/meta_cluster.js +++ b/js/global_variables/meta_cluster.js @@ -29,18 +29,27 @@ export const update_meta_cluster = (cats, new_meta_cluster) => { export const set_cluster_metadata = async (viz_state) => { if (viz_state.cats.has_meta_cluster) { + // find the index of color in viz_state.cats.meta_cluster_attr + const color_index = viz_state.cats.meta_cluster_attr.indexOf('color'); + // loop through the keys of meta_cluster and assemble a dictionary of colors use a map or something functional for (const cluster_name in viz_state.cats.meta_cluster) { viz_state.cats.color_dict_cluster[cluster_name] = hexToRgb( - viz_state.cats.meta_cluster[cluster_name]['color'] + viz_state.cats.meta_cluster[cluster_name][color_index] || '#000000' ); } - // loop through the keys and assembe cluster_counts + // find the index of count in viz_state.cats.meta_cluster_attr + const count_index = viz_state.cats.meta_cluster_attr.indexOf('count'); + for (const cluster_name in viz_state.cats.meta_cluster) { + + const raw = viz_state.cats.meta_cluster[cluster_name][count_index]; + const value = raw !== undefined ? Number(raw) : 0; + viz_state.cats.cluster_counts.push({ name: cluster_name, - value: viz_state.cats.meta_cluster[cluster_name]['count'], + value, }); } } else { diff --git a/js/read_parquet/objects_from_parquet.js b/js/read_parquet/objects_from_parquet.js new file mode 100644 index 00000000..58abbf98 --- /dev/null +++ b/js/read_parquet/objects_from_parquet.js @@ -0,0 +1,36 @@ +import { arrayBufferToArrowTable } from './arrayBufferToArrowTable'; + +/** + * Converts a Parquet-encoded ArrayBuffer into an object using the specified key field. + * + * @param {ArrayBuffer} bytes - The Parquet bytes. + * @param {string} keyField - The name of the field to use as the key. + * @returns {Promise<{ result: Object, attr: string[] }>} + */ +export const objects_from_parquet = async ( + bytes, + keyField = '__index_level_0__' +) => { + const table = await arrayBufferToArrowTable(bytes.buffer); + const fields = table.schema.fields.map((f) => f.name); + + if (fields.length < 2) return {}; + + if (!fields.includes(keyField)) { + throw new Error( + `Key field "${keyField}" not found in Parquet fields: ${fields.join(', ')}` + ); + } + + const keyCol = table.getChild(keyField).toArray(); + const valueFields = fields.filter((f) => f !== keyField); + const valueCols = valueFields.map((f) => table.getChild(f).toArray()); + + const result = {}; + for (let i = 0; i < table.numRows; i++) { + const key = String(keyCol[i]); + result[key] = valueCols.map((col) => col[i]); + } + + return { result, attr: valueFields }; +}; diff --git a/js/viz/landscape_ist.js b/js/viz/landscape_ist.js index 37317c3b..aaf80530 100644 --- a/js/viz/landscape_ist.js +++ b/js/viz/landscape_ist.js @@ -78,8 +78,12 @@ export const landscape_ist = async ( width = 0, height = 800, meta_cell = {}, + meta_cell_attr = [], meta_cluster = {}, + meta_cluster_attr = [], + // meta_cluster_attr = [], umap = {}, + // umap_attr = [], landscape_state = 'spatial', segmentation = 'default', creds = {}, @@ -212,13 +216,14 @@ export const landscape_ist = async ( viz_state.cats.cluster_counts = []; viz_state.cats.polygon_cell_names = []; - // check if meta_cell is an empty object if (Object.keys(meta_cell).length === 0) { viz_state.cats.has_meta_cell = false; } else { viz_state.cats.has_meta_cell = true; } viz_state.cats.meta_cell = meta_cell; + viz_state.cats.meta_cell_attr = meta_cell_attr; + viz_state.cats.inst_cell_attr = meta_cell_attr[0] || 'N.A.'; if (Object.keys(meta_cluster).length === 0) { viz_state.cats.has_meta_cluster = false; @@ -226,6 +231,8 @@ export const landscape_ist = async ( viz_state.cats.has_meta_cluster = true; } viz_state.cats.meta_cluster = meta_cluster; + viz_state.cats.meta_cluster_attr = meta_cluster_attr; + viz_state.cats.inst_cluster_attr = meta_cluster_attr[0] || 'N.A.'; viz_state.umap = {}; if (Object.keys(umap).length === 0) { diff --git a/js/widget.js b/js/widget.js index 8a2c371c..4776dbb8 100644 --- a/js/widget.js +++ b/js/widget.js @@ -1,6 +1,7 @@ import './widget.css'; import { networkFromParquet } from './read_parquet/network_from_parquet'; +import { objects_from_parquet } from './read_parquet/objects_from_parquet'; import { handleAsyncError, handleValidationWarning, @@ -23,9 +24,26 @@ const render_landscape_ist = async ({ model, el }) => { const dataset_name = model.get('dataset_name'); const width = model.get('width'); const height = model.get('height'); - const meta_cell = model.get('meta_cell'); - const meta_cluster = model.get('meta_cluster'); - const umap = model.get('umap'); + + let meta_cell_data; + let meta_cluster_data; + // let umap_data; + + const metaCellBytes = model.get('meta_cell_parquet'); + if (metaCellBytes && metaCellBytes.byteLength > 0) { + meta_cell_data = await objects_from_parquet(metaCellBytes, 'cell_id'); + } + + const metaClusterBytes = model.get('meta_cluster_parquet'); + if (metaClusterBytes && metaClusterBytes.byteLength > 0) { + meta_cluster_data = await objects_from_parquet(metaClusterBytes, 'leiden'); + } + + const umapBytes = model.get('umap_parquet'); + if (umapBytes && umapBytes.byteLength > 0) { + // umap_data = await objects_from_parquet(umapBytes); + } + const landscape_state = model.get('landscape_state'); const segmentation = model.get('segmentation'); @@ -42,9 +60,12 @@ const render_landscape_ist = async ({ model, el }) => { 0.25, width, height, - meta_cell, - meta_cluster, - umap, + meta_cell_data.result, + meta_cell_data.attr, + // {}, // meta_cluster, + meta_cluster_data.result, + meta_cluster_data.attr, + {}, // umap, landscape_state, segmentation, creds diff --git a/src/celldega/viz/widget.py b/src/celldega/viz/widget.py index 5e4fdc59..5494e27d 100644 --- a/src/celldega/viz/widget.py +++ b/src/celldega/viz/widget.py @@ -7,13 +7,21 @@ import warnings import anywidget +import pandas as pd import traitlets +import colorsys _clustergram_registry = {} # maps names to widget instances _enrich_registry = {} # maps names to widget instances +def _hsv_to_hex(h: float) -> str: + """Convert HSV color to hex string.""" + r, g, b = colorsys.hsv_to_rgb(h, 0.65, 0.9) + return f"#{int(r * 255):02x}{int(g * 255):02x}{int(b * 255):02x}" + + class Landscape(anywidget.AnyWidget): """ A widget for interactive visualization of spatial omics data. This widget @@ -25,8 +33,13 @@ class Landscape(anywidget.AnyWidget): ini_zoom (float): The initial zoom level of the view. token (str): The token traitlet. base_url (str): The base URL for the widget. + AnnData (AnnData, optional): AnnData object to derive metadata from. dataset_name (str, optional): The name of the dataset to visualize. This will show up in the user interface bar. + The AnnData input automatically extracts cell attributes (e.g., ``leiden`` + clusters), the corresponding colors (or derives them when missing), and any + available UMAP coordinates. + Attributes: component (str): The name of the component. technology (str): The technology used. @@ -61,19 +74,100 @@ class Landscape(anywidget.AnyWidget): region = traitlets.Dict({}).tag(sync=True) nbhd = traitlets.Dict({}).tag(sync=True) - meta_cell = traitlets.Dict({}).tag(sync=True) meta_cluster = traitlets.Dict({}).tag(sync=True) - umap = traitlets.Dict({}).tag(sync=True) landscape_state = traitlets.Unicode("spatial").tag(sync=True) update_trigger = traitlets.Dict().tag(sync=True) cell_clusters = traitlets.Dict({}).tag(sync=True) + # make a traitlet for cell_attr a list that will have the AnnData obs columns + # cell_attr = traitlets.List(['leiden']).tag(sync=True) + cell_attr = traitlets.List(trait=traitlets.Unicode(), default_value=["leiden"]).tag(sync=True) + segmentation = traitlets.Unicode("default").tag(sync=True) width = traitlets.Int(0).tag(sync=True) height = traitlets.Int(800).tag(sync=True) + def __init__(self, **kwargs): + adata = kwargs.pop("adata", None) or kwargs.pop("AnnData", None) + pq_meta_cell = kwargs.pop("meta_cell_parquet", None) + pq_meta_cluster = kwargs.pop("meta_cluster_parquet", None) + pq_umap = kwargs.pop("umap_parquet", None) + + meta_cell_df = kwargs.pop("meta_cell", None) + meta_cluster = kwargs.get("meta_cluster") + umap_df = kwargs.pop("umap", None) + meta_cluster_df = None + # cell_attr = kwargs.pop("cell_attr", "leiden") + cell_attr = kwargs.pop("cell_attr", ["leiden"]) + + def _df_to_bytes(df): + import io + + import pyarrow as pa + import pyarrow.parquet as pq + + df.columns = df.columns.map(str) + buf = io.BytesIO() + pq.write_table(pa.Table.from_pandas(df), buf, compression="zstd") + return buf.getvalue() + + if adata is not None: + meta_cell_df = adata.obs[cell_attr].copy() + # meta_cell_df.reset_index(inplace=True) + pq_meta_cell = _df_to_bytes(meta_cell_df) + + if "leiden" in adata.obs.columns: + cluster_counts = adata.obs["leiden"].value_counts().sort_index() + colors = adata.uns.get("leiden_colors") + if colors is None: + n = len(cluster_counts) + colors = [_hsv_to_hex(i / max(n, 1)) for i in range(n)] + meta_cluster_df = pd.DataFrame( + { + "color": list(colors)[: len(cluster_counts)], + "count": cluster_counts.values, + }, + index=cluster_counts.index, + ) + + pq_meta_cluster = _df_to_bytes(meta_cluster_df) + + if "X_umap" in adata.obsm: + umap_df = pd.DataFrame(adata.obsm["X_umap"], index=adata.obs.index).reset_index() + pq_umap = _df_to_bytes(umap_df) + + if isinstance(meta_cell_df, pd.DataFrame): + pq_meta_cell = _df_to_bytes(meta_cell_df.reset_index()) + + if isinstance(meta_cluster, pd.DataFrame): + pq_meta_cluster = _df_to_bytes(meta_cluster.reset_index()) + kwargs.pop("meta_cluster") + meta_cluster_df = meta_cluster + + if isinstance(umap_df, pd.DataFrame): + pq_umap = _df_to_bytes(umap_df.reset_index()) + + parquet_traits = {} + if pq_meta_cell is not None: + parquet_traits["meta_cell_parquet"] = traitlets.Bytes(pq_meta_cell).tag(sync=True) + if pq_meta_cluster is not None: + parquet_traits["meta_cluster_parquet"] = traitlets.Bytes(pq_meta_cluster).tag(sync=True) + if pq_umap is not None: + parquet_traits["umap_parquet"] = traitlets.Bytes(pq_umap).tag(sync=True) + + if parquet_traits: + self.add_traits(**parquet_traits) + + super().__init__(**kwargs) + + # store DataFrames locally without syncing to the frontend + self.meta_cell = meta_cell_df + self.umap = umap_df + if meta_cluster_df is not None: + self.meta_cluster_df = meta_cluster_df + def trigger_update(self, new_value): """ Update the update_trigger traitlet with a new value. diff --git a/tests/unit/test_viz/test_landscape_anndata.py b/tests/unit/test_viz/test_landscape_anndata.py new file mode 100644 index 00000000..5233f437 --- /dev/null +++ b/tests/unit/test_viz/test_landscape_anndata.py @@ -0,0 +1,59 @@ +"""Tests for Landscape widget initialization with AnnData.""" + +import io +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData + +try: + from celldega.viz import Landscape +except Exception as e: # pragma: no cover - skip if deps missing + pytest.skip(f"celldega modules unavailable: {e}", allow_module_level=True) + + +def make_simple_anndata() -> AnnData: + """Create a small AnnData object for testing.""" + np.random.seed(0) + X = np.random.rand(5, 3) + obs = pd.DataFrame({"leiden": pd.Categorical(["0", "1", "0", "1", "0"])}) + obs.index = [f"cell{i}" for i in range(5)] + var = pd.DataFrame(index=[f"gene{i}" for i in range(3)]) + adata = AnnData(X=X, obs=obs, var=var) + adata.obsm["X_umap"] = np.random.rand(5, 2) + adata.uns["leiden_colors"] = ["#ff0000", "#00ff00"] + return adata + + +def test_landscape_initializes_with_anndata() -> None: + """Landscape should accept AnnData and expose parquet traitlets.""" + adata = make_simple_anndata() + widget = Landscape(base_url="https://example.com", AnnData=adata) + + assert hasattr(widget, "meta_cell_parquet") + assert hasattr(widget, "meta_cluster_parquet") + assert hasattr(widget, "umap_parquet") + + meta_cell = pd.read_parquet(io.BytesIO(widget.meta_cell_parquet)) + meta_cluster = pd.read_parquet(io.BytesIO(widget.meta_cluster_parquet)) + umap_df = pd.read_parquet(io.BytesIO(widget.umap_parquet)) + + pd.testing.assert_frame_equal( + meta_cell, + adata.obs[["leiden"]].reset_index(), + ) + + cluster_counts = adata.obs["leiden"].value_counts().sort_index() + expected_cluster = pd.DataFrame( + { + "color": adata.uns["leiden_colors"][: len(cluster_counts)], + "count": cluster_counts.values, + }, + index=cluster_counts.index, + ).reset_index() + pd.testing.assert_frame_equal(meta_cluster, expected_cluster) + + pd.testing.assert_frame_equal( + umap_df, + pd.DataFrame(adata.obsm["X_umap"], index=adata.obs.index).reset_index(), + )