Skip to content

Commit f189c7a

Browse files
authored
Merge pull request #165 from kmyusk/pmml_cat
PMML export enhancements and 3.8-3.6 compatibility of rule induction code
2 parents 5f9b1eb + 511b1de commit f189c7a

File tree

8 files changed

+499
-69
lines changed

8 files changed

+499
-69
lines changed

aix360/algorithms/rule_induction/trxf/pmml_export/models/data_dictionary.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,17 @@ class OpType(enum.Enum):
2323
continuous = 2
2424

2525

26+
@dataclass(frozen=True)
27+
class Value:
28+
value: str = field()
29+
30+
2631
@dataclass(frozen=True)
2732
class DataField:
2833
name: str = field()
2934
optype: OpType = field()
3035
dataType: DataType = field()
36+
values: typing.Optional[typing.List[Value]] = field(default=None)
3137

3238

3339
@dataclass(frozen=True)

aix360/algorithms/rule_induction/trxf/pmml_export/reader/trxf_reader.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
from typing import Dict
2+
13
import numpy as np
24
import pandas as pd
35

46
from aix360.algorithms.rule_induction.trxf.classifier import ruleset_classifier
57
from aix360.algorithms.rule_induction.trxf.classifier.ruleset_classifier import RuleSetClassifier
68
from aix360.algorithms.rule_induction.trxf.core import Conjunction, Relation
79
from aix360.algorithms.rule_induction.trxf.pmml_export import models
10+
from aix360.algorithms.rule_induction.trxf.pmml_export.models.data_dictionary import Value
811
from aix360.algorithms.rule_induction.trxf.pmml_export.reader import AbstractReader
912
from aix360.algorithms.rule_induction.trxf.pmml_export.models import SimplePredicate, Operator, CompoundPredicate, \
1013
BooleanOperator
@@ -38,13 +41,17 @@ def read(self, trxf_classifier: RuleSetClassifier) -> models.SimplePMMLRuleSetMo
3841
assert self._data_dictionary is not None
3942
return models.SimplePMMLRuleSetModel(dataDictionary=self._data_dictionary, ruleSetModel=rule_set_model)
4043

41-
def load_data_dictionary(self, X: pd.DataFrame):
44+
def load_data_dictionary(self, X: pd.DataFrame, values: Dict = None):
4245
"""
4346
Extract the data dictionary from a feature dataframe, and store it
47+
48+
@param X: Input dataframe
49+
@param values: A dict mapping column name to a list of possible categorical values. It will be inferred from X if not provided.
4450
"""
4551
dtypes = X.dtypes
4652
data_fields = []
4753
for index, value in dtypes.items():
54+
vals = None
4855
if np.issubdtype(value, np.integer):
4956
data_type = models.DataType.integer
5057
op_type = models.OpType.ordinal
@@ -60,7 +67,9 @@ def load_data_dictionary(self, X: pd.DataFrame):
6067
else:
6168
data_type = models.DataType.string
6269
op_type = models.OpType.categorical
63-
data_fields.append(models.DataField(name=str(index), optype=op_type, dataType=data_type))
70+
vals = values[index] if values is not None and index in values else list(X[index].unique())
71+
wrapped_vals = list(map(lambda v: Value(v), vals)) if vals is not None else vals
72+
data_fields.append(models.DataField(name=str(index), optype=op_type, dataType=data_type, values=wrapped_vals))
6473
self._data_dictionary = models.DataDictionary(data_fields)
6574

6675

aix360/algorithms/rule_induction/trxf/pmml_export/serializer/nyoka_serializer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ def _nyoka_data_dictionary(self, data_dictionary: models.DataDictionary) -> nyok
4141
return nyoka_pmml.DataDictionary(
4242
numberOfFields=0 if data_dictionary.dataFields is None else len(data_dictionary.dataFields),
4343
DataField=None if data_dictionary.dataFields is None else [
44-
nyoka_pmml.DataField(name=f.name, optype=f.optype.name, dataType=f.dataType.name)
44+
nyoka_pmml.DataField(name=f.name, optype=f.optype.name, dataType=f.dataType.name,
45+
Value=None if f.values is None else [
46+
nyoka_pmml.Value(value=v.value) for v in f.values
47+
])
4548
for f in data_dictionary.dataFields])
4649

4750
def _nyoka_rule_set_model(self, rule_set_model: models.RuleSetModel) -> nyoka_pmml.RuleSetModel:

tests/rule_induction/trxf/pmml_export/resources/adult.pmml

Lines changed: 235 additions & 44 deletions
Large diffs are not rendered by default.
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<PMML xmlns="http://www.dmg.org/PMML-4_4" version="4.4.1">
3+
<Header copyright="Copyright IBM Corp, exported to PMML by Nyoka (c) 2022 Software AG" description="Default description">
4+
<Application name="SimpleRuleSetExport" version="0.0.1"/>
5+
<Timestamp>1970-01-01 00:00:00+00:00</Timestamp>
6+
</Header>
7+
<DataDictionary numberOfFields="14">
8+
<DataField name="age" optype="continuous" dataType="double"/>
9+
<DataField name="workclass" optype="categorical" dataType="string">
10+
<Value value="Private"/>
11+
<Value value="State-gov"/>
12+
<Value value="Self-emp-not-inc"/>
13+
<Value value="Local-gov"/>
14+
<Value value="Federal-gov"/>
15+
<Value value="Self-emp-inc"/>
16+
<Value value="Without-pay"/>
17+
</DataField>
18+
<DataField name="fnlwgt" optype="continuous" dataType="double"/>
19+
<DataField name="education" optype="categorical" dataType="string">
20+
<Value value="HS-grad"/>
21+
<Value value="10th"/>
22+
<Value value="Bachelors"/>
23+
<Value value="Assoc-acdm"/>
24+
<Value value="Some-college"/>
25+
<Value value="Doctorate"/>
26+
<Value value="Prof-school"/>
27+
<Value value="9th"/>
28+
<Value value="Assoc-voc"/>
29+
<Value value="Masters"/>
30+
<Value value="7th-8th"/>
31+
<Value value="11th"/>
32+
<Value value="1st-4th"/>
33+
<Value value="5th-6th"/>
34+
<Value value="12th"/>
35+
<Value value="Preschool"/>
36+
</DataField>
37+
<DataField name="education_num" optype="continuous" dataType="double"/>
38+
<DataField name="marital_status" optype="categorical" dataType="string">
39+
<Value value="Married-civ-spouse"/>
40+
<Value value="Divorced"/>
41+
<Value value="Never-married"/>
42+
<Value value="Widowed"/>
43+
<Value value="Separated"/>
44+
<Value value="Married-spouse-absent"/>
45+
<Value value="Married-AF-spouse"/>
46+
</DataField>
47+
<DataField name="occupation" optype="categorical" dataType="string">
48+
<Value value="Transport-moving"/>
49+
<Value value="Craft-repair"/>
50+
<Value value="Sales"/>
51+
<Value value="Adm-clerical"/>
52+
<Value value="Prof-specialty"/>
53+
<Value value="Other-service"/>
54+
<Value value="Exec-managerial"/>
55+
<Value value="Farming-fishing"/>
56+
<Value value="Machine-op-inspct"/>
57+
<Value value="Handlers-cleaners"/>
58+
<Value value="Protective-serv"/>
59+
<Value value="Tech-support"/>
60+
<Value value="Priv-house-serv"/>
61+
<Value value="Armed-Forces"/>
62+
</DataField>
63+
<DataField name="relationship" optype="categorical" dataType="string">
64+
<Value value="Husband"/>
65+
<Value value="Not-in-family"/>
66+
<Value value="Wife"/>
67+
<Value value="Own-child"/>
68+
<Value value="Unmarried"/>
69+
<Value value="Other-relative"/>
70+
</DataField>
71+
<DataField name="race" optype="categorical" dataType="string">
72+
<Value value="White"/>
73+
<Value value="Black"/>
74+
<Value value="Other"/>
75+
<Value value="Asian-Pac-Islander"/>
76+
<Value value="Amer-Indian-Eskimo"/>
77+
</DataField>
78+
<DataField name="sex" optype="categorical" dataType="string">
79+
<Value value="Male"/>
80+
<Value value="Female"/>
81+
</DataField>
82+
<DataField name="capital_gain" optype="continuous" dataType="double"/>
83+
<DataField name="capital_loss" optype="continuous" dataType="double"/>
84+
<DataField name="hours_per_week" optype="continuous" dataType="double"/>
85+
<DataField name="native_country" optype="categorical" dataType="string">
86+
<Value value="United-States"/>
87+
<Value value="Portugal"/>
88+
<Value value="Cuba"/>
89+
<Value value="Mexico"/>
90+
<Value value="France"/>
91+
<Value value="Jamaica"/>
92+
<Value value="Haiti"/>
93+
<Value value="Honduras"/>
94+
<Value value="India"/>
95+
<Value value="Dominican-Republic"/>
96+
<Value value="Outlying-US(Guam-USVI-etc)"/>
97+
<Value value="South"/>
98+
<Value value="Scotland"/>
99+
<Value value="Italy"/>
100+
<Value value="Germany"/>
101+
<Value value="Philippines"/>
102+
<Value value="Vietnam"/>
103+
<Value value="El-Salvador"/>
104+
<Value value="Nicaragua"/>
105+
<Value value="China"/>
106+
<Value value="Trinadad&amp;Tobago"/>
107+
<Value value="Puerto-Rico"/>
108+
<Value value="Japan"/>
109+
<Value value="Iran"/>
110+
<Value value="Guatemala"/>
111+
<Value value="England"/>
112+
<Value value="Poland"/>
113+
<Value value="Canada"/>
114+
<Value value="Cambodia"/>
115+
<Value value="Greece"/>
116+
<Value value="Thailand"/>
117+
<Value value="Ireland"/>
118+
<Value value="Hong"/>
119+
<Value value="Taiwan"/>
120+
<Value value="Ecuador"/>
121+
<Value value="Peru"/>
122+
<Value value="Yugoslavia"/>
123+
<Value value="Columbia"/>
124+
<Value value="Hungary"/>
125+
<Value value="Laos"/>
126+
<Value value="Holand-Netherlands"/>
127+
</DataField>
128+
</DataDictionary>
129+
<RuleSetModel functionName="classification" algorithmName="RuleSet">
130+
<MiningSchema>
131+
<MiningField name="marital_status" usageType="active"/>
132+
<MiningField name="education_num" usageType="active"/>
133+
<MiningField name="hours_per_week" usageType="active"/>
134+
<MiningField name="education" usageType="active"/>
135+
<MiningField name="age" usageType="active"/>
136+
<MiningField name="occupation" usageType="active"/>
137+
<MiningField name="fnlwgt" usageType="active"/>
138+
<MiningField name="capital_gain" usageType="active"/>
139+
<MiningField name="capital_loss" usageType="active"/>
140+
</MiningSchema>
141+
<RuleSet defaultScore="&lt;=50K">
142+
<RuleSelectionMethod criterion="weightedMax"/>
143+
<SimpleRule id="[marital_status == Married-civ-spouse] ^ [education_num &lt;= 12.0] ^ [education_num &gt;= 8.0] ^ [hours_per_week &gt;= 38.0] ^ [education == Some-college] ^ [age &lt;= 57.0] ^ [hours_per_week &lt;= 45.0] ^ [age &gt;= 48.0]" score="&gt;50K" recordCount="15081" nbCorrect="11329" confidence="0.609271523178808" weight="0.609271523178808">
144+
<CompoundPredicate booleanOperator="and">
145+
<SimplePredicate field="marital_status" operator="equal" value="Married-civ-spouse"/>
146+
<SimplePredicate field="education_num" operator="lessOrEqual" value="12.0"/>
147+
<SimplePredicate field="education_num" operator="greaterOrEqual" value="8.0"/>
148+
<SimplePredicate field="hours_per_week" operator="greaterOrEqual" value="38.0"/>
149+
<SimplePredicate field="education" operator="equal" value="Some-college"/>
150+
<SimplePredicate field="age" operator="lessOrEqual" value="57.0"/>
151+
<SimplePredicate field="hours_per_week" operator="lessOrEqual" value="45.0"/>
152+
<SimplePredicate field="age" operator="greaterOrEqual" value="48.0"/>
153+
</CompoundPredicate>
154+
</SimpleRule>
155+
<SimpleRule id="[marital_status == Married-civ-spouse] ^ [education_num &lt;= 12.0] ^ [age &gt;= 37.0] ^ [occupation == Tech-support]" score="&gt;50K" recordCount="15081" nbCorrect="11323" confidence="0.684931506849315" weight="0.684931506849315">
156+
<CompoundPredicate booleanOperator="and">
157+
<SimplePredicate field="marital_status" operator="equal" value="Married-civ-spouse"/>
158+
<SimplePredicate field="education_num" operator="lessOrEqual" value="12.0"/>
159+
<SimplePredicate field="age" operator="greaterOrEqual" value="37.0"/>
160+
<SimplePredicate field="occupation" operator="equal" value="Tech-support"/>
161+
</CompoundPredicate>
162+
</SimpleRule>
163+
<SimpleRule id="[marital_status == Married-civ-spouse] ^ [education_num &lt;= 12.0] ^ [education_num &gt;= 9.0] ^ [hours_per_week &gt;= 40.0] ^ [age &gt;= 42.0] ^ [education == HS-grad] ^ [age &lt;= 53.0] ^ [fnlwgt &gt;= 154227.0] ^ [fnlwgt &lt;= 163948.0]" score="&gt;50K" recordCount="15081" nbCorrect="11287" confidence="0.3448275862068966" weight="0.3448275862068966">
164+
<CompoundPredicate booleanOperator="and">
165+
<SimplePredicate field="marital_status" operator="equal" value="Married-civ-spouse"/>
166+
<SimplePredicate field="education_num" operator="lessOrEqual" value="12.0"/>
167+
<SimplePredicate field="education_num" operator="greaterOrEqual" value="9.0"/>
168+
<SimplePredicate field="hours_per_week" operator="greaterOrEqual" value="40.0"/>
169+
<SimplePredicate field="age" operator="greaterOrEqual" value="42.0"/>
170+
<SimplePredicate field="education" operator="equal" value="HS-grad"/>
171+
<SimplePredicate field="age" operator="lessOrEqual" value="53.0"/>
172+
<SimplePredicate field="fnlwgt" operator="greaterOrEqual" value="154227.0"/>
173+
<SimplePredicate field="fnlwgt" operator="lessOrEqual" value="163948.0"/>
174+
</CompoundPredicate>
175+
</SimpleRule>
176+
<SimpleRule id="[marital_status == Married-civ-spouse] ^ [age &gt;= 37.0] ^ [education_num &lt;= 12.0] ^ [age &lt;= 47.0] ^ [education_num &gt;= 10.0] ^ [occupation == Sales] ^ [fnlwgt &gt;= 131827.0]" score="&gt;50K" recordCount="15081" nbCorrect="11300" confidence="0.5333333333333333" weight="0.5333333333333333">
177+
<CompoundPredicate booleanOperator="and">
178+
<SimplePredicate field="marital_status" operator="equal" value="Married-civ-spouse"/>
179+
<SimplePredicate field="age" operator="greaterOrEqual" value="37.0"/>
180+
<SimplePredicate field="education_num" operator="lessOrEqual" value="12.0"/>
181+
<SimplePredicate field="age" operator="lessOrEqual" value="47.0"/>
182+
<SimplePredicate field="education_num" operator="greaterOrEqual" value="10.0"/>
183+
<SimplePredicate field="occupation" operator="equal" value="Sales"/>
184+
<SimplePredicate field="fnlwgt" operator="greaterOrEqual" value="131827.0"/>
185+
</CompoundPredicate>
186+
</SimpleRule>
187+
<SimpleRule id="[marital_status == Married-civ-spouse] ^ [education_num &gt;= 13.0]" score="&gt;50K" recordCount="15081" nbCorrect="12283" confidence="0.7355608591885442" weight="0.7355608591885442">
188+
<CompoundPredicate booleanOperator="and">
189+
<SimplePredicate field="marital_status" operator="equal" value="Married-civ-spouse"/>
190+
<SimplePredicate field="education_num" operator="greaterOrEqual" value="13.0"/>
191+
</CompoundPredicate>
192+
</SimpleRule>
193+
<SimpleRule id="[marital_status == Married-civ-spouse] ^ [capital_gain &gt;= 5178.0]" score="&gt;50K" recordCount="15081" nbCorrect="11863" confidence="0.9913344887348353" weight="0.9913344887348353">
194+
<CompoundPredicate booleanOperator="and">
195+
<SimplePredicate field="marital_status" operator="equal" value="Married-civ-spouse"/>
196+
<SimplePredicate field="capital_gain" operator="greaterOrEqual" value="5178.0"/>
197+
</CompoundPredicate>
198+
</SimpleRule>
199+
<SimpleRule id="[marital_status == Married-civ-spouse] ^ [education_num &lt;= 12.0] ^ [capital_loss &gt;= 1741.0]" score="&gt;50K" recordCount="15081" nbCorrect="11370" confidence="0.6968085106382979" weight="0.6968085106382979">
200+
<CompoundPredicate booleanOperator="and">
201+
<SimplePredicate field="marital_status" operator="equal" value="Married-civ-spouse"/>
202+
<SimplePredicate field="education_num" operator="lessOrEqual" value="12.0"/>
203+
<SimplePredicate field="capital_loss" operator="greaterOrEqual" value="1741.0"/>
204+
</CompoundPredicate>
205+
</SimpleRule>
206+
<SimpleRule id="[marital_status == Married-civ-spouse] ^ [education_num &gt;= 10.0] ^ [fnlwgt &gt;= 118088.0] ^ [occupation == Exec-managerial] ^ [hours_per_week &gt;= 41.0]" score="&gt;50K" recordCount="15081" nbCorrect="11535" confidence="0.8103896103896104" weight="0.8103896103896104">
207+
<CompoundPredicate booleanOperator="and">
208+
<SimplePredicate field="marital_status" operator="equal" value="Married-civ-spouse"/>
209+
<SimplePredicate field="education_num" operator="greaterOrEqual" value="10.0"/>
210+
<SimplePredicate field="fnlwgt" operator="greaterOrEqual" value="118088.0"/>
211+
<SimplePredicate field="occupation" operator="equal" value="Exec-managerial"/>
212+
<SimplePredicate field="hours_per_week" operator="greaterOrEqual" value="41.0"/>
213+
</CompoundPredicate>
214+
</SimpleRule>
215+
</RuleSet>
216+
</RuleSetModel>
217+
</PMML>

tests/rule_induction/trxf/pmml_export/resources/toto.pmml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
</Header>
77
<DataDictionary numberOfFields="4">
88
<DataField name="toto0" optype="continuous" dataType="double"/>
9-
<DataField name="toto1" optype="categorical" dataType="string"/>
9+
<DataField name="toto1" optype="categorical" dataType="string">
10+
<Value value="foo"/>
11+
</DataField>
1012
<DataField name="toto2" optype="categorical" dataType="boolean"/>
1113
<DataField name="toto3" optype="ordinal" dataType="integer"/>
1214
</DataDictionary>

0 commit comments

Comments
 (0)