@@ -12,7 +12,7 @@ import numpy as np
1212{transform.estimator_imports}
1313from sklearn.utils.metaestimators import available_if
1414
15- from snowflake.ml.framework.base import BaseTransformer
15+ from snowflake.ml.sklearn. framework.base import BaseTransformer
1616from snowflake.ml._internal import telemetry
1717from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
1818from snowflake.ml._internal.utils import pkg_version_utils, identifier
@@ -21,6 +21,14 @@ from snowflake.snowpark import DataFrame, Session
2121from snowflake.snowpark.functions import pandas_udf, sproc
2222from snowflake.snowpark.types import PandasSeries
2323
24+ from snowflake.ml.model.model_signature import (
25+ DataType,
26+ FeatureSpec,
27+ ModelSignature,
28+ _infer_signature,
29+ _rename_features,
30+ )
31+
2432_PROJECT = "ModelDevelopment"
2533# Derive subproject from module name by removing "sklearn"
2634# and converting module name from underscore to CamelCase
@@ -116,6 +124,7 @@ class {transform.original_class_name}(BaseTransformer):
116124 self._sklearn_object = {transform.root_module_name}.{transform.original_class_name}(
117125 {transform.sklearn_init_arguments}
118126 )
127+ self._model_signature_dict = None
119128 {transform.estimator_init_member_args}
120129
121130 def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
@@ -161,6 +170,7 @@ class {transform.original_class_name}(BaseTransformer):
161170 "Supported dataset types: snowpark.DataFrame, pandas.DataFrame."
162171 )
163172 self._is_fitted = True
173+ self._get_model_signatures(dataset)
164174 return self
165175
166176 def _fit_snowpark(self, dataset: DataFrame) -> None:
@@ -310,9 +320,9 @@ class {transform.original_class_name}(BaseTransformer):
310320 query,
311321 stage_transform_file_name,
312322 stage_result_file_name,
313- identifier.get_equivalent_identifier_in_the_response_pandas_dataframe (self.input_cols),
314- identifier.get_equivalent_identifier_in_the_response_pandas_dataframe (self.label_cols),
315- identifier.get_equivalent_identifier_in_the_response_pandas_dataframe (self.sample_weight_col),
323+ identifier.get_unescaped_names (self.input_cols),
324+ identifier.get_unescaped_names (self.label_cols),
325+ identifier.get_unescaped_names (self.sample_weight_col),
316326 statement_params=statement_params,
317327 )
318328
@@ -378,7 +388,7 @@ class {transform.original_class_name}(BaseTransformer):
378388 # Input columns for UDF are sorted by column names.
379389 # We need actual order of input cols to reorder dataframe before calling inference methods.
380390 input_cols = self.input_cols
381- unquoted_input_cols = identifier.get_equivalent_identifier_in_the_response_pandas_dataframe (self.input_cols)
391+ unquoted_input_cols = identifier.get_unescaped_names (self.input_cols)
382392
383393 statement_params = telemetry.get_function_usage_statement_params(
384394 project=_PROJECT,
@@ -511,9 +521,37 @@ class {transform.original_class_name}(BaseTransformer):
511521 expected_output_cols_list: List[str]
512522 ) -> pd.DataFrame:
513523 output_cols = expected_output_cols_list.copy()
514- transformed_numpy_array = getattr(self._sklearn_object, inference_method)(
515- dataset[self.input_cols]
524+
525+ # Model expects exact same columns names in the input df for predict call.
526+ # Given the scenario that user use snowpark DataFrame in fit call, but pandas DataFrame in predict call
527+ # input cols need to match unquoted / quoted
528+ input_cols = self.input_cols
529+ unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
530+
531+ estimator = self._sklearn_object
532+
533+ input_df = dataset[input_cols] # Select input columns with quoted column names.
534+ if hasattr(estimator, "feature_names_in_"):
535+ missing_features = []
536+ for i, f in enumerate(getattr(estimator, "feature_names_in_")):
537+ if i >= len(input_cols) or (input_cols[i] != f and unquoted_input_cols[i] != f):
538+ missing_features.append(f)
539+
540+ if len(missing_features) > 0:
541+ raise ValueError(
542+ "The feature names should match with those that were passed during fit.\n"
543+ f"Features seen during fit call but not present in the input: {{missing_features}}\n"
544+ f"Features in the input dataframe : {{input_cols}}\n"
545+ )
546+ input_df.columns = getattr(estimator, "feature_names_in_")
547+ else:
548+ # Just rename the column names to unquoted identifiers.
549+ input_df.columns = unquoted_input_cols # Replace the quoted columns identifier with unquoted column ids.
550+
551+ transformed_numpy_array = getattr(estimator, inference_method)(
552+ input_df
516553 )
554+
517555 if (
518556 isinstance(transformed_numpy_array, list)
519557 and len(transformed_numpy_array) > 0
@@ -974,12 +1012,45 @@ class {transform.original_class_name}(BaseTransformer):
9741012 score_sproc_name,
9751013 query,
9761014 stage_score_file_name,
977- identifier.get_equivalent_identifier_in_the_response_pandas_dataframe (self.input_cols),
978- identifier.get_equivalent_identifier_in_the_response_pandas_dataframe (self.label_cols),
979- identifier.get_equivalent_identifier_in_the_response_pandas_dataframe (self.sample_weight_col),
1015+ identifier.get_unescaped_names (self.input_cols),
1016+ identifier.get_unescaped_names (self.label_cols),
1017+ identifier.get_unescaped_names (self.sample_weight_col),
9801018 statement_params=statement_params,
9811019 )
9821020
9831021 cleanup_temp_files([local_score_file_name])
9841022
9851023 return score
1024+
1025+ def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1026+ self._model_signature_dict: Dict[str, ModelSignature] = dict()
1027+
1028+ PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
1029+
1030+ inputs = _infer_signature(dataset[self.input_cols], "input")
1031+ if hasattr(self, "predict"):
1032+ # For classifier, the type of predict is the same as the type of label
1033+ if self._sklearn_object._estimator_type == 'classifier':
1034+ outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
1035+ outputs = _rename_features(outputs, self.output_cols) # rename the output columns
1036+ self._model_signature_dict["predict"] = ModelSignature(inputs, outputs)
1037+ # For regressor, the type of predict is float64
1038+ elif self._sklearn_object._estimator_type == 'regressor':
1039+ outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
1040+ self._model_signature_dict["predict"] = ModelSignature(inputs, outputs)
1041+
1042+ for prob_func in PROB_FUNCTIONS:
1043+ if hasattr(self, prob_func):
1044+ output_cols_prefix: str = f"{{prob_func}}_"
1045+ output_column_names = self._get_output_column_names(output_cols_prefix)
1046+ outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
1047+ self._model_signature_dict[prob_func] = ModelSignature(inputs, outputs)
1048+
1049+ ##TODO: Add support for transform method
1050+
1051+
1052+ @property
1053+ def model_signatures(self) -> Dict[str, ModelSignature]:
1054+ if self._model_signature_dict is None:
1055+ raise RuntimeError("Estimator not fitted before accessing property model_signatures! ")
1056+ return self._model_signature_dict
0 commit comments