Skip to content

Commit 2848a42

Browse files
authored
Remove model_missing_values from ClusterBasedNormalizer call (#310)
* Warning fix * Feedback * Up rdt version
1 parent 6373a47 commit 2848a42

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

ctgan/data_transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def _fit_continuous(self, data):
4646
A ``ColumnTransformInfo`` object.
4747
"""
4848
column_name = data.columns[0]
49-
gm = ClusterBasedNormalizer(model_missing_values=True, max_clusters=min(len(data), 10))
49+
gm = ClusterBasedNormalizer(
50+
missing_value_generation='from_column', max_clusters=min(len(data), 10))
5051
gm.fit(data, column_name)
5152
num_components = sum(gm.valid_component_indicator)
5253

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"torch>=1.8.0;python_version<'3.10'",
2222
"torch>=1.11.0;python_version>='3.10' and python_version<'3.11'",
2323
"torch>=2.0.0;python_version>='3.11'",
24-
'rdt>=1.3.0,<2.0',
24+
'rdt>=1.6.1,<2.0',
2525
]
2626

2727
setup_requires = [

tests/unit/test_data_transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def test__fit_continuous_max_clusters(self, MockCBN):
7575
transformer._fit_continuous(data)
7676

7777
# Assert
78-
MockCBN.assert_called_once_with(model_missing_values=True, max_clusters=len(data))
78+
MockCBN.assert_called_once_with(
79+
missing_value_generation='from_column', max_clusters=len(data))
7980

8081
@patch('ctgan.data_transformer.OneHotEncoder')
8182
def test___fit_discrete(self, MockOHE):

0 commit comments

Comments
 (0)