@@ -6,7 +6,7 @@ import os
66from typing import Iterable, Optional, Union, List, Any, Dict, Callable
77from uuid import uuid4
88
9- import joblib
9+ import cloudpickle as cp
1010import pandas as pd
1111import numpy as np
1212{transform.estimator_imports}
@@ -183,7 +183,8 @@ class {transform.original_class_name}(BaseTransformer):
183183
184184 # Create a temp file and dump the transform to that file.
185185 local_transform_file_name = get_temp_file_path()
186- joblib.dump(self._sklearn_object, local_transform_file_name)
186+ with open(local_transform_file_name, mode="w+b") as local_transform_file:
187+ cp.dump(self._sklearn_object, local_transform_file)
187188
188189 # Create temp stage to run fit.
189190 transform_stage_name = "SNOWML_TRANSFORM_{{safe_id}}".format(safe_id=self.id)
@@ -214,7 +215,13 @@ class {transform.original_class_name}(BaseTransformer):
214215 custom_tags=dict([("autogen", True)]),
215216 )
216217 # Put locally serialized transform on stage.
217- session.file.put(local_transform_file_name, stage_transform_file_name, auto_compress=False, overwrite=True, statement_params=statement_params)
218+ session.file.put(
219+ local_transform_file_name,
220+ stage_transform_file_name,
221+ auto_compress=False,
222+ overwrite=True,
223+ statement_params=statement_params
224+ )
218225
219226 @sproc(
220227 is_permanent=False,
@@ -233,7 +240,7 @@ class {transform.original_class_name}(BaseTransformer):
233240 label_cols: List[str],
234241 sample_weight_col: Optional[str]
235242 ) -> str:
236- import joblib
243+ import cloudpickle as cp
237244 import numpy as np
238245 import os
239246 import pandas
@@ -251,7 +258,12 @@ class {transform.original_class_name}(BaseTransformer):
251258
252259 session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
253260
254- estimator = joblib.load(os.path.join(local_transform_file_name, os.listdir(local_transform_file_name)[0]))
261+ local_transform_file_path = os.path.join(
262+ local_transform_file_name,
263+ os.listdir(local_transform_file_name)[0]
264+ )
265+ with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
266+ estimator = cp.load(local_transform_file_obj)
255267
256268 argspec = inspect.getfullargspec(estimator.fit)
257269 args = {{'X': df[input_cols]}}
@@ -268,12 +280,20 @@ class {transform.original_class_name}(BaseTransformer):
268280 local_result_file_name = local_result_file.name
269281 local_result_file.close()
270282
271- joblib_dump_files = joblib.dump(estimator, local_result_file_name)
272- session.file.put(local_result_file_name, stage_result_file_name, auto_compress = False, overwrite = True, statement_params=statement_params)
283+ with open(local_result_file_name, mode="w+b") as local_result_file_obj:
284+ cp.dump(estimator, local_result_file_obj)
285+
286+ session.file.put(
287+ local_result_file_name,
288+ stage_result_file_name,
289+ auto_compress = False,
290+ overwrite = True,
291+ statement_params=statement_params
292+ )
273293
274294 # Note: you can add something like + "|" + str(df) to the return string
275295 # to pass debug information to the caller.
276- return str(os.path.basename(joblib_dump_files[0] ))
296+ return str(os.path.basename(local_result_file_name ))
277297
278298 # Call fit sproc
279299 statement_params = telemetry.get_function_usage_statement_params(
@@ -302,8 +322,13 @@ class {transform.original_class_name}(BaseTransformer):
302322 if len(fields) > 1:
303323 print("\n".join(fields[1:]))
304324
305- session.file.get(os.path.join(stage_result_file_name, sproc_export_file_name), local_result_file_name, statement_params=statement_params)
306- self._sklearn_object = joblib.load(os.path.join(local_result_file_name, sproc_export_file_name))
325+ session.file.get(
326+ os.path.join(stage_result_file_name, sproc_export_file_name),
327+ local_result_file_name,
328+ statement_params=statement_params
329+ )
330+ with open(os.path.join(local_result_file_name, sproc_export_file_name),mode="r+b") as result_file_obj:
331+ self._sklearn_object = cp.load(result_file_obj)
307332
308333 cleanup_temp_files([local_transform_file_name, local_result_file_name])
309334
@@ -843,7 +868,8 @@ class {transform.original_class_name}(BaseTransformer):
843868
844869 # Create a temp file and dump the score to that file.
845870 local_score_file_name = get_temp_file_path()
846- joblib.dump(self._sklearn_object, local_score_file_name)
871+ with open(local_score_file_name, mode="w+b") as local_score_file:
872+ cp.dump(self._sklearn_object, local_score_file)
847873
848874 # Create temp stage to run score.
849875 score_stage_name = "SNOWML_SCORE_{{safe_id}}".format(safe_id=self.id)
@@ -872,7 +898,13 @@ class {transform.original_class_name}(BaseTransformer):
872898 custom_tags=dict([("autogen", True)]),
873899 )
874900 # Put locally serialized score on stage.
875- session.file.put(local_score_file_name, stage_score_file_name, auto_compress=False, overwrite=True, statement_params=statement_params)
901+ session.file.put(
902+ local_score_file_name,
903+ stage_score_file_name,
904+ auto_compress=False,
905+ overwrite=True,
906+ statement_params=statement_params
907+ )
876908
877909 @sproc(
878910 is_permanent=False,
@@ -890,7 +922,7 @@ class {transform.original_class_name}(BaseTransformer):
890922 label_cols: List[str],
891923 sample_weight_col: Optional[str]
892924 ) -> float:
893- import joblib
925+ import cloudpickle as cp
894926 import numpy as np
895927 import os
896928 import pandas
@@ -905,7 +937,11 @@ class {transform.original_class_name}(BaseTransformer):
905937 local_score_file.close()
906938
907939 session.file.get(stage_score_file_name, local_score_file_name, statement_params=statement_params)
908- estimator = joblib.load(os.path.join(local_score_file_name, os.listdir(local_score_file_name)[0]))
940+
941+ local_score_file_name_path = os.path.join(local_score_file_name, os.listdir(local_score_file_name)[0])
942+ with open(local_score_file_name_path, mode="r+b") as local_score_file_obj:
943+ estimator = cp.load(local_score_file_obj)
944+
909945 argspec = inspect.getfullargspec(estimator.score)
910946 if "X" in argspec.args:
911947 args = {{'X': df[input_cols]}}
0 commit comments