Skip to content

Commit 46c50a4

Browse files
snowflake-provisionerSnowflake Authors
andauthored
Project import generated by Copybara. (#10)
GitOrigin-RevId: 0e6fdd4a1de3105200fbe008718a93be6ef1b1d1 Co-authored-by: Snowflake Authors <[email protected]>
1 parent 0f49221 commit 46c50a4

File tree

18 files changed

+605
-278
lines changed

18 files changed

+605
-278
lines changed

snowflake/ml/model/BUILD.bazel

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ py_library(
6161
srcs = ["_deployer.py"],
6262
deps = [
6363
":_udf_util",
64-
":_model_meta",
6564
":model_signature",
6665
":type_hints",
6766
"//snowflake/ml/_internal/utils:identifier"
@@ -87,7 +86,8 @@ py_library(
8786
":model_signature",
8887
":type_hints",
8988
"//snowflake/ml/model/_handlers:custom",
90-
"//snowflake/ml/model/_handlers:sklearn"
89+
"//snowflake/ml/model/_handlers:sklearn",
90+
"//snowflake/ml/model/_handlers:xgboost"
9191
],
9292
)
9393

snowflake/ml/model/_deployer.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,7 @@
88
from typing_extensions import Required
99

1010
from snowflake.ml._internal.utils import identifier
11-
from snowflake.ml.model import (
12-
_model_meta,
13-
_udf_util,
14-
model_signature,
15-
type_hints as model_types,
16-
)
11+
from snowflake.ml.model import _udf_util, model_signature, type_hints as model_types
1712
from snowflake.snowpark import DataFrame, Session, functions as F
1813
from snowflake.snowpark._internal import type_utils
1914

@@ -28,18 +23,14 @@ class Deployment(TypedDict):
2823
2924
Attributes:
3025
name: Name of the deployment.
31-
model: The model object that get deployed.
32-
model_meta: The model metadata.
3326
platform: Target platform to deploy the model.
34-
target_method: The name of the target method to be deployed.
27+
signature: The signature of the model method.
3528
options: Additional options when deploying the model.
3629
"""
3730

3831
name: Required[str]
3932
platform: Required[TargetPlatform]
40-
model: Required[model_types.ModelType]
41-
model_meta: Required[_model_meta.ModelMetadata]
42-
target_method: str
33+
signature: model_signature.ModelSignature
4334
options: Required[model_types.DeployOptions]
4435

4536

@@ -53,19 +44,15 @@ def create(
5344
self,
5445
name: str,
5546
platform: TargetPlatform,
56-
model: model_types.ModelType,
57-
model_meta: _model_meta.ModelMetadata,
58-
target_method: str,
47+
signature: model_signature.ModelSignature,
5948
options: Optional[model_types.DeployOptions] = None,
6049
) -> Deployment:
6150
"""Create a deployment.
6251
6352
Args:
6453
name: Name of the deployment for the model.
65-
model: The model object that get deployed.
66-
model_meta: The model metadata.
6754
platform: Target platform to deploy the model.
68-
target_method: The name of the target method to be deployed.
55+
signature: The signature of the model method.
6956
options: Additional options when deploying the model.
7057
Each target platform will have their own specifications of options.
7158
"""
@@ -105,19 +92,15 @@ def create(
10592
self,
10693
name: str,
10794
platform: TargetPlatform,
108-
model: model_types.ModelType,
109-
model_meta: _model_meta.ModelMetadata,
110-
target_method: str,
95+
signature: model_signature.ModelSignature,
11196
options: Optional[model_types.DeployOptions] = None,
11297
) -> Deployment:
11398
"""Create a deployment.
11499
115100
Args:
116101
name: Name of the deployment for the model.
117102
platform: Target platform to deploy the model.
118-
model: The model object that get deployed.
119-
model_meta: The model metadata.
120-
target_method: The name of the target method to be deployed.
103+
signature: The signature of the model method.
121104
options: Additional options when deploying the model.
122105
Each target platform will have their own specifications of options.
123106
@@ -129,9 +112,7 @@ def create(
129112
info = Deployment(
130113
name=name,
131114
platform=platform,
132-
model=model,
133-
model_meta=model_meta,
134-
target_method=target_method,
115+
signature=signature,
135116
options=options,
136117
)
137118
self._storage[name] = info
@@ -203,6 +184,7 @@ def create_deployment(
203184
204185
Raises:
205186
RuntimeError: Raised when running into issues when deploying.
187+
ValueError: Raised when target method does not exist in model.
206188
207189
Returns:
208190
The deployment information.
@@ -225,7 +207,10 @@ def create_deployment(
225207
target_method=target_method,
226208
**options,
227209
)
228-
info = self._manager.create(name, platform, m, meta, target_method, options)
210+
signature = meta.signatures.get(target_method, None)
211+
if not signature:
212+
raise ValueError(f"Target method {target_method} does not exist in model.")
213+
info = self._manager.create(name=name, platform=platform, signature=signature, options=options)
229214
is_success = True
230215
except Exception as e:
231216
print(e)
@@ -303,10 +288,8 @@ def predict(self, name: str, X: Union[model_types.SupportedDataType, DataFrame])
303288
d = self.get_deployment(name)
304289
if not d:
305290
raise ValueError(f"Deployment {name} does not exist.")
306-
meta = d["model_meta"]
307-
target_method = d["target_method"]
291+
sig = d["signature"]
308292
keep_order = d["options"].get("keep_order", True)
309-
sig = meta.signatures[target_method]
310293
if not isinstance(X, DataFrame):
311294
df = model_signature._validate_data_with_features_and_convert_to_df(sig.inputs, X)
312295
s_df = self._session.create_dataframe(df)

snowflake/ml/model/_handlers/BUILD.bazel

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,16 @@ py_library(
3838
"//snowflake/ml/model:model_signature",
3939
],
4040
)
41+
42+
py_library(
43+
name = "xgboost",
44+
srcs = ["xgboost.py"],
45+
deps = [
46+
":_base",
47+
"//snowflake/ml/model:_model_meta",
48+
"//snowflake/ml/model:custom_model",
49+
"//snowflake/ml/_internal:type_utils",
50+
"//snowflake/ml/model:type_hints",
51+
"//snowflake/ml/model:model_signature",
52+
],
53+
)

snowflake/ml/model/_handlers/_base.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from abc import ABC, abstractmethod
2-
from typing import Optional
2+
from typing import Generic, Optional
33

44
from snowflake.ml.model import _model_meta, type_hints as model_types
55

66

7-
class _ModelHandler(ABC):
7+
class _ModelHandler(ABC, Generic[model_types.ModelType]):
88
"""Provides handling for a given type of model defined by `type` class property."""
99

1010
handler_type = "_base"
@@ -14,14 +14,24 @@ class _ModelHandler(ABC):
1414

1515
@staticmethod
1616
@abstractmethod
17-
def can_handle(model: model_types.ModelType) -> bool:
17+
def can_handle(model: model_types.SupportedModelType) -> bool:
1818
"""Whether this handler could support the type of the `model`.
1919
2020
Args:
2121
model: The model object.
2222
"""
2323
...
2424

25+
@staticmethod
26+
@abstractmethod
27+
def cast_model(model: model_types.SupportedModelType) -> model_types.ModelType:
28+
"""Cast the model from Union type into the type that handler could handle.
29+
30+
Args:
31+
model: The model object.
32+
"""
33+
...
34+
2535
@staticmethod
2636
@abstractmethod
2737
def _save_model(
@@ -30,6 +40,7 @@ def _save_model(
3040
model_meta: _model_meta.ModelMetadata,
3141
model_blobs_dir_path: str,
3242
sample_input: Optional[model_types.SupportedDataType] = None,
43+
is_sub_model: Optional[bool] = False,
3344
) -> None:
3445
"""Save the model.
3546
@@ -39,6 +50,7 @@ def _save_model(
3950
model_meta: The model metadata.
4051
model_blobs_dir_path: Directory path to the model.
4152
sample_input: Sample input to infer the signatures from.
53+
is_sub_model: Flag to show if it is a sub model, a sub model does not need signature.
4254
"""
4355
...
4456

snowflake/ml/model/_handlers/custom.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
22
import os
33
import sys
4-
from typing import TYPE_CHECKING, Optional
4+
from typing import TYPE_CHECKING, Dict, Optional
55

66
import anyio
77
import cloudpickle
@@ -20,48 +20,57 @@
2020
from snowflake.ml.model import custom_model
2121

2222

23-
class _CustomModelHandler(_base._ModelHandler):
23+
class _CustomModelHandler(_base._ModelHandler["custom_model.CustomModel"]):
2424
"""Handler for custom model."""
2525

2626
handler_type = "custom"
2727

2828
@staticmethod
29-
def can_handle(model: model_types.ModelType) -> bool:
29+
def can_handle(model: model_types.SupportedModelType) -> bool:
3030
return bool(type_utils.LazyType("snowflake.ml.model.custom_model.CustomModel").isinstance(model))
3131

32+
@staticmethod
33+
def cast_model(model: model_types.SupportedModelType) -> "custom_model.CustomModel":
34+
from snowflake.ml.model import custom_model
35+
36+
assert isinstance(model, custom_model.CustomModel)
37+
return model
38+
3239
@staticmethod
3340
def _save_model(
3441
name: str,
3542
model: "custom_model.CustomModel",
3643
model_meta: model_meta_api.ModelMetadata,
3744
model_blobs_dir_path: str,
3845
sample_input: Optional[model_types.SupportedDataType] = None,
46+
is_sub_model: Optional[bool] = False,
3947
**kwargs: Unpack[model_types.CustomModelSaveOption],
4048
) -> None:
4149
from snowflake.ml.model import custom_model
4250

4351
assert isinstance(model, custom_model.CustomModel)
4452

45-
if model_meta._signatures is None:
46-
# In this case sample_input should be available, because of the check in save_model.
47-
assert sample_input is not None
48-
model_meta._signatures = {}
49-
for target_method in model._get_infer_methods():
50-
if inspect.iscoroutinefunction(target_method):
51-
with anyio.start_blocking_portal() as portal:
52-
predictions_df = portal.call(target_method, model, sample_input)
53-
else:
54-
predictions_df = target_method(model, sample_input)
55-
func_name = target_method.__name__
56-
sig = model_signature.infer_signature(sample_input, predictions_df)
57-
model_meta._signatures[func_name] = sig
58-
else:
59-
method_names = [method.__name__ for method in model._get_infer_methods()]
60-
for method_name in model_meta._signatures.keys():
61-
if method_name not in method_names:
62-
raise ValueError(f"Target method {method_name} does not exists.")
63-
if not callable(getattr(model, method_name, None)):
64-
raise ValueError(f"Target method {method_name} is not callable.")
53+
if not is_sub_model:
54+
if model_meta._signatures is None:
55+
# In this case sample_input should be available, because of the check in save_model.
56+
assert sample_input is not None
57+
model_meta._signatures = {}
58+
for target_method in model._get_infer_methods():
59+
if inspect.iscoroutinefunction(target_method):
60+
with anyio.start_blocking_portal() as portal:
61+
predictions_df = portal.call(target_method, model, sample_input)
62+
else:
63+
predictions_df = target_method(model, sample_input)
64+
func_name = target_method.__name__
65+
sig = model_signature.infer_signature(sample_input, predictions_df)
66+
model_meta._signatures[func_name] = sig
67+
else:
68+
method_names = [method.__name__ for method in model._get_infer_methods()]
69+
for method_name in model_meta._signatures.keys():
70+
if method_name not in method_names:
71+
raise ValueError(f"Target method {method_name} does not exists.")
72+
if not callable(getattr(model, method_name, None)):
73+
raise ValueError(f"Target method {method_name} is not callable.")
6574

6675
model_blob_path = os.path.join(model_blobs_dir_path, name)
6776
os.makedirs(model_blob_path, exist_ok=True)
@@ -76,7 +85,14 @@ def _save_model(
7685
for sub_name, model_ref in model.context.model_refs.items():
7786
handler = _model_handler._find_handler(model_ref.model)
7887
assert handler is not None
79-
handler._save_model(sub_name, model_ref.model, model_meta, model_blobs_dir_path)
88+
sub_model = handler.cast_model(model_ref.model)
89+
handler._save_model(
90+
name=sub_name,
91+
model=sub_model,
92+
model_meta=model_meta,
93+
model_blobs_dir_path=model_blobs_dir_path,
94+
is_sub_model=True,
95+
)
8096

8197
# Make sure that the module where the model is defined get pickled by value as well.
8298
cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
@@ -115,7 +131,7 @@ def _load_model(
115131

116132
artifacts_meta = model_blob_metadata.artifacts
117133
artifacts = {name: os.path.join(model_blob_path, rel_path) for name, rel_path in artifacts_meta.items()}
118-
models = dict()
134+
models: Dict[str, model_types.SupportedModelType] = dict()
119135
for sub_model_name, _ref in m.context.model_refs.items():
120136
model_type = model_meta.models[sub_model_name].model_type
121137
handler = _model_handler._load_handler(model_type)

0 commit comments

Comments
 (0)