33from openprotein .base import APISession
44from openprotein .common import FeatureType , ReductionType
55from openprotein .data import AssayDataset , AssayMetadata
6- from openprotein .embeddings import EmbeddingModel
6+ from openprotein .embeddings import EmbeddingModel , EmbeddingsAPI
77from openprotein .errors import InvalidParameterError
88from openprotein .jobs import JobsAPI
9- from openprotein .svd import SVDModel
9+ from openprotein .svd import SVDAPI , SVDModel
1010
1111from . import api
1212from .models import UMAPModel
@@ -38,41 +38,67 @@ def fit_umap(
3838
3939 Parameters
4040 ----------
41- model_id : str
42- The ID of the model to fit the UMAP on.
4341 sequences: list of bytes or None, optional
4442 Optional sequences to fit UMAP with. Either use sequences or
4543 assay_id. sequences is preferred.
4644 assay : AssayMetadata or AssayDataset or str or None, optional
4745 Optional assay containing sequences to fit SVD with.
4846 Or its assay_id. Either use sequences or assay.
4947 Ignored if sequences are provided.
50- n_components: int
48+ model : EmbeddingModel or SVDModel or str
49+ Instance of either EmbeddingModel or SVDModel to use depending
50+ on feature type. Can also be a str specifying the model id,
51+ but then feature_type would have to be specified.
52+ feature_type : FeatureType or None, optional
53+ Type of features to use for encoding sequences. "SVD" or "PLM".
54+ None would require model to be EmbeddingModel or SVDModel.
55+ n_components : int, optional
5156 Number of UMAP components to fit. Defaults to 2.
52- n_neighbors: int
57+ n_neighbors : int, optional
5358 Number of neighbors to use for fitting. Defaults to 15.
54- min_dist: float
59+ min_dist : float, optional
5560 Minimum distance in UMAP fitting. Defaults to 0.1.
56- reduction: str or None, optional
61+ reduction : str or None, optional
5762 Type of embedding reduction to use for computing features.
5863 E.g. "MEAN" or "SUM". Useful when dealing with variable length
5964 sequence. Defaults to None.
60- kwargs:
65+ kwargs :
6166 Additional keyword arguments to be passed to foundational models, e.g. prompt_id for PoET models.
6267
6368 Returns
6469 -------
6570 UMAPModel
6671 The UMAP model being fit.
6772 """
68- if isinstance (model , str ):
69- if feature_type is None :
73+ # extract feature type
74+ feature_type = (
75+ FeatureType .PLM
76+ if isinstance (model , EmbeddingModel )
77+ else FeatureType .SVD if isinstance (model , SVDModel ) else feature_type
78+ )
79+ if feature_type is None :
80+ raise InvalidParameterError (
81+ "Expected feature_type to be provided if passing str model_id as model"
82+ )
83+ # get model if model_id
84+ if feature_type == FeatureType .PLM :
85+ if reduction is None :
7086 raise InvalidParameterError (
71- "Expected feature_type to be specified if using a string identifier as model "
87+ "Expected reduction if using EmbeddingModel "
7288 )
73- model_id = model
74- else :
75- model_id = model .id # for embeddings / svd model
89+ if isinstance (model , str ):
90+ embeddings_api = getattr (self .session , "embedding" , None )
91+ assert isinstance (embeddings_api , EmbeddingsAPI )
92+ model = embeddings_api .get_model (model )
93+ assert isinstance (model , EmbeddingModel ), "Expected EmbeddingModel"
94+ model_id = model .id
95+ elif feature_type == FeatureType .SVD :
96+ if isinstance (model , str ):
97+ svd_api = getattr (self .session , "svd" , None )
98+ assert isinstance (svd_api , SVDAPI )
99+ model = svd_api .get_svd (model )
100+ assert isinstance (model , SVDModel ), "Expected SVDModel"
101+ model_id = model .id
76102 # get assay_id
77103 assay_id = (
78104 assay .assay_id
@@ -84,6 +110,7 @@ def fit_umap(
84110 job = api .umap_fit_post (
85111 session = self .session ,
86112 model_id = model_id ,
113+ feature_type = feature_type ,
87114 sequences = sequences ,
88115 assay_id = assay_id ,
89116 n_components = n_components ,
0 commit comments