@@ -369,17 +369,11 @@ class WrapperGeneratorBase:
369369
370370 original_class_name INFERRED Class name for the given scikit-learn
371371 estimator.
372- estimator_class_name GENERATED Name for the new estimator class.
373- transformer_class_name GENERATED [TODO] Name for the new transformer
374- class.
375372 module_name INFERRED Name of the module that given class is
376373 is contained in.
377374 estimator_imports GENERATED Imports needed for the estimator / fit()
378375 call.
379376 fit_sproc_imports GENERATED Imports needed for the fit sproc call.
380- transform_function_name INFERRED Name for the transformer function. This
381- will be one of "transform" or
382- "predict()" depending on the class.
383377 ------------------------------------------------------------------------------------
384378 SIGNATURES AND ARGUMENTS
385379 ------------------------------------------------------------------------------------
@@ -444,9 +438,6 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None:
444438
445439 # Naming of the class.
446440 self .original_class_name = ""
447- self .estimator_class_name = ""
448- self .transformer_class_name = ""
449- self .transform_function_name = ""
450441
451442 # The signature and argument passing the __init__ functions.
452443 self .original_init_signature = inspect .Signature ()
@@ -456,33 +447,32 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None:
456447 self .sklearn_init_args_dict = ""
457448 self .estimator_init_member_args = ""
458449
450+ # Doc strings
459451 self .original_class_docstring = ""
460452 self .estimator_class_docstring = ""
461453 self .transformer_class_docstring = ""
462-
463- self .estimator_imports = ""
464- self .estimator_imports_list : List [str ] = []
465-
466454 self .original_fit_docstring = ""
467455 self .fit_docstring = ""
468456 self .original_transform_docstring = ""
469457 self .transform_docstring = ""
470458
459+ # Import strings
460+ self .estimator_imports = ""
461+ self .estimator_imports_list : List [str ] = []
462+ self .additional_import_statements = ""
463+
464+ # Test strings
471465 self .test_dataset_func = ""
472466 self .test_estimator_input_args = ""
473467 self .test_estimator_input_args_list : List [str ] = []
474468 self .test_class_name = ""
475469 self .test_estimator_imports = ""
476470 self .test_estimator_imports_list : List [str ] = []
477471
478- self .additional_import_statements = ""
479-
472+ # Dependencies
480473 self .predict_udf_deps = ""
481474 self .fit_sproc_deps = ""
482475
483- # TODO(amauser): Make fit a no-op if there is no internal state
484- # TODO(amauser): handling sparse input and output (LabelBinarizer)
485-
486476 def _format_default_value (self , default_value : Any ) -> str :
487477 if isinstance (default_value , str ):
488478 return f'"{ default_value } "'
@@ -561,26 +551,13 @@ def split_long_lines(line: str) -> str:
561551 self .estimator_class_docstring = class_docstring
562552
563553 def _populate_class_names (self ) -> None :
564- # TODO(snandamuri): All the 3 fields have exact same value. Do we really need these
565- # 3 separate fields?
566554 self .original_class_name = self .class_object [0 ]
567- self .estimator_class_name = self .original_class_name
568- self .transformer_class_name = self .estimator_class_name
569-
570555 self .test_class_name = f"{ self .original_class_name } Test"
571556
572557 def _populate_function_names_and_signatures (self ) -> None :
573558 for member in inspect .getmembers (self .class_object [1 ]):
574559 if member [0 ] == "__init__" :
575560 self .original_init_signature = inspect .signature (member [1 ])
576- elif member [0 ] == "predict" or member [0 ] == "transform" :
577- if self .transform_function_name != "" :
578- print ("ERROR: Class has both transform() and predict() methods." )
579- # TODO(snandamuri): Add support for both transform() and predict() methods in estimators.
580- # For now, resolve to predict() method when both predict() and transform() are available.
581- self .transform_function_name = "predict"
582- else :
583- self .transform_function_name = member [0 ]
584561
585562 signature_lines = []
586563 sklearn_init_lines = []
@@ -642,6 +619,7 @@ def _populate_function_names_and_signatures(self) -> None:
642619 self .estimator_init_member_args = "\n " .join (init_member_args )
643620 self .estimator_args_transform_calls = "\n " .join (arg_transform_calls )
644621
622+ # TODO(snandamuri): Implement type inference for classifiers.
645623 self .udf_datatype = "float" if self ._from_data_py or self ._is_regressor else ""
646624
647625 def _populate_file_paths (self ) -> None :
@@ -825,7 +803,7 @@ def generate(self) -> "SklearnWrapperGenerator":
825803 self .test_estimator_input_args_list .extend (["min_samples_leaf=1" , "max_leaf_nodes=100" ])
826804
827805 self .fit_sproc_deps = self .predict_udf_deps = (
828- "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'scikit-learn=={sklearn.__version__}',"
806+ "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'scikit-learn=={sklearn.__version__}', "
829807 "f'xgboost=={xgboost.__version__}', f'joblib=={joblib.__version__}'"
830808 )
831809 self ._construct_string_from_lists ()
@@ -842,7 +820,7 @@ def generate(self) -> "XGBoostWrapperGenerator":
842820 self .test_estimator_input_args_list .extend (["random_state=0" , "subsample=1.0" , "colsample_bynode=1.0" ])
843821 self .fit_sproc_imports = "import xgboost"
844822 self .fit_sproc_deps = self .predict_udf_deps = (
845- "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'xgboost=={xgboost.__version__}',"
823+ "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'xgboost=={xgboost.__version__}', "
846824 "f'joblib=={joblib.__version__}'"
847825 )
848826 self ._construct_string_from_lists ()
@@ -859,7 +837,7 @@ def generate(self) -> "LightGBMWrapperGenerator":
859837 self .test_estimator_input_args_list .extend (["random_state=0" ])
860838 self .fit_sproc_imports = "import lightgbm"
861839 self .fit_sproc_deps = self .predict_udf_deps = (
862- "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'lightgbm=={lightgbm.__version__}',"
840+ "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'lightgbm=={lightgbm.__version__}', "
863841 "f'joblib=={joblib.__version__}'"
864842 )
865843 self ._construct_string_from_lists ()
0 commit comments