1616LOAD_DIABETES = "load_diabetes"
1717
1818
19- ADDITIONAL_PARAM_DESCRIPTIONS = """
20-
19+ ADDITIONAL_PARAM_DESCRIPTIONS = {
20+ "input_cols" : """
2121input_cols: Optional[Union[str, List[str]]]
2222 A string or list of strings representing column names that contain features.
2323 If this parameter is not specified, all columns in the input DataFrame except
2424 the columns specified by label_cols, sample_weight_col, and passthrough_cols
25- parameters are considered input columns.
26-
25+ parameters are considered input columns. Input columns can also be set after
26+ initialization with the `set_input_cols` method.
27+ """ ,
28+ "label_cols" : """
2729label_cols: Optional[Union[str, List[str]]]
2830 A string or list of strings representing column names that contain labels.
29- This is a required param for estimators, as there is no way to infer these
30- columns. If this parameter is not specified, then object is fitted without
31- labels (like a transformer).
32-
31+ Label columns must be specified with this parameter during initialization
32+ or with the `set_label_cols` method before fitting.
33+ """ ,
34+ "output_cols" : """
3335output_cols: Optional[Union[str, List[str]]]
3436 A string or list of strings representing column names that will store the
3537 output of predict and transform operations. The length of output_cols must
36- match the expected number of output columns from the specific estimator or
38+ match the expected number of output columns from the specific predictor or
3739 transformer class used.
38- If this parameter is not specified, output column names are derived by
39- adding an OUTPUT_ prefix to the label column names. These inferred output
40- column names work for estimator's predict() method, but output_cols must
41- be set explicitly for transformers.
42-
40+ If you omit this parameter, output column names are derived by adding an
41+ OUTPUT_ prefix to the label column names for supervised estimators, or
42+ OUTPUT_<IDX>for unsupervised estimators. These inferred output column names
43+ work for predictors, but output_cols must be set explicitly for transformers.
44+ In general, explicitly specifying output column names is clearer, especially
45+ if you don’t specify the input column names.
46+ To transform in place, pass the same names for input_cols and output_cols.
47+ be set explicitly for transformers. Output columns can also be set after
48+ initialization with the `set_output_cols` method.
49+ """ ,
50+ "sample_weight_col" : """
4351sample_weight_col: Optional[str]
4452 A string representing the column name containing the sample weights.
45- This argument is only required when working with weighted datasets.
46-
53+ This argument is only required when working with weighted datasets. Sample
54+ weight column can also be set after initialization with the
55+ `set_sample_weight_col` method.
56+ """ ,
57+ "passthrough_cols" : """
4758passthrough_cols: Optional[Union[str, List[str]]]
4859 A string or a list of strings indicating column names to be excluded from any
4960 operations (such as train, transform, or inference). These specified column(s)
5061 will remain untouched throughout the process. This option is helpful in scenarios
5162 requiring automatic input_cols inference, but need to avoid using specific
52- columns, like index columns, during training or inference.
53-
63+ columns, like index columns, during training or inference. Passthrough columns
64+ can also be set after initialization with the `set_passthrough_cols` method.
65+ """ ,
66+ "drop_input_cols" : """
5467drop_input_cols: Optional[bool], default=False
5568 If set, the response of predict(), transform() methods will not contain input columns.
56- """
69+ """ ,
70+ }
5771
5872ADDITIONAL_METHOD_DESCRIPTION = """
5973Raises:
@@ -448,7 +462,6 @@ class WrapperGeneratorBase:
448462 is contained in.
449463 estimator_imports GENERATED Imports needed for the estimator / fit()
450464 call.
451- wrapper_provider_class GENERATED Class name of wrapper provider.
452465 ------------------------------------------------------------------------------------
453466 SIGNATURES AND ARGUMENTS
454467 ------------------------------------------------------------------------------------
@@ -545,7 +558,6 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None:
545558 self .estimator_imports = ""
546559 self .estimator_imports_list : List [str ] = []
547560 self .score_sproc_imports : List [str ] = []
548- self .wrapper_provider_class = ""
549561 self .additional_import_statements = ""
550562
551563 # Test strings
@@ -630,10 +642,11 @@ def _populate_class_doc_fields(self) -> None:
630642 class_docstring = inspect .getdoc (self .class_object [1 ]) or ""
631643 class_docstring = class_docstring .rsplit ("Attributes\n " , 1 )[0 ]
632644
645+ parameters_heading = "Parameters\n ----------\n "
633646 class_description , param_description = (
634- class_docstring .rsplit ("Parameters \n " , 1 )
635- if len (class_docstring .rsplit ("Parameters \n " , 1 )) == 2
636- else (class_docstring , "---------- \n " )
647+ class_docstring .rsplit (parameters_heading , 1 )
648+ if len (class_docstring .rsplit (parameters_heading , 1 )) == 2
649+ else (class_docstring , "" )
637650 )
638651
639652 # Extract the first sentence of the class description
@@ -645,9 +658,11 @@ def _populate_class_doc_fields(self) -> None:
645658 f"]\n ({ self .get_doc_link ()} )"
646659 )
647660
648- # Add SnowML specific param descriptions.
649- param_description = "Parameters\n " + param_description .strip ()
650- param_description += ADDITIONAL_PARAM_DESCRIPTIONS
661+ # Add SnowML specific param descriptions before third party parameters.
662+ snowml_parameters = ""
663+ for d in ADDITIONAL_PARAM_DESCRIPTIONS .values ():
664+ snowml_parameters += d
665+ param_description = f"{ parameters_heading } { snowml_parameters } \n { param_description .strip ()} "
651666
652667 class_docstring = f"{ class_description } \n \n { param_description } "
653668 class_docstring = textwrap .indent (class_docstring , " " ).strip ()
@@ -718,12 +733,23 @@ def _populate_function_names_and_signatures(self) -> None:
718733 for member in inspect .getmembers (self .class_object [1 ]):
719734 if member [0 ] == "__init__" :
720735 self .original_init_signature = inspect .signature (member [1 ])
736+ elif member [0 ] == "fit" :
737+ original_fit_signature = inspect .signature (member [1 ])
738+ if original_fit_signature .parameters ["y" ].default is None :
739+ # The fit does not require labels, so our label_cols argument is optional.
740+ ADDITIONAL_PARAM_DESCRIPTIONS [
741+ "label_cols"
742+ ] = """
743+ label_cols: Optional[Union[str, List[str]]]
744+ This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
745+ """
721746
722747 signature_lines = []
723748 sklearn_init_lines = []
724749 init_member_args = []
725750 has_kwargs = False
726751 sklearn_init_args_dict_list = []
752+
727753 for k , v in self .original_init_signature .parameters .items ():
728754 if k == "self" :
729755 signature_lines .append ("self" )
@@ -855,9 +881,9 @@ def generate(self) -> "WrapperGeneratorBase":
855881 self ._populate_flags ()
856882 self ._populate_class_names ()
857883 self ._populate_import_statements ()
858- self ._populate_class_doc_fields ()
859884 self ._populate_function_doc_fields ()
860885 self ._populate_function_names_and_signatures ()
886+ self ._populate_class_doc_fields ()
861887 self ._populate_file_paths ()
862888 self ._populate_integ_test_fields ()
863889 return self
@@ -876,13 +902,8 @@ def generate(self) -> "SklearnWrapperGenerator":
876902 # Populate all the common values
877903 super ().generate ()
878904
879- is_model_selector = WrapperGeneratorFactory ._is_class_of_type (self .class_object [1 ], "BaseSearchCV" )
880-
881905 # Populate SKLearn specific values
882906 self .estimator_imports_list .extend (["import sklearn" , f"import { self .root_module_name } " ])
883- self .wrapper_provider_class = (
884- "SklearnModelSelectionWrapperProvider" if is_model_selector else "SklearnWrapperProvider"
885- )
886907 self .score_sproc_imports = ["sklearn" ]
887908
888909 if "random_state" in self .original_init_signature .parameters .keys ():
@@ -982,6 +1003,9 @@ def generate(self) -> "SklearnWrapperGenerator":
9821003 if self ._is_hist_gradient_boosting_regressor :
9831004 self .test_estimator_input_args_list .extend (["min_samples_leaf=1" , "max_leaf_nodes=100" ])
9841005
1006+ self .deps = (
1007+ "f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'"
1008+ )
9851009 self .supported_export_method = "to_sklearn"
9861010 self .unsupported_export_methods = ["to_xgboost" , "to_lightgbm" ]
9871011 self ._construct_string_from_lists ()
@@ -1010,10 +1034,10 @@ def generate(self) -> "XGBoostWrapperGenerator":
10101034 ["random_state=0" , "subsample=1.0" , "colsample_bynode=1.0" , "n_jobs=1" ]
10111035 )
10121036 self .score_sproc_imports = ["xgboost" ]
1013- self .wrapper_provider_class = "XGBoostWrapperProvider"
10141037 # TODO(snandamuri): Replace cloudpickle with joblib after latest version of joblib is added to snowflake conda.
10151038 self .supported_export_method = "to_xgboost"
10161039 self .unsupported_export_methods = ["to_sklearn" , "to_lightgbm" ]
1040+ self .deps = "f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'"
10171041 self ._construct_string_from_lists ()
10181042 return self
10191043
@@ -1039,8 +1063,8 @@ def generate(self) -> "LightGBMWrapperGenerator":
10391063 self .estimator_imports_list .append ("import lightgbm" )
10401064 self .test_estimator_input_args_list .extend (["random_state=0" , "n_jobs=1" ])
10411065 self .score_sproc_imports = ["lightgbm" ]
1042- self .wrapper_provider_class = "LightGBMWrapperProvider"
10431066
1067+ self .deps = "f'numpy=={np.__version__}', f'lightgbm=={lightgbm.__version__}', f'cloudpickle=={cp.__version__}'"
10441068 self .supported_export_method = "to_lightgbm"
10451069 self .unsupported_export_methods = ["to_sklearn" , "to_xgboost" ]
10461070 self ._construct_string_from_lists ()
0 commit comments