diff --git a/src/evaluation/sc_evaluate_opt.py b/src/evaluation/sc_evaluate_opt.py new file mode 100644 index 0000000..4a7e86e --- /dev/null +++ b/src/evaluation/sc_evaluate_opt.py @@ -0,0 +1,232 @@ +import os +import click +import yaml +import sys +import fnmatch +import numpy as np +import pandas as pd +import scanpy as sc +from scipy.sparse import issparse # Import for sparse checks + +src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(src_dir) + +from evaluation.utils.sc_metrics_opt import ( + filter_low_quality_cells_and_genes, + Statistics, VisualizeClassify +) + +def check_dirs(path): + if not os.path.exists(path): + os.makedirs(path) + +class SingleCellEvaluator: + def __init__(self, config): + self.config = config + self.home_dir = config["dir_list"]["home"] + self.dataset_config = config["dataset_config"] + self.dataset_name = self.dataset_config["name"] + self.cell_type_col = self.dataset_config["cell_type_col_name"] + self.cell_label_col = self.dataset_config["cell_label_col_name"] + + self.save_dir = os.path.join(self.home_dir, "data_splits") + self.random_seed = config["evaluator_config"]["random_seed"] + + ## experiment name + self.experiment_name = self.config['generator_config']['experiment_name'] + self.generator_name = self.config['generator_config']['name'] + self.res_figures_dir = os.path.join(self.home_dir, + config["dir_list"]["figures"], + self.dataset_name, + self.generator_name, + self.experiment_name + ) + self.res_files_dir = os.path.join(self.home_dir, + config["dir_list"]["res_files"], + self.dataset_name, + self.generator_name, + self.experiment_name) + check_dirs(self.res_figures_dir) + check_dirs(self.res_files_dir) + + self.synthetic_data_path = os.path.join(self.save_dir, + self.dataset_name, + "synthetic", + self.generator_name, + self.experiment_name) + self.celltypist_model_path = os.path.join(self.home_dir, + self.dataset_config["celltypist_model"]) + self.results = {} + + @staticmethod + def save_split_results(results, output_file): + df = pd.DataFrame([results]) + df.to_csv(output_file, index=False) + + def load_test_anndata(self): + try: + test_data_pth = os.path.join(self.home_dir, self.dataset_config["test_count_file"]) + test_data = sc.read_h5ad(test_data_pth) + + test_data.obs[self.cell_label_col] = ( + test_data.obs[self.cell_label_col] + .astype(str) + .str.replace(" ", "_", regex=True) + ) + + # Instead of converting to dense, check for NaN and Inf directly + X = test_data.X + if issparse(X): + nan_count = np.isnan(X.data).sum() + inf_count = np.isinf(X.data).sum() + else: + nan_count = np.isnan(X).sum() + inf_count = np.isinf(X).sum() + + if nan_count > 0 or inf_count > 0: + raise ValueError(f"Test data contains {nan_count} NaN values and {inf_count} Inf values.") + + print(test_data) + return test_data + except Exception as e: + raise Exception(f"Failed to load test anndata: {e}") + + + def load_synthetic_anndata(self): + try: + syn_data_pth = os.path.join(self.synthetic_data_path, "onek1k_annotated_synthetic.h5ad") + syn_data = sc.read_h5ad(syn_data_pth) + + # Check for NaN and Inf values without converting to dense + X = syn_data.X + if issparse(X): + nan_count = np.isnan(X.data).sum() + inf_count = np.isinf(X.data).sum() + else: + nan_count = np.isnan(X).sum() + inf_count = np.isinf(X).sum() + + if nan_count > 0 or inf_count > 0: + raise ValueError(f"Synthetic data contains {nan_count} NaN values and {inf_count} Inf values.") + + print(syn_data) + return syn_data + except Exception as e: + raise Exception(f"Failed to load synthetic anndata: {e}") + + def initialize_datasets(self): + test_anndata = self.load_test_anndata() + synthetic_anndata = self.load_synthetic_anndata() + + print(f"Initial gene count - Real: {test_anndata.n_vars}, Synthetic: {synthetic_anndata.n_vars}") + real_data = filter_low_quality_cells_and_genes(test_anndata) + synthetic_data = filter_low_quality_cells_and_genes(synthetic_anndata) + print(f"After filtering - Real: {real_data.n_vars}, Synthetic: {synthetic_data.n_vars}") + + # make sure both datasets have the same genes after filter + common_genes = real_data.var_names.intersection(synthetic_data.var_names) + real_data = real_data[:, common_genes] + synthetic_data = synthetic_data[:, common_genes] + + print(f"After gene alignment - Real: {real_data.n_vars}, Synthetic: {synthetic_data.n_vars}") + + return real_data, synthetic_data + + def get_statistical_evals(self): + real_data, synthetic_data = self.initialize_datasets() + stats = Statistics(self.random_seed) + print("Computing SCC...") + scc = stats.compute_scc(real_data, synthetic_data) + print("Computing MMD...") + mmd = stats.compute_mmd_optimized(real_data, synthetic_data) + print("Computing LISI...") + lisi = stats.compute_lisi(real_data, synthetic_data) + print("Computing ARI...") + ari_real_syn, ari_gt_comb = stats.compute_ari(real_data, synthetic_data, self.cell_type_col) + print("Done.") + + return { + 'scc': scc, + 'mmd': mmd, + 'lisi': lisi, + 'ari_real_vs_syn': ari_real_syn, + 'ari_gt_vs_comb': ari_gt_comb + } + + def get_umap_evals(self, n_hvgs: int): + real_data, synthetic_data = self.initialize_datasets() + visual = VisualizeClassify(self.res_figures_dir, self.random_seed) + visual.plot_umap(real_data, synthetic_data, n_hvgs) + + def get_classification_evals(self): + real_data, synthetic_data = self.initialize_datasets() + classfier = VisualizeClassify(self.res_figures_dir, self.random_seed) + ari_score, jaccard = classfier.celltypist_classification(real_data, + synthetic_data, + self.celltypist_model_path) + roc_score, _ = classfier.random_forest_eval(real_data, synthetic_data) + + return { + "celltypist_ari": ari_score, + "celltypist_jaccard": jaccard, + "randomforest_roc": roc_score, + } + + @staticmethod + def save_results_to_csv(results, output_file): + df = pd.DataFrame([results]) + df.to_csv(output_file, index=False) + +@click.group() +def cli(): + pass + +@click.command() +def run_statistical_eval(): + with open("config.yaml", 'r') as file: + config = yaml.safe_load(file) + + evaluator = SingleCellEvaluator(config=config) + results = evaluator.get_statistical_evals() + + output_file = os.path.join(evaluator.res_files_dir, f"statistics_evals.csv") + evaluator.save_results_to_csv(results, output_file) + click.echo(f"Evaluation for classification is completed. Results saved to {output_file}") + +@click.command() +@click.argument("n_hvgs", type=int, default=2000) +def run_umap_eval(n_hvgs): + with open("config.yaml", 'r') as file: + config = yaml.safe_load(file) + + evaluator = SingleCellEvaluator(config=config) + evaluator.get_umap_evals(n_hvgs) + +@click.command() +@click.argument("cell_label", type=str, default="CD4 ET") +def run_qq_eval(cell_label: str): + with open("config.yaml", 'r') as file: + config = yaml.safe_load(file) + + evaluator = SingleCellEvaluator(config=config) + evaluator.save_qq_evals(cell_label=cell_label) + +@click.command() +def run_classification_eval(): + with open("config.yaml", 'r') as file: + config = yaml.safe_load(file) + + evaluator = SingleCellEvaluator(config=config) + results = evaluator.get_classification_evals() + + output_file = os.path.join(evaluator.res_files_dir, f"classification_evals.csv") + evaluator.save_results_to_csv(results, output_file) + click.echo(f"Evaluation for classification is completed. Results saved to {output_file}") + +cli.add_command(run_classification_eval) +cli.add_command(run_umap_eval) +cli.add_command(run_statistical_eval) +cli.add_command(run_qq_eval) + +if __name__ == '__main__': + cli() diff --git a/src/evaluation/utils/sc_metrics_opt.py b/src/evaluation/utils/sc_metrics_opt.py new file mode 100644 index 0000000..081a7e6 --- /dev/null +++ b/src/evaluation/utils/sc_metrics_opt.py @@ -0,0 +1,375 @@ +import umap +import os +import numpy as np +import scanpy as sc +import scipy.stats as stats +import scipy.sparse +import matplotlib.pyplot as plt +import seaborn as sns +from scipy.spatial.distance import cdist +from scipy.stats import spearmanr +from sklearn.model_selection import train_test_split +from sklearn.decomposition import PCA, TruncatedSVD +from sklearn.ensemble import RandomForestClassifier +from sklearn.preprocessing import LabelBinarizer +from sklearn.metrics.pairwise import rbf_kernel +from sklearn.metrics import adjusted_rand_score, roc_auc_score, jaccard_score +from scib.metrics import ilisi_graph +import celltypist +from scipy.sparse import issparse + +_DEF_N_HVGS = 5000 + +def filter_low_quality_cells_and_genes(adata, min_counts=10, min_cells=3): + """ + Filters cells and genes based on minimum counts. + Uses Scanpy’s built-in filtering functions (which are sparse-aware). + """ + adata = adata.copy() + sc.pp.filter_cells(adata, min_counts=min_counts) + sc.pp.filter_genes(adata, min_cells=min_cells) + return adata + +def get_dense_column(adata, i): + """ + Returns the i-th column of adata.X as a dense vector. + This avoids converting the entire matrix to dense at once. + """ + X = adata.X + if issparse(X): + return X[:, i].toarray().ravel() + else: + return np.array(X[:, i]).ravel() + +def check_for_inf_nan(adata, label): + """ + Checks for NaN/Inf values in adata.X without converting the whole matrix. + """ + X = adata.X + if issparse(X): + data = X.data + else: + data = np.array(X) + print(f"==> Checking {label} dataset:") + print(f" NaNs? {np.isnan(data).any()}") + print(f" Infs? {np.isinf(data).any()}") + print(f" Min: {data.min()}, Max: {data.max()}\n") + +def check_missing_genes(real_data, synthetic_data): + """ + Compares gene names between real and synthetic datasets. + """ + real_genes = set(real_data.var_names) + synthetic_genes = set(synthetic_data.var_names) + missing_in_real = synthetic_genes - real_genes + missing_in_synthetic = real_genes - synthetic_genes + + print("==> Checking gene differences:") + print(f" Genes in synthetic but not in real: {len(missing_in_real)}") + print(f" Genes in real but not in synthetic: {len(missing_in_synthetic)}") + print(f" Example missing in real: {list(missing_in_real)[:10]}") + print(f" Example missing in synthetic: {list(missing_in_synthetic)[:10]}") + print(f" real_data.var_names dtype: {real_data.var_names.dtype}") + print(f" synthetic_data.var_names dtype: {synthetic_data.var_names.dtype}\n") + +class Statistics: + def __init__(self, random_seed=42): + self.random_seed = random_seed + np.random.seed(self.random_seed) + + def compute_scc(self, real_data, synthetic_data, n_hvgs=_DEF_N_HVGS): + """ + Computes the mean Spearman correlation across highly variable genes (HVGs) + between the real and synthetic datasets. Instead of converting the whole + expression matrix to dense, each gene column is converted on the fly. + """ + np.random.seed(self.random_seed) + print("=== Starting compute_scc ===") + check_missing_genes(real_data, synthetic_data) + + # Align genes using the gene names from synthetic_data + common_genes = synthetic_data.var_names + print("Aligning real and synthetic data on common genes...") + real_data = real_data[:, common_genes] + synthetic_data = synthetic_data[:, common_genes] + + check_for_inf_nan(real_data, "Real") + check_for_inf_nan(synthetic_data, "Synthetic") + + # Normalize and log-transform both datasets + print("Normalizing and log-transforming real data...") + sc.pp.normalize_total(real_data, target_sum=1e4) + sc.pp.log1p(real_data) + print("Normalizing and log-transforming synthetic data...") + sc.pp.normalize_total(synthetic_data, target_sum=1e4) + sc.pp.log1p(synthetic_data) + + check_for_inf_nan(real_data, "Real") + check_for_inf_nan(synthetic_data, "Synthetic") + + # Identify HVGs using the combined dataset + print("Concatenating datasets to identify highly variable genes...") + combined_adata = real_data.concatenate(synthetic_data) + sc.pp.normalize_total(combined_adata, target_sum=1e4) + sc.pp.log1p(combined_adata) + print("Identifying highly variable genes...") + sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) + + # Subset to HVGs + hvgs = combined_adata.var["highly_variable"] + print(f"Subsetting to HVGs: {hvgs.sum()} genes selected.") + real_hvg = real_data[:, hvgs] + synth_hvg = synthetic_data[:, hvgs] + + # Compute Spearman correlation gene-by-gene + print("Computing Spearman correlation gene-by-gene...") + scc_values = [] + total_genes = real_hvg.n_vars + progress_interval = max(1, total_genes // 100) + for i in range(total_genes): + real_vec = get_dense_column(real_hvg, i) + synth_vec = get_dense_column(synth_hvg, i) + corr, _ = stats.spearmanr(real_vec, synth_vec, nan_policy='omit') + scc_values.append(corr) + if (i + 1) % progress_interval == 0 or (i + 1) == total_genes: + percent = ((i + 1) / total_genes) * 100 + print(f" Processed {i + 1} / {total_genes} genes ({percent:.0f}%)") + scc_values = np.array(scc_values) + mean_corr = np.nanmean(scc_values) if not np.all(np.isnan(scc_values)) else np.nan + print(f"Finished compute_scc: Mean Spearman correlation = {mean_corr:.4f}\n") + return mean_corr + + def compute_mmd_optimized(self, real_data, synthetic_data, sample_size=20000, + n_pca=50, gamma=1.0, n_hvgs=_DEF_N_HVGS): + np.random.seed(self.random_seed) + # Align genes using synthetic_data's ordering + common_genes = synthetic_data.var_names + real_data = real_data[:, common_genes] + synthetic_data = synthetic_data[:, common_genes] + + combined_adata = real_data.concatenate(synthetic_data) + sc.pp.normalize_total(combined_adata, target_sum=1e4) + sc.pp.log1p(combined_adata) + sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) + + hvgs = combined_adata.var["highly_variable"] + real_hvg = real_data[:, hvgs] + synth_hvg = synthetic_data[:, hvgs] + + n_real = real_hvg.n_obs + n_synth = synth_hvg.n_obs + + real_idx = np.random.choice(n_real, min(sample_size, n_real), replace=False) + synth_idx = np.random.choice(n_synth, min(sample_size, n_synth), replace=False) + + # Process sparse or dense data accordingly + if issparse(real_hvg.X): + real_sample = real_hvg.X[real_idx] + synth_sample = synth_hvg.X[synth_idx] + from scipy.sparse import vstack + combined_sample = vstack([real_sample, synth_sample]) + pca_model = TruncatedSVD(n_components=n_pca, random_state=self.random_seed) + combined_pca = pca_model.fit_transform(combined_sample) + else: + real_sample = real_hvg.X[real_idx] + synth_sample = synth_hvg.X[synth_idx] + combined_sample = np.vstack([real_sample, synth_sample]) + pca_model = PCA(n_components=n_pca, random_state=self.random_seed) + combined_pca = pca_model.fit_transform(combined_sample) + + # Use shape[0] instead of len() to get the number of real samples + num_real = real_sample.shape[0] + real_pca = combined_pca[:num_real] + synth_pca = combined_pca[num_real:] + + K_xx = rbf_kernel(real_pca, real_pca, gamma=gamma).mean() + K_yy = rbf_kernel(synth_pca, synth_pca, gamma=gamma).mean() + K_xy = rbf_kernel(real_pca, synth_pca, gamma=gamma).mean() + + return K_xx + K_yy - 2 * K_xy + + def compute_lisi(self, real_data, synthetic_data, n_hvgs=_DEF_N_HVGS): + np.random.seed(self.random_seed) + common_genes = synthetic_data.var_names + real_data = real_data[:, common_genes] + synthetic_data = synthetic_data[:, common_genes] + combined_adata = real_data.concatenate( + synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] + ) + # Create a numeric batch label (0 = real, 1 = synthetic) + combined_adata.obs["batch"] = (combined_adata.obs["source"] == "synthetic").astype(int) + + sc.pp.normalize_total(combined_adata, target_sum=1e4) + sc.pp.log1p(combined_adata) + sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) + combined_adata = combined_adata[:, combined_adata.var['highly_variable']] + + # Dynamically determine the number of PCA components + n_obs, n_vars = combined_adata.shape + n_comps = min(n_hvgs, n_obs - 1, n_vars - 1) if n_obs > 1 and n_vars > 1 else 1 + print(f"Performing PCA with n_comps={n_comps} (n_obs={n_obs}, n_vars={n_vars})") + + sc.pp.pca(combined_adata, n_comps=n_comps, random_state=self.random_seed) + sc.pp.neighbors(combined_adata, n_neighbors=10, method='umap') + + return ilisi_graph(combined_adata, batch_key="batch", type_="knn") + + + def compute_ari(self, real_data, synthetic_data, cell_type_col, n_hvgs=_DEF_N_HVGS): + """ + Computes the Adjusted Rand Index (ARI) to measure clustering consistency + between real and synthetic data. Clusters are obtained via Scanpy's Louvain. + """ + np.random.seed(self.random_seed) + print("=== Starting compute_ari ===") + common_genes = synthetic_data.var_names + real_data = real_data[:, common_genes] + synthetic_data = synthetic_data[:, common_genes] + combined_adata = real_data.concatenate( + synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] + ) + + print("Normalizing, log-transforming, and selecting HVGs for ARI computation...") + sc.pp.normalize_total(combined_adata, target_sum=1e4) + sc.pp.log1p(combined_adata) + sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) + combined_adata = combined_adata[:, combined_adata.var['highly_variable']] + + n_obs, n_vars = combined_adata.shape + n_comps = min(n_hvgs, n_obs - 1, n_vars - 1) if n_obs > 1 and n_vars > 1 else 1 + print(f"Performing PCA with n_comps={n_comps} (n_obs={n_obs}, n_vars={n_vars}) and computing neighbors") + sc.pp.pca(combined_adata, n_comps=n_comps, random_state=self.random_seed) + sc.pp.neighbors(combined_adata, n_neighbors=10, method='umap') + print("Clustering with Louvain...") + sc.tl.louvain(combined_adata) + + # Convert Louvain clusters to numerical labels + combined_adata.obs["louvain"] = combined_adata.obs["louvain"].astype("category").cat.codes + real_clusters = combined_adata.obs.loc[combined_adata.obs["source"] == "real", "louvain"].values + synthetic_clusters = combined_adata.obs.loc[combined_adata.obs["source"] == "synthetic", "louvain"].values + ari_real_vs_syn = adjusted_rand_score(real_clusters, synthetic_clusters) + ari_gt_vs_comb = adjusted_rand_score(combined_adata.obs[cell_type_col], combined_adata.obs["louvain"]) + + print(f"Finished compute_ari: ARI (real vs synthetic) = {ari_real_vs_syn:.4f}, ARI (ground truth vs clusters) = {ari_gt_vs_comb:.4f}\n") + return ari_real_vs_syn, ari_gt_vs_comb + +class VisualizeClassify: + def __init__(self, sc_figures_dir, random_seed=42): + self.random_seed = random_seed + self.sc_figures_dir = sc_figures_dir + np.random.seed(self.random_seed) + + def plot_umap(self, real_data, synthetic_data, n_hvgs=_DEF_N_HVGS): + """ + Creates and saves a UMAP plot of the combined real and synthetic data. + """ + print("=== Starting UMAP plotting ===") + sc.settings.figdir = self.sc_figures_dir + np.random.seed(self.random_seed) + check_for_inf_nan(real_data, "Real") + check_for_inf_nan(synthetic_data, "Synthetic") + combined_adata = real_data.concatenate( + synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] + ) + + print("Normalizing, log-transforming, and selecting HVGs for UMAP...") + sc.pp.normalize_total(combined_adata, target_sum=1e4) + sc.pp.log1p(combined_adata) + sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) + combined_adata = combined_adata[:, combined_adata.var['highly_variable']] + + n_obs, n_vars = combined_adata.shape + n_comps = min(n_hvgs, n_obs - 1, n_vars - 1) if n_obs > 1 and n_vars > 1 else 1 + print(f"Performing PCA with n_comps={n_comps} (n_obs={n_obs}, n_vars={n_vars}), computing neighbors, and generating UMAP...") + sc.pp.pca(combined_adata, n_comps=n_comps, random_state=self.random_seed) + sc.pp.neighbors(combined_adata) + sc.tl.umap(combined_adata, random_state=self.random_seed) + + sc.pl.umap(combined_adata, + color=["source"], + title="UMAP of Real vs Synthetic Data", + save=f"syn_test_PCA_HVG={n_hvgs}.png") + print("UMAP plot saved.\n") + + def celltypist_classification(self, real_data_test, synthetic_data, celltypist_model, n_hvgs=_DEF_N_HVGS): + """ + Uses a CellTypist model to annotate cells from both datasets and then compares + the predicted labels via ARI and Jaccard scores. + """ + np.random.seed(self.random_seed) + print("=== Starting celltypist classification ===") + combined_adata = real_data_test.concatenate(synthetic_data) + sc.pp.normalize_total(combined_adata, target_sum=1e4) + sc.pp.log1p(combined_adata) + sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) + + # Normalize and log-transform each dataset individually + sc.pp.normalize_total(real_data_test, target_sum=1e4) + sc.pp.log1p(real_data_test) + sc.pp.normalize_total(synthetic_data, target_sum=1e4) + sc.pp.log1p(synthetic_data) + + # Subset both datasets to HVGs + real_data_test = real_data_test[:, combined_adata.var['highly_variable']] + synthetic_data = synthetic_data[:, combined_adata.var['highly_variable']] + + print("Loading CellTypist model and annotating cells...") + model = celltypist.models.Model.load(celltypist_model) + real_predictions = celltypist.annotate(real_data_test, model=model) + synthetic_predictions = celltypist.annotate(synthetic_data, model=model) + + real_labels = real_predictions.predicted_labels.values.ravel() + synthetic_labels = synthetic_predictions.predicted_labels.values.ravel() + + ari_score = adjusted_rand_score(real_labels, synthetic_labels) + + lb = LabelBinarizer() + real_onehot = lb.fit_transform(real_labels) + synthetic_onehot = lb.transform(synthetic_labels) + + jaccard_scores = [ + jaccard_score(real_onehot[:, i], synthetic_onehot[:, i]) + for i in range(real_onehot.shape[1]) + ] + jaccard = np.mean(jaccard_scores) + print(f"Finished celltypist classification: ARI = {ari_score:.4f}, Jaccard = {jaccard:.4f}\n") + return ari_score, jaccard + + def random_forest_eval(self, real_data, synthetic_data, n_hvgs=_DEF_N_HVGS): + """ + Evaluates how well a Random Forest can separate real vs. synthetic cells. + After batch correction, the expression matrix is converted to dense only once. + """ + np.random.seed(self.random_seed) + print("=== Starting Random Forest evaluation ===") + real_data.obs["source"] = "real" + synthetic_data.obs["source"] = "synthetic" + + combined_adata = real_data.concatenate( + synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] + ) + + print("Normalizing, log-transforming, and selecting HVGs for Random Forest...") + sc.pp.normalize_total(combined_adata, target_sum=1e4) + sc.pp.log1p(combined_adata) + sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) + combined_adata = combined_adata[:, combined_adata.var['highly_variable']] + + print("Applying Combat batch correction...") + sc.pp.combat(combined_adata, key="source") + + print("Converting expression matrix to dense and splitting data...") + X = combined_adata.X.A if hasattr(combined_adata.X, "A") else combined_adata.X + y = (combined_adata.obs["source"] == "synthetic").astype(int).values + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=self.random_seed) + + print("Training Random Forest classifier...") + rf = RandomForestClassifier(n_estimators=1000, max_depth=5, random_state=self.random_seed) + rf.fit(X_train, y_train) + + pred_probs = rf.predict_proba(X_test)[:, 1] + auc = roc_auc_score(y_test, pred_probs) + print(f"Finished Random Forest evaluation: AUC = {auc:.4f}\n") + + return auc, pred_probs diff --git a/src/generators/blue_team.py b/src/generators/blue_team.py index 10ff16e..d9c2944 100644 --- a/src/generators/blue_team.py +++ b/src/generators/blue_team.py @@ -17,10 +17,11 @@ 'dpcvae': ('models.cvae', 'CVAEDataGenerationPipeline'), "ctgan": ('models.sdv_ctgan', 'CTGANDataGenerationPipeline'), "dpctgan": ('models.dpctgan', 'DPCTGANDataGenerationPipeline'), - "sc_dist": ('models.sc_dist', 'ScDistributionDataGenerator') + "sc_dist": ('models.sc_dist', 'ScDistributionDataGenerator'), + "sc_dist_sparse": ('models.sc_dist_opt', 'ScDistributionDataGenerator') } -## dynamic import to avoid package versioning errors +## dynamic import to avoid package versioning errors def get_generator_class(generator_name): if generator_name in generator_classes: module_name, class_name = generator_classes[generator_name] @@ -42,17 +43,17 @@ def cli(): def generate_split_indices(): configfile = "config.yaml" config = yaml.safe_load(open(configfile)) - rdataloader = RealDataLoader(config) + rdataloader = RealDataLoader(config) rdataloader.save_split_indices() -## the real data will be split into 5 train/test pairs +## the real data will be split into 5 train/test pairs ## based on the above generated {dataset_name}_split.yaml ## the data will be saved under data_splits/{dataset_name}/real/ @click.command() def generate_data_splits(): configfile = "config.yaml" config = yaml.safe_load(open(configfile)) - rdataloader = RealDataLoader(config) + rdataloader = RealDataLoader(config) # Save dataset rdataloader.save_split_data() @@ -143,4 +144,4 @@ def run_singlecell_generator(experiment_name: str = None): # else: # print("CUDA is NOT available.") -#check_cuda_availability() \ No newline at end of file +#check_cuda_availability() diff --git a/src/generators/models/sc_dist_opt.py b/src/generators/models/sc_dist_opt.py new file mode 100644 index 0000000..7210a96 --- /dev/null +++ b/src/generators/models/sc_dist_opt.py @@ -0,0 +1,169 @@ +import os +import sys +import pandas as pd +import numpy as np +import scipy.stats as st +import scipy.sparse as sp +import scanpy as sc +import anndata as ad +from typing import Dict, Any + +src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(src_dir) + +from generators.models.sc_base import BaseSingleCellDataGenerator + +class ScDistributionDataGenerator(BaseSingleCellDataGenerator): + def __init__(self, config: Dict[str, Any]): + super().__init__(config) + self.noise_level = self.generator_config["noise_level"] + self.random_seed = self.generator_config["random_seed"] + self.distribution = self.generator_config["distribution"] # Either 'NB' or 'Poisson' + self.cell_type_col_name = self.dataset_config["cell_type_col_name"] + self.cell_label_col_name = self.dataset_config["cell_label_col_name"] + self.batch_size = self.generator_config.get("batch_size", None) # Optional batch size + + # Parameters for the data generation + self.gene_means = None + self.num_samples = None + self.X_train_features = None + self.cell_type_params = {} + self.max_real_value = None + + self.initialize_random_seeds() + + def initialize_random_seeds(self): + np.random.seed(self.random_seed) + + def train(self): + """Compute gene expression parameters for each cell type from training data.""" + X_train_adata = self.load_train_anndata() + counts = X_train_adata.X + cell_types = X_train_adata.obs[self.cell_type_col_name].values + cell_labels = X_train_adata.obs[self.cell_label_col_name].values + + self.cell_type_to_label = dict(set(zip(cell_types, cell_labels))) + print("Cell Type to Label Mapping:", self.cell_type_to_label) + + # Determine max real expression value without converting sparse data to dense + if sp.issparse(counts): + self.max_real_value = counts.data.max() if counts.data.size > 0 else 0 + else: + self.max_real_value = counts.max() + print(f"Max real expression value from training: {self.max_real_value}") + + unique_cell_types = np.unique(cell_types) + for cell_type in unique_cell_types: + print(f"Training on Cell Type: {cell_type}") + cell_type_mask = cell_types == cell_type + cell_type_counts = counts[cell_type_mask, :] + + if sp.issparse(cell_type_counts): + # Compute means and variances on sparse matrices: + means = np.array(cell_type_counts.mean(axis=0)).ravel() + # For variance: Var(X)=E[X^2] - (E[X])^2 + sq_means = np.array(cell_type_counts.power(2).mean(axis=0)).ravel() + variances = sq_means - means**2 + else: + means = cell_type_counts.mean(axis=0) + variances = cell_type_counts.var(axis=0) + + means = np.clip(means, 1e-6, None) # Avoid zero means + + if self.distribution == 'NB': + # Ensure variance is at least the mean + variances = np.maximum(variances, means) + dispersions = (variances - means) / (means ** 2) + dispersions = np.clip(dispersions, 1e-3, 10) # Avoid extreme values + + print(f"Dispersion values for {cell_type}: min={dispersions.min()}, max={dispersions.max()}") + + if np.any(np.isnan(dispersions)): + raise ValueError(f"NaN detected in dispersions for {cell_type}!") + + self.cell_type_params[str(cell_type)] = { + 'means': means.astype(np.float32), + 'dispersions': dispersions.astype(np.float32) + } + + elif self.distribution == 'Poisson': + self.cell_type_params[str(cell_type)] = { + 'means': means.astype(np.float32) + } + + print("Training completed successfully!") + + def generate(self): + if self.max_real_value is None: + raise ValueError("Training must be completed before generating data!") + + X_test_adata = self.load_test_anndata() + counts_shape = X_test_adata.X.shape + print("Original counts shape:", counts_shape) + + cell_types = X_test_adata.obs[self.cell_type_col_name].values + synthetic_counts = sp.lil_matrix(counts_shape, dtype=np.int64) + synthetic_cell_types = [] + + unique_cell_types = np.unique(cell_types) + for cell_type in unique_cell_types: + print(f"Generating for Cell Type: {cell_type}") + + if str(cell_type) not in self.cell_type_params: + print(f"Cell type {cell_type} not found in training data! Skipping...") + continue + + cell_type_mask = cell_types == cell_type + cell_indices = np.where(cell_type_mask)[0] + num_cells = len(cell_indices) + + means = self.cell_type_params[str(cell_type)]['means'].astype(np.float64) + means = np.clip(means, 1e-6, None) # Avoid zeros + + if self.distribution == 'NB': + dispersions = self.cell_type_params[str(cell_type)]['dispersions'].astype(np.float64) + dispersions = np.clip(dispersions, 1e-3, 10) # Prevent extreme values + + # Compute Negative Binomial parameters + n_param = np.clip(1 / (dispersions + 1e-6), 1e-2, 10) + p_param = np.clip(means / (means + n_param), 0.01, 0.99) + + print(f"n_param range for {cell_type}: min={n_param.min()}, max={n_param.max()}") + print(f"p_param range for {cell_type}: min={p_param.min()}, max={p_param.max()}") + + expected_variance = means + (means ** 2) / n_param + print(f"Expected variance for {cell_type}: min={expected_variance.min()}, max={expected_variance.max()}") + + # Use batch processing if batch_size is specified, otherwise process all cells at once + batch_size = self.batch_size if self.batch_size is not None else num_cells + + for start in range(0, num_cells, batch_size): + end = min(start + batch_size, num_cells) + current_batch_size = end - start + + if self.distribution == 'NB': + batch_generated = st.nbinom.rvs( + n=n_param, p=p_param, size=(current_batch_size, means.shape[0]) + ).astype(np.int64) + elif self.distribution == 'Poisson': + batch_generated = st.poisson.rvs( + means, size=(current_batch_size, means.shape[0]) + ).astype(np.int64) + + # Limit extreme values to prevent memory explosion + upper_clip = np.percentile(batch_generated, 99.5) + batch_generated = np.clip(batch_generated, 0, min(upper_clip, self.max_real_value * 2)) + + indices = cell_indices[start:end] + synthetic_counts[indices, :] = batch_generated + synthetic_cell_types.extend([cell_type] * current_batch_size) + + synthetic_counts_csr = synthetic_counts.tocsr().astype(np.int64) + synthetic_adata = ad.AnnData(X=synthetic_counts_csr) + synthetic_adata.obs[self.cell_type_col_name] = synthetic_cell_types + synthetic_adata.var_names = X_test_adata.var_names + + return synthetic_adata + + def load_from_checkpoint(self): + pass