Skip to content

Commit 6484bf8

Browse files
authored
[ADD] Allow users to pass feat types to tabular validator (#441)
* add tests and make get_columns_to_encode in tabular validator * fix flake and mypy and silly bug * pass feat types to search function of the api * add example * add openml to requirements * add task ids to populate cache * add check for feat types * fix mypy and flake
1 parent dcd2bc5 commit 6484bf8

12 files changed

+359
-34
lines changed

autoPyTorch/api/base_task.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def _get_dataset_input_validator(
307307
resampling_strategy_args: Optional[Dict[str, Any]] = None,
308308
dataset_name: Optional[str] = None,
309309
dataset_compression: Optional[DatasetCompressionSpec] = None,
310+
**kwargs: Any
310311
) -> Tuple[BaseDataset, BaseInputValidator]:
311312
"""
312313
Returns an object of a child class of `BaseDataset` and
@@ -353,6 +354,7 @@ def get_dataset(
353354
resampling_strategy_args: Optional[Dict[str, Any]] = None,
354355
dataset_name: Optional[str] = None,
355356
dataset_compression: Optional[DatasetCompressionSpec] = None,
357+
**kwargs: Any
356358
) -> BaseDataset:
357359
"""
358360
Returns an object of a child class of `BaseDataset` according to the current task.
@@ -407,6 +409,10 @@ def get_dataset(
407409
Subsampling takes into account classification labels and stratifies
408410
accordingly. We guarantee that at least one occurrence of each
409411
label is included in the sampled set.
412+
kwargs (Any):
413+
can be used to pass task specific dataset arguments. Currently supports
414+
passing `feat_types` for tabular tasks which specifies whether a feature is
415+
'numerical' or 'categorical'.
410416
411417
Returns:
412418
BaseDataset:
@@ -420,7 +426,8 @@ def get_dataset(
420426
resampling_strategy=resampling_strategy,
421427
resampling_strategy_args=resampling_strategy_args,
422428
dataset_name=dataset_name,
423-
dataset_compression=dataset_compression)
429+
dataset_compression=dataset_compression,
430+
**kwargs)
424431

425432
return dataset
426433

autoPyTorch/api/tabular_classification.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def _get_dataset_input_validator(
168168
resampling_strategy_args: Optional[Dict[str, Any]] = None,
169169
dataset_name: Optional[str] = None,
170170
dataset_compression: Optional[DatasetCompressionSpec] = None,
171+
**kwargs: Any,
171172
) -> Tuple[TabularDataset, TabularInputValidator]:
172173
"""
173174
Returns an object of `TabularDataset` and an object of
@@ -194,6 +195,9 @@ def _get_dataset_input_validator(
194195
dataset_compression (Optional[DatasetCompressionSpec]):
195196
specifications for dataset compression. For more info check
196197
documentation for `BaseTask.get_dataset`.
198+
kwargs (Any):
199+
Currently for tabular tasks, expect `feat_types: (Optional[List[str]]` which
200+
specifies whether a feature is 'numerical' or 'categorical'.
197201
198202
Returns:
199203
TabularDataset:
@@ -206,12 +210,14 @@ def _get_dataset_input_validator(
206210
resampling_strategy_args = resampling_strategy_args if resampling_strategy_args is not None else \
207211
self.resampling_strategy_args
208212

213+
feat_types = kwargs.pop('feat_types', None)
209214
# Create a validator object to make sure that the data provided by
210215
# the user matches the autopytorch requirements
211216
input_validator = TabularInputValidator(
212217
is_classification=True,
213218
logger_port=self._logger_port,
214-
dataset_compression=dataset_compression
219+
dataset_compression=dataset_compression,
220+
feat_types=feat_types
215221
)
216222

217223
# Fit a input validator to check the provided data
@@ -238,6 +244,7 @@ def search(
238244
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
239245
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
240246
dataset_name: Optional[str] = None,
247+
feat_types: Optional[List[str]] = None,
241248
budget_type: str = 'epochs',
242249
min_budget: int = 5,
243250
max_budget: int = 50,
@@ -266,6 +273,10 @@ def search(
266273
A pair of features (X_train) and targets (y_train) used to fit a
267274
pipeline. Additionally, a holdout of this pairs (X_test, y_test) can
268275
be provided to track the generalization performance of each stage.
276+
feat_types (Optional[List[str]]):
277+
Description about the feature types of the columns.
278+
Accepts `numerical` for integers, float data and `categorical`
279+
for categories, strings and bool. Defaults to None.
269280
optimize_metric (str):
270281
name of the metric that is used to evaluate a pipeline.
271282
budget_type (str):
@@ -433,7 +444,8 @@ def search(
433444
resampling_strategy=self.resampling_strategy,
434445
resampling_strategy_args=self.resampling_strategy_args,
435446
dataset_name=dataset_name,
436-
dataset_compression=self._dataset_compression)
447+
dataset_compression=self._dataset_compression,
448+
feat_types=feat_types)
437449

438450
return self._search(
439451
dataset=self.dataset,

autoPyTorch/api/tabular_regression.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def _get_dataset_input_validator(
169169
resampling_strategy_args: Optional[Dict[str, Any]] = None,
170170
dataset_name: Optional[str] = None,
171171
dataset_compression: Optional[DatasetCompressionSpec] = None,
172+
**kwargs: Any
172173
) -> Tuple[TabularDataset, TabularInputValidator]:
173174
"""
174175
Returns an object of `TabularDataset` and an object of
@@ -195,6 +196,9 @@ def _get_dataset_input_validator(
195196
dataset_compression (Optional[DatasetCompressionSpec]):
196197
specifications for dataset compression. For more info check
197198
documentation for `BaseTask.get_dataset`.
199+
kwargs (Any):
200+
Currently for tabular tasks, expect `feat_types: (Optional[List[str]]` which
201+
specifies whether a feature is 'numerical' or 'categorical'.
198202
Returns:
199203
TabularDataset:
200204
the dataset object.
@@ -206,12 +210,14 @@ def _get_dataset_input_validator(
206210
resampling_strategy_args = resampling_strategy_args if resampling_strategy_args is not None else \
207211
self.resampling_strategy_args
208212

213+
feat_types = kwargs.pop('feat_types', None)
209214
# Create a validator object to make sure that the data provided by
210215
# the user matches the autopytorch requirements
211216
input_validator = TabularInputValidator(
212217
is_classification=False,
213218
logger_port=self._logger_port,
214-
dataset_compression=dataset_compression
219+
dataset_compression=dataset_compression,
220+
feat_types=feat_types
215221
)
216222

217223
# Fit a input validator to check the provided data
@@ -238,6 +244,7 @@ def search(
238244
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
239245
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
240246
dataset_name: Optional[str] = None,
247+
feat_types: Optional[List[str]] = None,
241248
budget_type: str = 'epochs',
242249
min_budget: int = 5,
243250
max_budget: int = 50,
@@ -266,6 +273,10 @@ def search(
266273
A pair of features (X_train) and targets (y_train) used to fit a
267274
pipeline. Additionally, a holdout of this pairs (X_test, y_test) can
268275
be provided to track the generalization performance of each stage.
276+
feat_types (Optional[List[str]]):
277+
Description about the feature types of the columns.
278+
Accepts `numerical` for integers, float data and `categorical`
279+
for categories, strings and bool. Defaults to None.
269280
optimize_metric (str):
270281
Name of the metric that is used to evaluate a pipeline.
271282
budget_type (str):
@@ -434,7 +445,8 @@ def search(
434445
resampling_strategy=self.resampling_strategy,
435446
resampling_strategy_args=self.resampling_strategy_args,
436447
dataset_name=dataset_name,
437-
dataset_compression=self._dataset_compression)
448+
dataset_compression=self._dataset_compression,
449+
feat_types=feat_types)
438450

439451
return self._search(
440452
dataset=self.dataset,

autoPyTorch/data/base_feature_validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
logger: Optional[Union[PicklableClientLogger, logging.Logger]] = None,
3636
):
3737
# Register types to detect unsupported data format changes
38-
self.feat_type: Optional[List[str]] = None
38+
self.feat_types: Optional[List[str]] = None
3939
self.data_type: Optional[type] = None
4040
self.dtypes: List[str] = []
4141
self.column_order: List[str] = []

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 82 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,18 @@ class TabularFeatureValidator(BaseFeatureValidator):
9494
List of indices of numerical columns
9595
categorical_columns (List[int]):
9696
List of indices of categorical columns
97+
feat_types (List[str]):
98+
Description about the feature types of the columns.
99+
Accepts `numerical` for integers, float data and `categorical`
100+
for categories, strings and bool.
97101
"""
98102
def __init__(
99103
self,
100104
logger: Optional[Union[PicklableClientLogger, Logger]] = None,
105+
feat_types: Optional[List[str]] = None,
101106
):
102107
super().__init__(logger)
108+
self.feat_types = feat_types
103109

104110
@staticmethod
105111
def _comparator(cmp1: str, cmp2: str) -> int:
@@ -167,9 +173,9 @@ def _fit(
167173
if not X.select_dtypes(include='object').empty:
168174
X = self.infer_objects(X)
169175

170-
self.transformed_columns, self.feat_type = self._get_columns_to_encode(X)
176+
self.transformed_columns, self.feat_types = self.get_columns_to_encode(X)
171177

172-
assert self.feat_type is not None
178+
assert self.feat_types is not None
173179

174180
if len(self.transformed_columns) > 0:
175181

@@ -186,8 +192,8 @@ def _fit(
186192
# The column transformer reorders the feature types
187193
# therefore, we need to change the order of columns as well
188194
# This means categorical columns are shifted to the left
189-
self.feat_type = sorted(
190-
self.feat_type,
195+
self.feat_types = sorted(
196+
self.feat_types,
191197
key=functools.cmp_to_key(self._comparator)
192198
)
193199

@@ -201,7 +207,7 @@ def _fit(
201207
for cat in encoded_categories
202208
]
203209

204-
for i, type_ in enumerate(self.feat_type):
210+
for i, type_ in enumerate(self.feat_types):
205211
if 'numerical' in type_:
206212
self.numerical_columns.append(i)
207213
else:
@@ -336,7 +342,7 @@ def _check_data(
336342

337343
# Define the column to be encoded here as the feature validator is fitted once
338344
# per estimator
339-
self.transformed_columns, self.feat_type = self._get_columns_to_encode(X)
345+
self.transformed_columns, self.feat_types = self.get_columns_to_encode(X)
340346

341347
column_order = [column for column in X.columns]
342348
if len(self.column_order) > 0:
@@ -361,12 +367,72 @@ def _check_data(
361367
else:
362368
self.dtypes = dtypes
363369

370+
def get_columns_to_encode(
371+
self,
372+
X: pd.DataFrame
373+
) -> Tuple[List[str], List[str]]:
374+
"""
375+
Return the columns to be transformed as well as
376+
the type of feature for each column.
377+
378+
The returned values are dependent on `feat_types` passed to the `__init__`.
379+
380+
Args:
381+
X (pd.DataFrame)
382+
A set of features that are going to be validated (type and dimensionality
383+
checks) and an encoder fitted in the case the data needs encoding
384+
385+
Returns:
386+
transformed_columns (List[str]):
387+
Columns to encode, if any
388+
feat_type:
389+
Type of each column numerical/categorical
390+
"""
391+
transformed_columns, feat_types = self._get_columns_to_encode(X)
392+
if self.feat_types is not None:
393+
self._validate_feat_types(X)
394+
transformed_columns = [X.columns[i] for i, col in enumerate(self.feat_types)
395+
if col.lower() == 'categorical']
396+
return transformed_columns, self.feat_types
397+
else:
398+
return transformed_columns, feat_types
399+
400+
def _validate_feat_types(self, X: pd.DataFrame) -> None:
401+
"""
402+
Checks if the passed `feat_types` is compatible with what
403+
AutoPyTorch expects, i.e, it should only contain `numerical`
404+
or `categorical` and the number of feature types is equal to
405+
the number of features. The case does not matter.
406+
407+
Args:
408+
X (pd.DataFrame):
409+
input features set
410+
411+
Raises:
412+
ValueError:
413+
if the number of feat_types is not equal to the number of features
414+
if the feature type are not one of "numerical", "categorical"
415+
"""
416+
assert self.feat_types is not None # mypy check
417+
418+
if len(self.feat_types) != len(X.columns):
419+
raise ValueError(f"Expected number of `feat_types`: {len(self.feat_types)}"
420+
f" to be the same as the number of features {len(X.columns)}")
421+
for feat_type in set(self.feat_types):
422+
if feat_type.lower() not in ['numerical', 'categorical']:
423+
raise ValueError(f"Expected type of features to be in `['numerical', "
424+
f"'categorical']`, but got {feat_type}")
425+
364426
def _get_columns_to_encode(
365427
self,
366428
X: pd.DataFrame,
367429
) -> Tuple[List[str], List[str]]:
368430
"""
369-
Return the columns to be encoded from a pandas dataframe
431+
Return the columns to be transformed as well as
432+
the type of feature for each column from a pandas dataframe.
433+
434+
If `self.feat_types` is not None, it also validates that the
435+
dataframe dtypes dont disagree with the ones passed in `__init__`.
370436
371437
Args:
372438
X (pd.DataFrame)
@@ -380,21 +446,24 @@ def _get_columns_to_encode(
380446
Type of each column numerical/categorical
381447
"""
382448

383-
if len(self.transformed_columns) > 0 and self.feat_type is not None:
384-
return self.transformed_columns, self.feat_type
449+
if len(self.transformed_columns) > 0 and self.feat_types is not None:
450+
return self.transformed_columns, self.feat_types
385451

386452
# Register if a column needs encoding
387453
transformed_columns = []
388454

389455
# Also, register the feature types for the estimator
390-
feat_type = []
456+
feat_types = []
391457

392458
# Make sure each column is a valid type
393459
for i, column in enumerate(X.columns):
394460
if X[column].dtype.name in ['category', 'bool']:
395461

396462
transformed_columns.append(column)
397-
feat_type.append('categorical')
463+
if self.feat_types is not None and self.feat_types[i].lower() == 'numerical':
464+
raise ValueError(f"Passed numerical as the feature type for column: {column} "
465+
f"but the column is categorical")
466+
feat_types.append('categorical')
398467
# Move away from np.issubdtype as it causes
399468
# TypeError: data type not understood in certain pandas types
400469
elif not is_numeric_dtype(X[column]):
@@ -434,8 +503,8 @@ def _get_columns_to_encode(
434503
)
435504
)
436505
else:
437-
feat_type.append('numerical')
438-
return transformed_columns, feat_type
506+
feat_types.append('numerical')
507+
return transformed_columns, feat_types
439508

440509
def list_to_dataframe(
441510
self,

autoPyTorch/data/tabular_validator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- encoding: utf-8 -*-
22
import logging
3-
from typing import Optional, Tuple, Union
3+
from typing import List, Optional, Tuple, Union
44

55
import numpy as np
66

@@ -41,18 +41,24 @@ class TabularInputValidator(BaseInputValidator):
4141
dataset_compression (Optional[DatasetCompressionSpec]):
4242
specifications for dataset compression. For more info check
4343
documentation for `BaseTask.get_dataset`.
44+
feat_types (List[str]):
45+
Description about the feature types of the columns.
46+
Accepts `numerical` for integers, float data and `categorical`
47+
for categories, strings and bool
4448
"""
4549
def __init__(
4650
self,
4751
is_classification: bool = False,
4852
logger_port: Optional[int] = None,
4953
dataset_compression: Optional[DatasetCompressionSpec] = None,
54+
feat_types: Optional[List[str]] = None,
5055
seed: int = 42,
5156
):
5257
self.dataset_compression = dataset_compression
5358
self._reduced_dtype: Optional[DatasetDTypeContainerType] = None
5459
self.is_classification = is_classification
5560
self.logger_port = logger_port
61+
self.feat_types = feat_types
5662
self.seed = seed
5763
if self.logger_port is not None:
5864
self.logger: Union[logging.Logger, PicklableClientLogger] = get_named_client_logger(
@@ -63,7 +69,8 @@ def __init__(
6369
self.logger = logging.getLogger('Validation')
6470

6571
self.feature_validator = TabularFeatureValidator(
66-
logger=self.logger)
72+
logger=self.logger,
73+
feat_types=self.feat_types)
6774
self.target_validator = TabularTargetValidator(
6875
is_classification=self.is_classification,
6976
logger=self.logger

0 commit comments

Comments
 (0)