Skip to content

Commit be26e89

Browse files
committed
Feed data as DataFrame instead of as an array to some models
1 parent 06d7d71 commit be26e89

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

python/model/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ def explain(self, features, samples=None):
111111
raise NotImplementedError()
112112

113113
# Private
114+
def _get_predictor_type(self):
115+
return str(type(self._get_predictor()))
116+
114117
def _hydrate(self, model, metadata):
115118
# Fill attributes
116119
self._model = model
@@ -329,7 +332,7 @@ def info(self):
329332
# Info from model
330333
result['model'] = {
331334
'type': str(type(self._model)),
332-
'predictor_type': str(type(self._get_predictor())),
335+
'predictor_type': self._get_predictor_type(),
333336
'is_explainable': self._is_explainable,
334337
'task': self.task_type(as_text=True),
335338
'family': self.family

python/model/sklearn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,12 @@ def explain(self, features, samples=None):
190190
# Explainer
191191
explainer = shap.TreeExplainer(self._get_predictor(), **params)
192192
colnames = self._feature_names()
193-
shap_values = explainer.shap_values(preprocessed[colnames].values)
193+
# This patch will ensure that the data will be fed as a pandas DataFrame
194+
# instead of as a numpy array to some models. Ex: LightGBM
195+
input_data = preprocessed[colnames]
196+
predictor_type = self._get_predictor_type()
197+
use_pandas = any(c in predictor_type for c in ('LGBMClassifier', 'LGBMRegressor'))
198+
shap_values = explainer.shap_values(input_data if use_pandas else input_data.values)
194199

195200
# Create an index to handle multiple samples input
196201
index = preprocessed.index

0 commit comments

Comments
 (0)