Skip to content

Commit cd38c89

Browse files
snowflake-provisionerSnowflake Authors
andauthored
Project import generated by Copybara. (#16)
GitOrigin-RevId: 1c09d7ecb92720c6367448f920684dabf40d2813 Co-authored-by: Snowflake Authors <[email protected]>
1 parent 2a14eaf commit cd38c89

37 files changed

+2596
-542
lines changed

ci/conda_recipe/meta.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,25 @@ requirements:
2222
- python
2323
- absl-py>=0.15,<2
2424
- anyio>=3.5.0,<4
25+
- cloudpickle
2526
- fsspec>=2022.11,<=2023.1
2627
- numpy>=1.23,<1.24
28+
- packaging>=23.0,<24
2729
- pyyaml>=6.0,<7
2830
- scipy>=1.9,<2
2931
- snowflake-connector-python
3032
- snowflake-snowpark-python>=1.4.0,<=2
3133
- sqlparse>=0.4,<1
34+
- typing-extensions>=4.1.0,<5
35+
36+
# conda-libmamba-solver is conda-specific requirement, and should not appear in wheel's dependency.
37+
- conda-libmamba-solver>=23.1.0,<24
3238

3339
# TODO(snandamuri): Versions of these packages must be exactly same between user's workspace and
3440
# snowpark sandbox. Generic definitions like scikit-learn>=1.1.0,<2 wont work because snowflake conda channel
3541
# only has a few allowlisted versions of scikit-learn available, so we must force users to use scikit-learn
3642
# versions that are available in the snowflake conda channel. Since there is no way to specify allow list of
3743
# versions in the requirements file, we are pinning the versions here.
38-
- joblib>=1.0.0,<=1.1.1
3944
- scikit-learn>=1.2.1,<2
4045
- xgboost==1.7.3
4146
about:

codegen/sklearn_wrapper_generator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -802,9 +802,10 @@ def generate(self) -> "SklearnWrapperGenerator":
802802
if self._is_hist_gradient_boosting_regressor:
803803
self.test_estimator_input_args_list.extend(["min_samples_leaf=1", "max_leaf_nodes=100"])
804804

805+
# TODO(snandamuri): Replace cloudpickle with joblib after latest version of joblib is added to snowflake conda.
805806
self.fit_sproc_deps = self.predict_udf_deps = (
806807
"f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'scikit-learn=={sklearn.__version__}', "
807-
"f'xgboost=={xgboost.__version__}', f'joblib=={joblib.__version__}'"
808+
"f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'"
808809
)
809810
self._construct_string_from_lists()
810811
return self
@@ -819,9 +820,10 @@ def generate(self) -> "XGBoostWrapperGenerator":
819820
self.estimator_imports_list.append("import xgboost")
820821
self.test_estimator_input_args_list.extend(["random_state=0", "subsample=1.0", "colsample_bynode=1.0"])
821822
self.fit_sproc_imports = "import xgboost"
823+
# TODO(snandamuri): Replace cloudpickle with joblib after latest version of joblib is added to snowflake conda.
822824
self.fit_sproc_deps = self.predict_udf_deps = (
823825
"f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'xgboost=={xgboost.__version__}', "
824-
"f'joblib=={joblib.__version__}'"
826+
"f'cloudpickle=={cp.__version__}'"
825827
)
826828
self._construct_string_from_lists()
827829
return self
@@ -836,9 +838,10 @@ def generate(self) -> "LightGBMWrapperGenerator":
836838
self.estimator_imports_list.append("import lightgbm")
837839
self.test_estimator_input_args_list.extend(["random_state=0"])
838840
self.fit_sproc_imports = "import lightgbm"
841+
# TODO(snandamuri): Replace cloudpickle with joblib after latest version of joblib is added to snowflake conda.
839842
self.fit_sproc_deps = self.predict_udf_deps = (
840843
"f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'lightgbm=={lightgbm.__version__}', "
841-
"f'joblib=={joblib.__version__}'"
844+
"f'cloudpickle=={cp.__version__}'"
842845
)
843846
self._construct_string_from_lists()
844847
return self

codegen/sklearn_wrapper_template.py_template

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import os
66
from typing import Iterable, Optional, Union, List, Any, Dict, Callable
77
from uuid import uuid4
88

9-
import joblib
9+
import cloudpickle as cp
1010
import pandas as pd
1111
import numpy as np
1212
{transform.estimator_imports}
@@ -183,7 +183,8 @@ class {transform.original_class_name}(BaseTransformer):
183183

184184
# Create a temp file and dump the transform to that file.
185185
local_transform_file_name = get_temp_file_path()
186-
joblib.dump(self._sklearn_object, local_transform_file_name)
186+
with open(local_transform_file_name, mode="w+b") as local_transform_file:
187+
cp.dump(self._sklearn_object, local_transform_file)
187188

188189
# Create temp stage to run fit.
189190
transform_stage_name = "SNOWML_TRANSFORM_{{safe_id}}".format(safe_id=self.id)
@@ -214,7 +215,13 @@ class {transform.original_class_name}(BaseTransformer):
214215
custom_tags=dict([("autogen", True)]),
215216
)
216217
# Put locally serialized transform on stage.
217-
session.file.put(local_transform_file_name, stage_transform_file_name, auto_compress=False, overwrite=True, statement_params=statement_params)
218+
session.file.put(
219+
local_transform_file_name,
220+
stage_transform_file_name,
221+
auto_compress=False,
222+
overwrite=True,
223+
statement_params=statement_params
224+
)
218225

219226
@sproc(
220227
is_permanent=False,
@@ -233,7 +240,7 @@ class {transform.original_class_name}(BaseTransformer):
233240
label_cols: List[str],
234241
sample_weight_col: Optional[str]
235242
) -> str:
236-
import joblib
243+
import cloudpickle as cp
237244
import numpy as np
238245
import os
239246
import pandas
@@ -251,7 +258,12 @@ class {transform.original_class_name}(BaseTransformer):
251258

252259
session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
253260

254-
estimator = joblib.load(os.path.join(local_transform_file_name, os.listdir(local_transform_file_name)[0]))
261+
local_transform_file_path = os.path.join(
262+
local_transform_file_name,
263+
os.listdir(local_transform_file_name)[0]
264+
)
265+
with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
266+
estimator = cp.load(local_transform_file_obj)
255267

256268
argspec = inspect.getfullargspec(estimator.fit)
257269
args = {{'X': df[input_cols]}}
@@ -268,12 +280,20 @@ class {transform.original_class_name}(BaseTransformer):
268280
local_result_file_name = local_result_file.name
269281
local_result_file.close()
270282

271-
joblib_dump_files = joblib.dump(estimator, local_result_file_name)
272-
session.file.put(local_result_file_name, stage_result_file_name, auto_compress = False, overwrite = True, statement_params=statement_params)
283+
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
284+
cp.dump(estimator, local_result_file_obj)
285+
286+
session.file.put(
287+
local_result_file_name,
288+
stage_result_file_name,
289+
auto_compress = False,
290+
overwrite = True,
291+
statement_params=statement_params
292+
)
273293

274294
# Note: you can add something like + "|" + str(df) to the return string
275295
# to pass debug information to the caller.
276-
return str(os.path.basename(joblib_dump_files[0]))
296+
return str(os.path.basename(local_result_file_name))
277297

278298
# Call fit sproc
279299
statement_params = telemetry.get_function_usage_statement_params(
@@ -302,8 +322,13 @@ class {transform.original_class_name}(BaseTransformer):
302322
if len(fields) > 1:
303323
print("\n".join(fields[1:]))
304324

305-
session.file.get(os.path.join(stage_result_file_name, sproc_export_file_name), local_result_file_name, statement_params=statement_params)
306-
self._sklearn_object = joblib.load(os.path.join(local_result_file_name, sproc_export_file_name))
325+
session.file.get(
326+
os.path.join(stage_result_file_name, sproc_export_file_name),
327+
local_result_file_name,
328+
statement_params=statement_params
329+
)
330+
with open(os.path.join(local_result_file_name, sproc_export_file_name),mode="r+b") as result_file_obj:
331+
self._sklearn_object = cp.load(result_file_obj)
307332

308333
cleanup_temp_files([local_transform_file_name, local_result_file_name])
309334

@@ -843,7 +868,8 @@ class {transform.original_class_name}(BaseTransformer):
843868

844869
# Create a temp file and dump the score to that file.
845870
local_score_file_name = get_temp_file_path()
846-
joblib.dump(self._sklearn_object, local_score_file_name)
871+
with open(local_score_file_name, mode="w+b") as local_score_file:
872+
cp.dump(self._sklearn_object, local_score_file)
847873

848874
# Create temp stage to run score.
849875
score_stage_name = "SNOWML_SCORE_{{safe_id}}".format(safe_id=self.id)
@@ -872,7 +898,13 @@ class {transform.original_class_name}(BaseTransformer):
872898
custom_tags=dict([("autogen", True)]),
873899
)
874900
# Put locally serialized score on stage.
875-
session.file.put(local_score_file_name, stage_score_file_name, auto_compress=False, overwrite=True, statement_params=statement_params)
901+
session.file.put(
902+
local_score_file_name,
903+
stage_score_file_name,
904+
auto_compress=False,
905+
overwrite=True,
906+
statement_params=statement_params
907+
)
876908

877909
@sproc(
878910
is_permanent=False,
@@ -890,7 +922,7 @@ class {transform.original_class_name}(BaseTransformer):
890922
label_cols: List[str],
891923
sample_weight_col: Optional[str]
892924
) -> float:
893-
import joblib
925+
import cloudpickle as cp
894926
import numpy as np
895927
import os
896928
import pandas
@@ -905,7 +937,11 @@ class {transform.original_class_name}(BaseTransformer):
905937
local_score_file.close()
906938

907939
session.file.get(stage_score_file_name, local_score_file_name, statement_params=statement_params)
908-
estimator = joblib.load(os.path.join(local_score_file_name, os.listdir(local_score_file_name)[0]))
940+
941+
local_score_file_name_path = os.path.join(local_score_file_name, os.listdir(local_score_file_name)[0])
942+
with open(local_score_file_name_path, mode="r+b") as local_score_file_obj:
943+
estimator = cp.load(local_score_file_obj)
944+
909945
argspec = inspect.getfullargspec(estimator.score)
910946
if "X" in argspec.args:
911947
args = {{'X': df[input_cols]}}

conda-env-snowflake.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies:
2525
- lightgbm==3.3.5
2626
- networkx==2.8.4
2727
- numpy==1.23.4
28+
- packaging==23.0
2829
- pandas==1.4.4
2930
- pytest==7.1.2
3031
- python==3.8.13
@@ -35,6 +36,6 @@ dependencies:
3536
- scikit-learn==1.2.2
3637
- snowflake-snowpark-python==1.4.0
3738
- sqlparse==0.4.3
38-
- typing-extensions==4.3.0
39+
- typing-extensions==4.5.0
3940
- xgboost==1.7.3
4041
- mypy==0.981 # not a package dependency.

conda-env.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dependencies:
2222
- mypy==0.981
2323
- networkx==2.8.4
2424
- numpy==1.23.4
25+
- packaging==23.0
2526
- pandas==1.4.4
2627
- pytest==7.1.2
2728
- python==3.8.13
@@ -36,5 +37,5 @@ dependencies:
3637
- torchdata==0.4.1
3738
- transformers==4.27.1
3839
- types-PyYAML==6.0.12
39-
- typing-extensions==4.3.0
40+
- typing-extensions==4.5.0
4041
- xgboost==1.7.3

snowflake/ml/BUILD.bazel

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,16 @@ snowml_wheel(
3838
requires = [
3939
"absl-py>=0.15,<2",
4040
"anyio>=3.5.0,<4",
41+
"cloudpickle", # Version range is specified by snowpark. We are implicitly depending on it.
4142
"fsspec[http]>=2022.11,<=2023.1",
4243
"numpy>=1.23,<1.24",
44+
"packaging>=23.0,<24",
4345
"pyyaml>=6.0,<7",
4446
"scipy>=1.9,<2",
4547
"snowflake-connector-python[pandas]",
4648
"snowflake-snowpark-python>=1.4.0,<2",
4749
"sqlparse>=0.4,<1",
50+
"typing-extensions>=4.1.0,<5",
4851

4952
# TODO(snandamuri): Versions of these packages must be exactly same between user's workspace and
5053
# snowpark sandbox. Generic definitions like scikit-learn>=1.1.0,<2 wont work because snowflake conda channel
@@ -53,7 +56,6 @@ snowml_wheel(
5356
# versions in the requirements file, we are pinning the versions here.
5457
"scikit-learn>=1.2.1,<2",
5558
"xgboost==1.7.3",
56-
"joblib>=1.0.0,<=1.1.1", # All the release versions between 1.0.0 and 1.1.1 are available in SF Conda channel.
5759
],
5860
version = VERSION,
5961
deps = [

snowflake/ml/_internal/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ py_test(
4646
srcs = ["env_utils_test.py"],
4747
deps = [
4848
":env_utils",
49+
":env",
4950
"//snowflake/ml/test_utils:mock_data_frame",
5051
"//snowflake/ml/test_utils:mock_session",
5152
],

snowflake/ml/_internal/env_utils.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from snowflake.ml._internal.utils import query_result_checker
1212
from snowflake.snowpark import session
1313

14+
_INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION: Optional[bool] = None
1415
_SNOWFLAKE_CONDA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
1516

1617

@@ -219,13 +220,16 @@ def relax_requirement_version(req: requirements.Requirement) -> requirements.Req
219220
return new_req
220221

221222

222-
def resolve_conda_environment(packages: List[requirements.Requirement], channels: List[str]) -> Optional[List[str]]:
223+
def resolve_conda_environment(
224+
packages: List[requirements.Requirement], channels: List[str], python_version: str
225+
) -> Optional[List[str]]:
223226
"""Use conda api to check if given packages are resolvable in given channels. Only work when conda is
224227
locally installed.
225228
226229
Args:
227230
packages: Packages to be installed.
228231
channels: Anaconda channels (name or url) where conda should search into.
232+
python_version: A string of python version where model is run.
229233
230234
Returns:
231235
List of frozen dependencies represented in PEP 508 form if resolvable, None otherwise.
@@ -234,7 +238,7 @@ def resolve_conda_environment(packages: List[requirements.Requirement], channels
234238
from conda_libmamba_solver import solver
235239

236240
package_names = list(map(lambda x: x.name, packages))
237-
specs = list(map(str, packages))
241+
specs = list(map(str, packages)) + [f"python=={python_version}"]
238242

239243
conda_solver = solver.LibMambaSolver("snow-env", channels=channels, specs_to_add=specs)
240244
try:
@@ -252,18 +256,38 @@ def resolve_conda_environment(packages: List[requirements.Requirement], channels
252256
)
253257

254258

259+
def _check_runtime_version_column_existence(session: session.Session) -> bool:
260+
sql = textwrap.dedent(
261+
"""
262+
SHOW COLUMNS
263+
LIKE 'runtime_version'
264+
IN TABLE information_schema.packages;
265+
"""
266+
)
267+
result = session.sql(sql).count()
268+
return result == 1
269+
270+
255271
def validate_requirements_in_snowflake_conda_channel(
256-
session: session.Session, reqs: List[requirements.Requirement]
272+
session: session.Session, reqs: List[requirements.Requirement], python_version: str
257273
) -> Optional[List[str]]:
258274
"""Search the snowflake anaconda channel for packages with version meet the specifier.
259275
260276
Args:
261277
session: Snowflake connection session.
262278
reqs: List of requirement specifiers.
279+
python_version: A string of python version where model is run.
280+
281+
Raises:
282+
ValueError: Raised when the specifier cannot be supported when creating UDF.
263283
264284
Returns:
265285
A list of pinned latest version that available in Snowflake anaconda channel and meet the version specifier.
266286
"""
287+
global _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION
288+
289+
if _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION is None:
290+
_INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION = _check_runtime_version_column_existence(session)
267291
ret_list = []
268292
reqs_to_request = []
269293
for req in reqs:
@@ -273,14 +297,26 @@ def validate_requirements_in_snowflake_conda_channel(
273297
pkg_names_str = " OR ".join(
274298
f"package_name = '{req_name}'" for req_name in sorted(req.name for req in reqs_to_request)
275299
)
276-
sql = textwrap.dedent(
277-
f"""
278-
SELECT PACKAGE_NAME, VERSION
279-
FROM information_schema.packages
280-
WHERE ({pkg_names_str})
281-
AND language = 'python';
282-
"""
283-
)
300+
if _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION:
301+
parsed_python_version = version.Version(python_version)
302+
sql = textwrap.dedent(
303+
f"""
304+
SELECT PACKAGE_NAME, VERSION
305+
FROM information_schema.packages
306+
WHERE ({pkg_names_str})
307+
AND language = 'python'
308+
AND runtime_version = '{parsed_python_version.major}.{parsed_python_version.minor}';
309+
"""
310+
)
311+
else:
312+
sql = textwrap.dedent(
313+
f"""
314+
SELECT PACKAGE_NAME, VERSION
315+
FROM information_schema.packages
316+
WHERE ({pkg_names_str})
317+
AND language = 'python';
318+
"""
319+
)
284320

285321
try:
286322
result = (
@@ -301,10 +337,11 @@ def validate_requirements_in_snowflake_conda_channel(
301337
except snowflake.connector.DataError:
302338
return None
303339
for req in reqs:
304-
available_versions = list(req.specifier.filter(_SNOWFLAKE_CONDA_PACKAGE_CACHE.get(req.name, [])))
340+
if len(req.specifier) > 1 or any(spec.operator != "==" for spec in req.specifier):
341+
raise ValueError("At most 1 version specifier using == operator is supported without local conda resolver.")
342+
available_versions = list(req.specifier.filter(set(_SNOWFLAKE_CONDA_PACKAGE_CACHE.get(req.name, []))))
305343
if not available_versions:
306344
return None
307345
else:
308-
latest_version = max(available_versions)
309-
ret_list.append(f"{req.name}=={latest_version}")
346+
ret_list.append(str(req))
310347
return sorted(ret_list)

0 commit comments

Comments
 (0)