Skip to content

Commit e8b4acf

Browse files
committed
fix tests, calculate features and X_prediction
1 parent 38a68bf commit e8b4acf

File tree

5 files changed

+80
-67
lines changed

5 files changed

+80
-67
lines changed

ml_garden/core/steps/calculate_features.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,16 @@ def __init__(
6565
f" features: {list(self.feature_extractors.keys())}"
6666
)
6767

68-
def _convert_column_to_datetime(self, df: pd.DataFrame, column: str) -> pd.DataFrame:
68+
def _convert_column_to_datetime(self, df: pd.DataFrame, column: str, log: bool) -> pd.DataFrame:
6969
"""Convert a column to datetime.
7070
Parameters
7171
----------
7272
df : pd.DataFrame
7373
The DataFrame containing the column to convert
7474
column : str
7575
The name of the column to convert
76+
log: bool
77+
If True, logs information.
7678
Returns
7779
-------
7880
pd.DataFrame
@@ -85,14 +87,15 @@ def _convert_column_to_datetime(self, df: pd.DataFrame, column: str) -> pd.DataF
8587
df[column],
8688
errors="raise",
8789
)
88-
self.logger.info(f"Column '{column}' automatically converted to datetime.")
90+
if log:
91+
self.logger.info(f"Column '{column}' automatically converted to datetime.")
8992
except ValueError as e:
9093
self.logger.error(f"Error converting column '{column}' to datetime: {e}")
9194
except Exception as e:
9295
self.logger.error(f"Unexpected error converting column '{column}' to datetime: {e}")
9396
else:
94-
self.logger.debug(f"Column '{column}' is already a datetime type.")
95-
97+
if log:
98+
self.logger.debug(f"Column '{column}' is already a datetime type.")
9699
return df
97100

98101
def _extract_feature(self, df: pd.DataFrame, column: str, feature: str) -> None:
@@ -122,6 +125,14 @@ def _extract_feature(self, df: pd.DataFrame, column: str, feature: str) -> None:
122125
)
123126
raise ValueError(error_message)
124127

128+
def _drop_datetime_columns(self, df: pd.DataFrame, log: bool) -> pd.DataFrame:
129+
"""Drop the datetime columns from the `df`."""
130+
if self.datetime_columns:
131+
if log:
132+
self.logger.info(f"Dropping original datetime columns: {self.datetime_columns}")
133+
return df.drop(columns=self.datetime_columns)
134+
return df
135+
125136
def execute(self, data: DataContainer) -> DataContainer:
126137
"""Execute the step.
127138
Parameters
@@ -135,21 +146,18 @@ def execute(self, data: DataContainer) -> DataContainer:
135146
"""
136147
self.logger.info("Calculating features")
137148

138-
if not data.is_train:
139-
data.flow = self._create_datetime_features(data.flow, log=True)
149+
datasets = [
150+
("X_prediction", data.X_prediction, True),
151+
("X_train", data.X_train, True),
152+
("X_validation", data.X_validation, False),
153+
("X_test", data.X_test, False),
154+
]
140155

141-
if data.train is not None:
142-
data.train = self._create_datetime_features(data.train, log=True)
143-
144-
if data.validation is not None:
145-
data.validation = self._create_datetime_features(data.validation)
146-
147-
if data.test is not None:
148-
data.test = self._create_datetime_features(data.test)
149-
150-
## add datetime columns to ignore columns for training
151-
if self.datetime_columns:
152-
data.columns_to_ignore_for_training.extend(self.datetime_columns)
156+
for attr_name, dataset, should_log in datasets:
157+
if dataset is not None:
158+
dataset = self._create_datetime_features(dataset, log=should_log)
159+
dataset = self._drop_datetime_columns(dataset, log=should_log)
160+
setattr(data, attr_name, dataset)
153161

154162
return data
155163

@@ -173,7 +181,7 @@ def _create_datetime_features(
173181
if self.datetime_columns:
174182
for column in self.datetime_columns:
175183
if column in df.columns:
176-
df = self._convert_column_to_datetime(df, column)
184+
df = self._convert_column_to_datetime(df, column, log)
177185

178186
if self.features:
179187
for feature in self.features:
@@ -191,4 +199,7 @@ def _create_datetime_features(
191199
if log:
192200
self.logger.warning("No datetime columns specified. Skipping feature extraction.")
193201

202+
if log:
203+
self.logger.info(f"Created new features: {self.features}")
204+
194205
return df

ml_garden/core/steps/encode.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,9 @@ def execute(self, data: DataContainer) -> DataContainer:
5252
self.logger.info("Encoding data")
5353

5454
if not data.is_train:
55-
categorical_features, numeric_features = self._get_feature_types(
56-
data.flow.drop(columns=data.columns_to_ignore_for_training)
57-
)
55+
categorical_features, numeric_features = self._get_feature_types(data.X_prediction)
5856
data.X_prediction, _, _ = self._apply_encoding(
59-
X=data.flow,
57+
X=data.X_prediction,
6058
y=None,
6159
categorical_features=categorical_features,
6260
numeric_features=numeric_features,
@@ -66,8 +64,6 @@ def execute(self, data: DataContainer) -> DataContainer:
6664
return data
6765

6866
categorical_features, numeric_features = self._get_feature_types(data.X_train)
69-
self.logger.info(f"New categorical features: {categorical_features}")
70-
self.logger.info(f"New numeric features: {numeric_features}")
7167

7268
data.X_train, data.y_train, data._encoder = self._apply_encoding(
7369
X=data.X_train,

ml_garden/core/steps/fit_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,6 @@ def predict(self, data: DataContainer) -> DataContainer:
298298
The updated data container
299299
"""
300300
self.logger.info(f"Predicting with {self.model_class.__name__} model")
301-
data.X_prediction = data.flow.drop(columns=data.columns_to_ignore_for_training)
302301
data.flow[data.prediction_column] = data.model.predict(data.X_prediction)
303302
data.predictions = data.flow[data.prediction_column]
304303
return data

ml_garden/core/steps/tabular_split.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class TabularSplitStep(PipelineStep):
1414
"""Split the data."""
1515

16-
used_for_prediction = False
16+
used_for_prediction = True
1717
used_for_training = True
1818

1919
def __init__(
@@ -129,6 +129,13 @@ def execute(self, data: DataContainer) -> DataContainer:
129129
130130
Where df is the DataFrame used as input to the SplitStep
131131
"""
132+
if not data.is_train:
133+
data.X_prediction = data.flow
134+
if data.columns_to_ignore_for_training:
135+
data.X_prediction = data.X_prediction.drop(
136+
columns=data.columns_to_ignore_for_training
137+
)
138+
return data
132139

133140
self.logger.info("Splitting tabular data...")
134141
df = data.flow

tests/core/steps/test_calculate_features.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def input_data() -> pd.DataFrame:
6262
def data(input_data: pd.DataFrame) -> DataContainer:
6363
data = DataContainer({"is_train": True})
6464
data.columns_to_ignore_for_training = []
65-
data.train = input_data
65+
data.X_train = input_data
6666
return data
6767

6868

@@ -72,7 +72,7 @@ def test_skipping_with_no_parameters(data: DataContainer):
7272
result = calculate_features_step.execute(data)
7373

7474
assert isinstance(result, DataContainer)
75-
assert result.train.equals(data.train)
75+
assert result.X_train.equals(data.X_train)
7676

7777

7878
def test_feature_names(data: DataContainer):
@@ -87,22 +87,22 @@ def test_feature_names(data: DataContainer):
8787
result = calculate_features_step.execute(data)
8888

8989
assert isinstance(result, DataContainer)
90-
assert "creation_date_year" in result.train.columns
91-
assert "creation_date_month" in result.train.columns
92-
assert "creation_date_day" in result.train.columns
93-
assert "creation_date_hour" in result.train.columns
94-
assert "creation_date_minute" in result.train.columns
95-
assert "creation_date_second" in result.train.columns
96-
assert "creation_date_weekday" in result.train.columns
97-
assert "creation_date_dayofyear" in result.train.columns
98-
assert "deletion_date_year" in result.train.columns
99-
assert "deletion_date_month" in result.train.columns
100-
assert "deletion_date_day" in result.train.columns
101-
assert "deletion_date_hour" in result.train.columns
102-
assert "deletion_date_minute" in result.train.columns
103-
assert "deletion_date_second" in result.train.columns
104-
assert "deletion_date_weekday" in result.train.columns
105-
assert "deletion_date_dayofyear" in result.train.columns
90+
assert "creation_date_year" in result.X_train.columns
91+
assert "creation_date_month" in result.X_train.columns
92+
assert "creation_date_day" in result.X_train.columns
93+
assert "creation_date_hour" in result.X_train.columns
94+
assert "creation_date_minute" in result.X_train.columns
95+
assert "creation_date_second" in result.X_train.columns
96+
assert "creation_date_weekday" in result.X_train.columns
97+
assert "creation_date_dayofyear" in result.X_train.columns
98+
assert "deletion_date_year" in result.X_train.columns
99+
assert "deletion_date_month" in result.X_train.columns
100+
assert "deletion_date_day" in result.X_train.columns
101+
assert "deletion_date_hour" in result.X_train.columns
102+
assert "deletion_date_minute" in result.X_train.columns
103+
assert "deletion_date_second" in result.X_train.columns
104+
assert "deletion_date_weekday" in result.X_train.columns
105+
assert "deletion_date_dayofyear" in result.X_train.columns
106106

107107

108108
def test_date_columns_are_ignored_for_training(data: DataContainer):
@@ -117,8 +117,8 @@ def test_date_columns_are_ignored_for_training(data: DataContainer):
117117
result = calculate_features_step.execute(data)
118118

119119
assert isinstance(result, DataContainer)
120-
assert "creation_date" in result.columns_to_ignore_for_training
121-
assert "deletion_date" in result.columns_to_ignore_for_training
120+
assert "creation_date" not in result.X_train.columns
121+
assert "deletion_date" not in result.X_train.columns
122122

123123

124124
def test_output_dtypes(data: DataContainer):
@@ -133,14 +133,14 @@ def test_output_dtypes(data: DataContainer):
133133
result = calculate_features_step.execute(data)
134134

135135
assert isinstance(result, DataContainer)
136-
assert result.train["creation_date_year"].dtype == np.dtype("uint16")
137-
assert result.train["creation_date_month"].dtype == np.dtype("uint8")
138-
assert result.train["creation_date_day"].dtype == np.dtype("uint8")
139-
assert result.train["creation_date_hour"].dtype == np.dtype("uint8")
140-
assert result.train["creation_date_minute"].dtype == np.dtype("uint8")
141-
assert result.train["creation_date_second"].dtype == np.dtype("uint8")
142-
assert result.train["creation_date_weekday"].dtype == np.dtype("uint8")
143-
assert result.train["creation_date_dayofyear"].dtype == np.dtype("uint16")
136+
assert result.X_train["creation_date_year"].dtype == np.dtype("uint16")
137+
assert result.X_train["creation_date_month"].dtype == np.dtype("uint8")
138+
assert result.X_train["creation_date_day"].dtype == np.dtype("uint8")
139+
assert result.X_train["creation_date_hour"].dtype == np.dtype("uint8")
140+
assert result.X_train["creation_date_minute"].dtype == np.dtype("uint8")
141+
assert result.X_train["creation_date_second"].dtype == np.dtype("uint8")
142+
assert result.X_train["creation_date_weekday"].dtype == np.dtype("uint8")
143+
assert result.X_train["creation_date_dayofyear"].dtype == np.dtype("uint16")
144144

145145

146146
def test_output_values(data: DataContainer):
@@ -155,28 +155,28 @@ def test_output_values(data: DataContainer):
155155
result = calculate_features_step.execute(data)
156156

157157
assert isinstance(result, DataContainer)
158-
assert result.train["creation_date_year"].equals(
158+
assert result.X_train["creation_date_year"].equals(
159159
pd.Series([2023, 2023, 2023, 2023, 2023, 2023, 2024, 2024], dtype="uint16")
160160
)
161-
assert result.train["creation_date_month"].equals(
161+
assert result.X_train["creation_date_month"].equals(
162162
pd.Series([1, 1, 1, 1, 1, 11, 2, 3], dtype="uint8")
163163
)
164-
assert result.train["creation_date_day"].equals(
164+
assert result.X_train["creation_date_day"].equals(
165165
pd.Series([1, 2, 3, 4, 5, 1, 28, 28], dtype="uint8")
166166
)
167-
assert result.train["creation_date_hour"].equals(
167+
assert result.X_train["creation_date_hour"].equals(
168168
pd.Series([0, 0, 0, 0, 0, 0, 0, 0], dtype="uint8")
169169
)
170-
assert result.train["creation_date_minute"].equals(
170+
assert result.X_train["creation_date_minute"].equals(
171171
pd.Series([0, 0, 0, 0, 0, 0, 0, 0], dtype="uint8")
172172
)
173-
assert result.train["creation_date_second"].equals(
173+
assert result.X_train["creation_date_second"].equals(
174174
pd.Series([0, 0, 0, 0, 0, 0, 0, 0], dtype="uint8")
175175
)
176-
assert result.train["creation_date_weekday"].equals(
176+
assert result.X_train["creation_date_weekday"].equals(
177177
pd.Series([6, 0, 1, 2, 3, 2, 2, 3], dtype="uint8")
178178
)
179-
assert result.train["creation_date_dayofyear"].equals(
179+
assert result.X_train["creation_date_dayofyear"].equals(
180180
pd.Series([1, 2, 3, 4, 5, 305, 59, 88], dtype="uint16")
181181
)
182182

@@ -214,7 +214,7 @@ def test_init_with_unsupported_features():
214214

215215
def test_execute_with_prediction(data: DataContainer):
216216
data.is_train = False
217-
data.flow = data.train.copy()
217+
data.X_prediction = data.X_train.copy()
218218

219219
datetime_columns = ["creation_date"]
220220
features = ["year", "month", "day"]
@@ -226,6 +226,6 @@ def test_execute_with_prediction(data: DataContainer):
226226
result = calculate_features_step.execute(data)
227227

228228
assert isinstance(result, DataContainer)
229-
assert "creation_date_year" in result.flow.columns
230-
assert "creation_date_month" in result.flow.columns
231-
assert "creation_date_day" in result.flow.columns
229+
assert "creation_date_year" in result.X_prediction.columns
230+
assert "creation_date_month" in result.X_prediction.columns
231+
assert "creation_date_day" in result.X_prediction.columns

0 commit comments

Comments
 (0)