Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions openproblems/tasks/label_projection/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .scvi_tools import scanvi_hvg
from .scvi_tools import scarches_scanvi_all_genes
from .scvi_tools import scarches_scanvi_hvg
from .scvi_tools import scarches_scanvi_xgb_all_genes
from .scvi_tools import scarches_scanvi_xgb_hvg
from .seurat import seurat
from .xgboost import xgboost_log_cpm
from .xgboost import xgboost_scran
98 changes: 97 additions & 1 deletion openproblems/tasks/label_projection/methods/scvi_tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ....tools.decorators import method
from ....tools.utils import check_version
from typing import Optional

import functools

Expand Down Expand Up @@ -86,7 +87,14 @@ def _scanvi(adata, test=False, n_hidden=None, n_latent=None, n_layers=None):
return preds


def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=None):
def _scanvi_scarches(
adata,
test=False,
n_hidden=None,
n_latent=None,
n_layers=None,
prediction_method="scanvi",
):
import scvi

if test:
Expand Down Expand Up @@ -138,6 +146,15 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
train_kwargs["limit_val_batches"] = 10
query_model.train(plan_kwargs=dict(weight_decay=0.0), **train_kwargs)

if prediction_method == "scanvi":
preds = _pred_scanvi(adata, query_model)
elif prediction_method == "xgboost":
preds = _pred_xgb(adata, adata_train, adata_test, query_model, test=test)

return preds


def _pred_scanvi(adata, query_model):
# this is temporary and won't be used
adata.obs["scanvi_labels"] = "Unknown"
preds = query_model.predict(adata)
Expand All @@ -146,6 +163,63 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
return preds


# note: could extend test option
def _pred_xgb(
adata,
adata_train,
adata_test,
query_model,
label_col="labels",
test=False,
num_round: Optional[int] = None,
):
import numpy as np
import xgboost as xgb

df = _classif_df(adata_train, query_model, label_col)

df["labels_int"] = df["labels"].cat.codes
categories = df["labels"].cat.categories

# X_train = df.drop(columns="labels")
X_train = df.drop(columns=["labels", "labels_int"])
# y_train = df["labels"].astype("category")
y_train = df["labels_int"].astype(int)

X_test = query_model.get_latent_representation(adata_test)

if test:
num_round = num_round or 2
else:
num_round = num_round or 5

xgbc = xgb.XGBClassifier(tree_method="hist", objective="multi:softprob")

xgbc.fit(X_train, y_train)

# adata_test.obs["preds_test"] = xgbc.predict(X_test)
adata_test.obs["preds_test"] = categories[xgbc.predict(X_test)]

preds = [
adata_test.obs["preds_test"][idx] if idx in adata_test.obs_names else np.nan
for idx in adata.obs_names
]

return preds


def _classif_df(adata, trained_model, label_col):
import pandas as pd

emb_data = trained_model.get_latent_representation(adata)

df = pd.DataFrame(data=emb_data, index=adata.obs_names)

df["labels"] = adata.obs[label_col]

return df


@_scanvi_method(method_name="scANVI (All genes)")
def scanvi_all_genes(adata, test=False):
adata.obs["labels_pred"] = _scanvi(adata, test=test)
Expand Down Expand Up @@ -176,3 +250,25 @@ def scarches_scanvi_hvg(adata, test=False):
adata.obs["labels_pred"] = _scanvi_scarches(bdata, test=test)
adata.uns["method_code_version"] = check_version("scvi-tools")
return adata


@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (All genes)")
def scarches_scanvi_xgb_all_genes(adata, test=False):
adata.obs["labels_pred"] = _scanvi_scarches(
adata, test=test, prediction_method="xgboost"
)

adata.uns["method_code_version"] = check_version("scvi-tools")
return adata


@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (Seurat v3 2000 HVG)")
def scarches_scanvi_xgb_hvg(adata, test=False):
hvg_df = _hvg(adata, test)
bdata = adata[:, hvg_df.highly_variable].copy()
adata.obs["labels_pred"] = _scanvi_scarches(
bdata, test=test, prediction_method="xgboost"
)

adata.uns["method_code_version"] = check_version("scvi-tools")
return adata