Skip to content

Commit a9f7fff

Browse files
authored
Merge pull request #166 from keli6/feature/TRXF-PMML-ScoreCard-export
TRXF pmml scorecard writer
2 parents f189c7a + 17fbd2e commit a9f7fff

File tree

12 files changed

+654
-94
lines changed

12 files changed

+654
-94
lines changed

aix360/algorithms/rule_induction/trxf/pmml_export/models/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,23 @@
22
from .data_dictionary import DataField
33
from .data_dictionary import DataType
44
from .data_dictionary import OpType
5+
from .data_dictionary import Value
6+
from .data_dictionary import Restriction
57
from .mining_schema import MiningField
68
from .mining_schema import MiningFieldUsageType
79
from .mining_schema import MiningSchema
810
from .predicate import BooleanOperator
911
from .predicate import CompoundPredicate
1012
from .predicate import Operator
1113
from .predicate import SimplePredicate
14+
from .predicate import TruePredicate
1215
from .rule import SimpleRule
1316
from .rule import DEFAULT_WEIGHT
1417
from .rule import DEFAULT_CONFIDENCE
1518
from .ruleset import RuleSelectionMethod
1619
from .ruleset import RuleSet
1720
from .ruleset_model import RuleSetModel
1821
from .pmml_ruleset_model import SimplePMMLRuleSetModel
22+
from .scorecard import Scorecard, Output, OutputField
23+
from .characteristics import Characteristics, Characteristic
24+
from .attribute import Attribute, ComplexPartialScore
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import typing
2+
3+
from . import predicate
4+
5+
try:
6+
# python >= 3.7
7+
from dataclasses import dataclass as dataclass
8+
from dataclasses import field as field
9+
except ImportError:
10+
from attr import s as dataclass
11+
from attr import ib as field
12+
13+
14+
@dataclass(frozen=True)
15+
class ComplexPartialScore:
16+
feature_name: str = field()
17+
multiplier: str = field(default=None)
18+
constant: str = field(default=None)
19+
20+
21+
@dataclass(frozen=True)
22+
class Attribute:
23+
score: typing.Union[str, ComplexPartialScore]
24+
predicate: typing.Union[predicate.SimplePredicate, predicate.CompoundPredicate, predicate.TruePredicate]
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import typing
2+
3+
from . import attribute
4+
5+
try:
6+
# python >= 3.7
7+
from dataclasses import dataclass as dataclass
8+
except ImportError:
9+
from attr import s as dataclass
10+
11+
12+
@dataclass(frozen=True)
13+
class Characteristic:
14+
name: str
15+
attributes: typing.List[attribute.Attribute]
16+
17+
18+
@dataclass(frozen=True)
19+
class Characteristics:
20+
characteristics: typing.List[Characteristic]

aix360/algorithms/rule_induction/trxf/pmml_export/models/data_dictionary.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,16 @@ class OpType(enum.Enum):
2323
continuous = 2
2424

2525

26+
class Restriction(enum.Enum):
27+
valid = 0
28+
invalid = 1
29+
missing = 2
30+
31+
2632
@dataclass(frozen=True)
2733
class Value:
2834
value: str = field()
35+
property: Restriction = field(default=Restriction.valid)
2936

3037

3138
@dataclass(frozen=True)
@@ -35,6 +42,11 @@ class DataField:
3542
dataType: DataType = field()
3643
values: typing.Optional[typing.List[Value]] = field(default=None)
3744

45+
def __post_init__(self):
46+
if self.values and \
47+
(self.dataType is not DataType.string or self.optype not in (OpType.ordinal, OpType.categorical)):
48+
raise ValueError
49+
3850

3951
@dataclass(frozen=True)
4052
class DataDictionary:

aix360/algorithms/rule_induction/trxf/pmml_export/models/predicate.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,19 @@ class Operator(enum.Enum):
1616
lessOrEqual = 3
1717
greaterThan = 4
1818
greaterOrEqual = 5
19+
isMissing = 6
1920

2021

2122
@dataclass(frozen=True)
2223
class SimplePredicate:
2324
operator: Operator = field()
24-
value: str = field()
2525
field: str = field()
26+
value: typing.Optional[str] = None
27+
28+
29+
@dataclass(frozen=True)
30+
class TruePredicate:
31+
pass
2632

2733

2834
# Use functional api to add aliases for `or` and `and`
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import typing
2+
3+
from . import characteristics
4+
from . import data_dictionary
5+
from . import mining_schema
6+
7+
try:
8+
# python >= 3.7
9+
from dataclasses import dataclass as dataclass
10+
from dataclasses import field as field
11+
except ImportError:
12+
from attr import s as dataclass
13+
from attr import ib as field
14+
15+
16+
@dataclass(frozen=True)
17+
class OutputField:
18+
name: str
19+
feature: str
20+
dataType: data_dictionary.DataType
21+
optype: data_dictionary.OpType
22+
23+
24+
@dataclass(frozen=True)
25+
class Output:
26+
outputFields: typing.List[OutputField]
27+
28+
29+
@dataclass(frozen=True)
30+
class Scorecard:
31+
dataDictionary: data_dictionary.DataDictionary
32+
miningSchema: mining_schema.MiningSchema
33+
output: Output
34+
characteristics: characteristics.Characteristics
35+
initialScore: str = field(default="0")
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import abc
2+
import typing
23

34
from .. import models
45

56

67
class AbstractSerializer(abc.ABC):
78

89
@abc.abstractmethod
9-
def serialize(self, simple_pmml_ruleset_model: models.SimplePMMLRuleSetModel) -> str:
10+
def serialize(self, model: typing.Union[models.SimplePMMLRuleSetModel, models.Scorecard]) -> str:
1011
raise NotImplementedError

aix360/algorithms/rule_induction/trxf/pmml_export/serializer/nyoka_serializer.py

Lines changed: 84 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import typing
12
import datetime
23
import io
34

@@ -16,36 +17,49 @@ class NyokaSerializer(AbstractSerializer):
1617
def __init__(self, timestamp: datetime = None):
1718
self._timestamp = timestamp
1819

19-
def serialize(self, simple_pmml_ruleset_model: models.SimplePMMLRuleSetModel, timestamp: datetime = None) -> str:
20-
pmml_model = self._nyoka_pmml_model(simple_pmml_ruleset_model)
20+
def serialize(self, model: typing.Union[models.SimplePMMLRuleSetModel, models.Scorecard]) -> str:
21+
if isinstance(model, models.SimplePMMLRuleSetModel):
22+
pmml_model = self._nyoka_pmml_model(
23+
model,
24+
RuleSetModel=None if model.ruleSetModel is None else [self._nyoka_rule_set_model(model.ruleSetModel)])
25+
elif isinstance(model, models.Scorecard):
26+
pmml_model = self._nyoka_pmml_model(
27+
model,
28+
Scorecard=None if model is None else [self._nyoka_scorecard_model(model)])
29+
else:
30+
raise NotImplemented
2131
string_io = io.StringIO()
2232
pmml_model.export(outfile=string_io, level=0)
2333
return string_io.getvalue()
2434

25-
def _nyoka_pmml_model(self, simple_pmml_ruleset_model: models.SimplePMMLRuleSetModel) -> nyoka_pmml.PMML:
26-
timestamp = datetime.datetime.now() if self._timestamp is None else self._timestamp
35+
def _nyoka_pmml_model(
36+
self,
37+
model: typing.Union[models.SimplePMMLRuleSetModel, models.Scorecard],
38+
**build_args
39+
) -> nyoka_pmml.PMML:
2740
return nyoka_pmml.PMML(
2841
version=nyoka_constants.PMML_SCHEMA.VERSION,
2942
Header=nyoka_pmml.Header(
3043
copyright=NyokaSerializer.COPYRIGHT_STRING,
3144
description=nyoka_constants.HEADER_INFO.DEFAULT_DESCRIPTION,
32-
Timestamp=nyoka_pmml.Timestamp(timestamp),
33-
Application=nyoka_pmml.Application(
34-
name=NyokaSerializer.APPLICATION_NAME, version=version.version)),
35-
DataDictionary=None if simple_pmml_ruleset_model.dataDictionary is None else self._nyoka_data_dictionary(
36-
simple_pmml_ruleset_model.dataDictionary),
37-
RuleSetModel=None if simple_pmml_ruleset_model.ruleSetModel is None else [
38-
self._nyoka_rule_set_model(simple_pmml_ruleset_model.ruleSetModel)])
45+
Timestamp=nyoka_pmml.Timestamp(datetime.datetime.now() if self._timestamp is None else self._timestamp),
46+
Application=nyoka_pmml.Application(name=NyokaSerializer.APPLICATION_NAME, version=version.version)),
47+
DataDictionary=None if model.dataDictionary is None else self._nyoka_data_dictionary(model.dataDictionary),
48+
**build_args)
3949

4050
def _nyoka_data_dictionary(self, data_dictionary: models.DataDictionary) -> nyoka_pmml.DataDictionary:
4151
return nyoka_pmml.DataDictionary(
42-
numberOfFields=0 if data_dictionary.dataFields is None else len(data_dictionary.dataFields),
43-
DataField=None if data_dictionary.dataFields is None else [
44-
nyoka_pmml.DataField(name=f.name, optype=f.optype.name, dataType=f.dataType.name,
45-
Value=None if f.values is None else [
46-
nyoka_pmml.Value(value=v.value) for v in f.values
47-
])
48-
for f in data_dictionary.dataFields])
52+
numberOfFields=0 if not data_dictionary.dataFields else len(data_dictionary.dataFields),
53+
DataField=None if not data_dictionary.dataFields else [
54+
self._nyoka_data_field(f) for f in data_dictionary.dataFields])
55+
56+
def _nyoka_data_field(self, data_field: models.DataField) -> nyoka_pmml.DataField:
57+
return nyoka_pmml.DataField(
58+
name=data_field.name,
59+
optype=data_field.optype.name,
60+
dataType=data_field.dataType.name,
61+
Value=None if not data_field.values else [
62+
nyoka_pmml.Value(value=val.value, property=val.property.name) for val in data_field.values])
4963

5064
def _nyoka_rule_set_model(self, rule_set_model: models.RuleSetModel) -> nyoka_pmml.RuleSetModel:
5165
return nyoka_pmml.RuleSetModel(
@@ -94,3 +108,55 @@ def _nyoka_rule(self, simple_rule: models.SimpleRule) -> nyoka_pmml.SimpleRule:
94108
nbCorrect=simple_rule.nbCorrect,
95109
confidence=simple_rule.confidence,
96110
weight=simple_rule.weight)
111+
112+
def _nyoka_scorecard_model(self, scorecard: models.Scorecard) -> nyoka_pmml.Scorecard:
113+
return nyoka_pmml.Scorecard(
114+
functionName='regression',
115+
algorithmName='ScoreCard',
116+
MiningSchema=None if scorecard.miningSchema is None else self._nyoka_mining_schema(
117+
scorecard.miningSchema),
118+
initialScore=scorecard.initialScore,
119+
useReasonCodes="false",
120+
Output=None if scorecard.output is None else nyoka_pmml.Output(
121+
OutputField=[
122+
nyoka_pmml.OutputField(
123+
name=outputField.name,
124+
feature=outputField.feature,
125+
dataType=outputField.dataType.name,
126+
optype=outputField.optype.name) for outputField in scorecard.output.outputFields]),
127+
Characteristics=None if scorecard.characteristics is None else self._nyoka_pmml_characteristics(
128+
scorecard.characteristics))
129+
130+
def _nyoka_pmml_characteristics(self, characteristics: models.Characteristics) -> nyoka_pmml.Characteristics:
131+
return nyoka_pmml.Characteristics(
132+
Characteristic=[
133+
nyoka_pmml.Characteristic(
134+
name=characteristic.name,
135+
Attribute=[self._nyoka_pmml_attributes(attribute) for attribute in characteristic.attributes])
136+
for characteristic in characteristics.characteristics])
137+
138+
def _nyoka_pmml_attributes(self, attribute: models.Attribute) -> nyoka_pmml.Attribute:
139+
return nyoka_pmml.Attribute(
140+
partialScore=attribute.score if not isinstance(attribute.score, models.ComplexPartialScore) else None,
141+
ComplexPartialScore=nyoka_pmml.ComplexPartialScore(
142+
Apply=nyoka_pmml.Apply(
143+
function='+',
144+
Apply_member=[nyoka_pmml.Apply(
145+
function='*',
146+
FieldRef=[nyoka_pmml.FieldRef(field=attribute.score.feature_name)],
147+
Constant=[nyoka_pmml.Constant(valueOf_=attribute.score.multiplier)])],
148+
Constant=[nyoka_pmml.Constant(valueOf_=attribute.score.constant)])) if isinstance(
149+
attribute.score, models.ComplexPartialScore) else None,
150+
SimplePredicate=None if (attribute.predicate is None or not isinstance(
151+
attribute.predicate, models.SimplePredicate)) else nyoka_pmml.SimplePredicate(
152+
field=attribute.predicate.field,
153+
operator=attribute.predicate.operator.name,
154+
value=attribute.predicate.value),
155+
CompoundPredicate=None if (attribute.predicate is None or not isinstance(
156+
attribute.predicate, models.CompoundPredicate)) else nyoka_pmml.CompoundPredicate(
157+
booleanOperator=attribute.predicate.booleanOperator.name,
158+
SimplePredicate=[
159+
nyoka_pmml.SimplePredicate(field=sp.field, operator=sp.operator.name, value=sp.value)
160+
for sp in attribute.predicate.simplePredicates]),
161+
True_=None if (attribute.predicate is None or not isinstance(
162+
attribute.predicate, models.TruePredicate)) else nyoka_pmml.True_())

0 commit comments

Comments
 (0)