diff --git a/eli5/_decision_path.py b/eli5/_decision_path.py
index cd47828a..f7e6bb7b 100644
--- a/eli5/_decision_path.py
+++ b/eli5/_decision_path.py
@@ -80,6 +80,7 @@ def get_top_features(weights, scale=1.0):
is_regression=is_regression,
targets=[],
)
+ assert explanation.targets is not None
if is_multiclass:
for label_id, label in display_names:
diff --git a/eli5/_feature_names.py b/eli5/_feature_names.py
index cfffa83d..fecb820f 100644
--- a/eli5/_feature_names.py
+++ b/eli5/_feature_names.py
@@ -102,7 +102,7 @@ def filtered(self, feature_filter, x=None):
"""
indices = []
filtered_feature_names = []
- indexed_names = None # type: Iterable[Tuple[int, Any]]
+ indexed_names = None # type: Optional[Iterable[Tuple[int, Any]]]
if isinstance(self.feature_names, (np.ndarray, list)):
indexed_names = enumerate(self.feature_names)
elif isinstance(self.feature_names, dict):
@@ -116,7 +116,8 @@ def filtered(self, feature_filter, x=None):
assert x.shape[0] == 1
flt = lambda nm, i: feature_filter(nm, x[0, i])
else:
- flt = lambda nm, i: feature_filter(nm, x[i])
+ # FIXME: mypy warns about x[i] because it thinks x can be None
+ flt = lambda nm, i: feature_filter(nm, x[i]) # type: ignore
else:
flt = lambda nm, i: feature_filter(nm)
@@ -125,6 +126,7 @@ def filtered(self, feature_filter, x=None):
indices.append(idx)
filtered_feature_names.append(name)
if self.has_bias and flt(self.bias_name, self.bias_idx):
+ assert self.bias_idx is not None # for mypy
bias_name = self.bias_name
indices.append(self.bias_idx)
else:
diff --git a/eli5/base.py b/eli5/base.py
index cdc3b0a7..a029d43f 100644
--- a/eli5/base.py
+++ b/eli5/base.py
@@ -16,15 +16,15 @@ class Explanation(object):
"""
def __init__(self,
estimator, # type: str
- description=None, # type: str
- error=None, # type: str
- method=None, # type: str
+ description=None, # type: Optional[str]
+ error=None, # type: Optional[str]
+ method=None, # type: Optional[str]
is_regression=False, # type: bool
- targets=None, # type: List[TargetExplanation]
- feature_importances=None, # type: FeatureImportances
- decision_tree=None, # type: TreeInfo
+ targets=None, # type: Optional[List[TargetExplanation]]
+ feature_importances=None, # type: Optional[FeatureImportances]
+ decision_tree=None, # type: Optional[TreeInfo]
highlight_spaces=None, # type: Optional[bool]
- transition_features=None, # type: TransitionFeatureWeights
+ transition_features=None, # type: Optional[TransitionFeatureWeights]
):
# type: (...) -> None
self.estimator = estimator
@@ -73,7 +73,7 @@ def __init__(self,
feature_weights, # type: FeatureWeights
proba=None, # type: float
score=None, # type: float
- weighted_spans=None, # type: WeightedSpans
+ weighted_spans=None, # type: Optional[WeightedSpans]
):
# type: (...) -> None
self.target = target
diff --git a/eli5/base_utils.py b/eli5/base_utils.py
index 4f3f22ca..779c6d64 100644
--- a/eli5/base_utils.py
+++ b/eli5/base_utils.py
@@ -33,4 +33,4 @@ def attrs(class_):
if idx >= defaults_shift:
attrib_kwargs['default'] = init_args.defaults[idx - defaults_shift]
these[arg] = attr.ib(**attrib_kwargs)
- return attr.s(class_, these=these, init=False, slots=True, **attrs_kwargs)
+ return attr.s(class_, these=these, init=False, slots=True, **attrs_kwargs) # type: ignore
diff --git a/eli5/formatters/html.py b/eli5/formatters/html.py
index 10acaf66..5d145748 100644
--- a/eli5/formatters/html.py
+++ b/eli5/formatters/html.py
@@ -143,22 +143,26 @@ def render_targets_weighted_spans(
targets, # type: List[TargetExplanation]
preserve_density, # type: Optional[bool]
):
- # type: (...) -> List[str]
+ # type: (...) -> List[Optional[str]]
""" Return a list of rendered weighted spans for targets.
Function must accept a list in order to select consistent weight
ranges across all targets.
"""
prepared_weighted_spans = prepare_weighted_spans(
targets, preserve_density)
- return [
- '
'.join(
- '{}{}'.format(
- '{}: '.format(pws.doc_weighted_spans.vec_name)
- if pws.doc_weighted_spans.vec_name else '',
- render_weighted_spans(pws))
- for pws in pws_lst)
- if pws_lst else None
- for pws_lst in prepared_weighted_spans]
+
+ def _fmt_pws(pws):
+ # type: (PreparedWeightedSpans) -> str
+ name = ('{}: '.format(pws.doc_weighted_spans.vec_name)
+ if pws.doc_weighted_spans.vec_name else '')
+ return '{}{}'.format(name, render_weighted_spans(pws))
+
+ def _fmt_pws_list(pws_lst):
+ # type: (List[PreparedWeightedSpans]) -> str
+ return '
'.join(_fmt_pws(pws) for pws in pws_lst)
+
+ return [_fmt_pws_list(pws_lst) if pws_lst else None
+ for pws_lst in prepared_weighted_spans]
def render_weighted_spans(pws):
diff --git a/eli5/formatters/text.py b/eli5/formatters/text.py
index 2ef52c58..44dcbdec 100644
--- a/eli5/formatters/text.py
+++ b/eli5/formatters/text.py
@@ -106,7 +106,7 @@ def _method_lines(explanation):
def _description_lines(explanation):
# type: (Explanation) -> List[str]
- return [explanation.description]
+ return [explanation.description or '']
def _error_lines(explanation):
@@ -117,6 +117,7 @@ def _error_lines(explanation):
def _feature_importances_lines(explanation, hl_spaces):
# type: (Explanation, Optional[bool]) -> Iterator[str]
max_width = 0
+ assert explanation.feature_importances is not None
for line in _fi_lines(explanation.feature_importances, hl_spaces):
max_width = max(max_width, len(line))
yield line
@@ -146,6 +147,7 @@ def _fi_lines(feature_importances, hl_spaces):
def _decision_tree_lines(explanation):
# type: (Explanation) -> List[str]
+ assert explanation.decision_tree is not None
return ["", tree2text(explanation.decision_tree)]
@@ -153,6 +155,7 @@ def _transition_features_lines(explanation):
# type: (Explanation) -> List[str]
from tabulate import tabulate # type: ignore
tf = explanation.transition_features
+ assert tf is not None
return [
"",
"Transition features:",
@@ -169,7 +172,7 @@ def _targets_lines(explanation, # type: Explanation
):
# type: (...) -> List[str]
lines = []
-
+ assert explanation.targets is not None
for target in explanation.targets:
scores = _format_scores(target.proba, target.score)
if scores:
diff --git a/eli5/formatters/text_helpers.py b/eli5/formatters/text_helpers.py
index f67e9c28..dc5ff28a 100644
--- a/eli5/formatters/text_helpers.py
+++ b/eli5/formatters/text_helpers.py
@@ -58,7 +58,7 @@ def __eq__(self, other):
def prepare_weighted_spans(targets, # type: List[TargetExplanation]
preserve_density=None, # type: Optional[bool]
):
- # type: (...) -> List[List[PreparedWeightedSpans]]
+ # type: (...) -> List[Optional[List[PreparedWeightedSpans]]]
""" Return weighted spans prepared for rendering.
Calculate a separate weight range for each different weighted
span (for each different index): each target has the same number
@@ -67,18 +67,23 @@ def prepare_weighted_spans(targets, # type: List[TargetExplanation]
targets_char_weights = [
[get_char_weights(ws, preserve_density=preserve_density)
for ws in t.weighted_spans.docs_weighted_spans]
- if t.weighted_spans else None
- for t in targets] # type: List[List[np.ndarray]]
+ if t.weighted_spans else None
+ for t in targets] # type: List[Optional[List[np.ndarray]]]
max_idx = max_or_0(len(ch_w or []) for ch_w in targets_char_weights)
+
+ targets_char_weights_not_None = [
+ cw for cw in targets_char_weights
+ if cw is not None] # type: List[List[np.ndarray]]
+
spans_weight_ranges = [
max_or_0(
- abs(x) for char_weights in targets_char_weights
- for x in char_weights[idx] if char_weights is not None)
+ abs(x) for char_weights in targets_char_weights_not_None
+ for x in char_weights[idx])
for idx in range(max_idx)]
return [
[PreparedWeightedSpans(ws, char_weights, weight_range)
for ws, char_weights, weight_range in zip(
- t.weighted_spans.docs_weighted_spans,
+ t.weighted_spans.docs_weighted_spans, # type: ignore
t_char_weights,
spans_weight_ranges)]
if t_char_weights is not None else None
diff --git a/eli5/formatters/trees.py b/eli5/formatters/trees.py
index c133c8c1..7be1e9c3 100644
--- a/eli5/formatters/trees.py
+++ b/eli5/formatters/trees.py
@@ -22,6 +22,8 @@ def p(*args):
value_repr = _format_leaf_value(tree_obj, node)
parts.append(" ---> {}".format(value_repr))
else:
+ assert node.left is not None
+ assert node.right is not None
feat_name = node.feature_name
if depth > 0:
diff --git a/eli5/formatters/utils.py b/eli5/formatters/utils.py
index 148b7bde..ab8b5b45 100644
--- a/eli5/formatters/utils.py
+++ b/eli5/formatters/utils.py
@@ -62,7 +62,7 @@ def format_signed(feature, # type: Dict[str, Any]
def should_highlight_spaces(explanation):
# type: (Explanation) -> bool
- hl_spaces = explanation.highlight_spaces
+ hl_spaces = bool(explanation.highlight_spaces)
if explanation.feature_importances:
hl_spaces = hl_spaces or any(
_has_invisible_spaces(fw.feature)
@@ -97,7 +97,7 @@ def has_any_values_for_weights(explanation):
def tabulate(data, # type: List[List[Any]]
- header=None, # type: List[Any]
+ header=None, # type: Optional[List[Any]]
col_align=None, # type: Union[str, List[str]]
):
# type: (...) -> List[str]
@@ -107,7 +107,11 @@ def tabulate(data, # type: List[List[Any]]
"""
if not data and not header:
return []
- n_cols = len(data[0] if data else header)
+ if data:
+ n_cols = len(data[0])
+ else:
+ assert header is not None
+ n_cols = len(header)
if not all(len(row) == n_cols for row in data):
raise ValueError('data is not rectangular')
diff --git a/eli5/lightgbm.py b/eli5/lightgbm.py
index 02ea4411..b4510912 100644
--- a/eli5/lightgbm.py
+++ b/eli5/lightgbm.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division
from collections import defaultdict
-from typing import DefaultDict
+from typing import DefaultDict, Optional
import numpy as np # type: ignore
import lightgbm # type: ignore
@@ -253,7 +253,7 @@ def _get_prediction_feature_weights(lgb, X, n_targets):
res = []
for target in range(n_targets):
- feature_weights = defaultdict(float) # type: DefaultDict[str, float]
+ feature_weights = defaultdict(float) # type: DefaultDict[Optional[str], float]
for info, leaf_id in zip(tree_info[:, target], pred_leafs[:, target]):
leaf_index, split_index = _get_leaf_split_indices(
info['tree_structure']
diff --git a/eli5/lime/lime.py b/eli5/lime/lime.py
index 9ab5acfc..924675bb 100644
--- a/eli5/lime/lime.py
+++ b/eli5/lime/lime.py
@@ -161,7 +161,7 @@ def __init__(self,
if char_based is None:
if token_pattern is None:
- self.char_based = False
+ self.char_based = False # type: Optional[bool]
self.token_pattern = DEFAULT_TOKEN_PATTERN
else:
self.char_based = None
@@ -335,7 +335,7 @@ def _train_local_classifier(estimator,
samples,
similarity, # type: np.ndarray
y_proba, # type: np.ndarray
- expand_factor=10, # type: int
+ expand_factor=10, # type: Optional[int]
test_size=0.3, # type: float
random_state=None,
):
diff --git a/eli5/lime/samplers.py b/eli5/lime/samplers.py
index 3a82eff1..ff72f568 100644
--- a/eli5/lime/samplers.py
+++ b/eli5/lime/samplers.py
@@ -2,7 +2,7 @@
from __future__ import absolute_import
import abc
from functools import partial
-from typing import List, Tuple, Any, Union, Dict
+from typing import List, Tuple, Any, Union, Dict, Optional
import six
import numpy as np # type: ignore
@@ -68,7 +68,7 @@ class MaskingTextSampler(BaseSampler):
Default is 1, meaning individual tokens are replaced.
"""
def __init__(self,
- token_pattern=None, # type: str
+ token_pattern=None, # type: Optional[str]
bow=True, # type: bool
random_state=None,
replacement='', # type: str
@@ -127,7 +127,7 @@ class MaskingTextSamplers(BaseSampler):
"""
def __init__(self,
sampler_params, # type: List[Dict[str, Any]]
- token_pattern=None, # type: str
+ token_pattern=None, # type: Optional[str]
random_state=None,
weights=None, # type: Union[np.ndarray, List[float]]
):
@@ -168,6 +168,7 @@ def sample_near_with_mask(self,
):
# type: (...) -> Tuple[List[str], np.ndarray, np.ndarray, TokenizedText]
assert n_samples >= 1
+ assert self.token_pattern is not None
text = TokenizedText(doc, token_pattern=self.token_pattern)
all_docs = [] # type: List[str]
similarities = []
diff --git a/eli5/lime/textutils.py b/eli5/lime/textutils.py
index dc940860..e896f347 100644
--- a/eli5/lime/textutils.py
+++ b/eli5/lime/textutils.py
@@ -5,7 +5,7 @@
from __future__ import absolute_import
import re
import math
-from typing import List, Tuple, Union
+from typing import List, Tuple, Union, Optional
import numpy as np # type: ignore
from sklearn.utils import check_random_state # type: ignore
@@ -70,7 +70,7 @@ def __init__(self, text, token_pattern=DEFAULT_TOKEN_PATTERN):
# type: (str, str) -> None
self.text = text
self.split = SplitResult.fromtext(text, token_pattern)
- self._vocab = None # type: List[str]
+ self._vocab = None # type: Optional[List[str]]
def replace_random_tokens(self,
n_samples, # type: int
diff --git a/eli5/sklearn/explain_prediction.py b/eli5/sklearn/explain_prediction.py
index d6ec0938..7de18783 100644
--- a/eli5/sklearn/explain_prediction.py
+++ b/eli5/sklearn/explain_prediction.py
@@ -184,6 +184,7 @@ def explain_prediction_linear_classifier(clf, doc,
method='linear model',
targets=[],
)
+ assert res.targets is not None
_weights = _linear_weights(clf, x, top, feature_names, flt_indices)
classes = getattr(clf, "classes_", ["-1", "1"]) # OneClassSVM support
@@ -302,6 +303,7 @@ def explain_prediction_linear_regressor(reg, doc,
targets=[],
is_regression=True,
)
+ assert res.targets is not None
_weights = _linear_weights(reg, x, top, feature_names, flt_indices)
names = get_default_target_names(reg)
@@ -426,6 +428,7 @@ def _weights(label_id, scale=1.0):
description=(DESCRIPTION_TREE_CLF_MULTICLASS if is_multiclass
else DESCRIPTION_TREE_CLF_BINARY),
)
+ assert res.targets is not None
display_names = get_target_display_names(
clf.classes_, target_names, targets, top_targets,
@@ -524,6 +527,7 @@ def _weights(label_id, scale=1.0):
targets=[],
is_regression=True,
)
+ assert res.targets is not None
names = get_default_target_names(reg, num_targets=num_targets)
display_names = get_target_display_names(names, target_names, targets,
diff --git a/eli5/sklearn/text.py b/eli5/sklearn/text.py
index 23c5f870..8b961c88 100644
--- a/eli5/sklearn/text.py
+++ b/eli5/sklearn/text.py
@@ -49,7 +49,7 @@ def add_weighted_spans(doc, vec, vectorized, target_expl):
def _get_doc_weighted_spans(doc,
vec,
feature_weights, # type: FeatureWeights
- feature_fn=None # type: Callable[[str], str]
+ feature_fn=None # type: Optional[Callable[[str], str]]
):
# type: (...) -> Optional[Tuple[FoundFeatures, DocWeightedSpans]]
if isinstance(vec, InvertableHashingVectorizer):
@@ -85,7 +85,7 @@ def _get_doc_weighted_spans(doc,
def _get_feature_weights_dict(feature_weights, # type: FeatureWeights
- feature_fn # type: Callable[[str], str]
+ feature_fn # type: Optional[Callable[[str], str]]
):
# type: (...) -> Dict[str, Tuple[float, Tuple[str, int]]]
""" Return {feat_name: (weight, (group, idx))} mapping. """
diff --git a/eli5/sklearn/utils.py b/eli5/sklearn/utils.py
index 45ea6466..ba3680aa 100644
--- a/eli5/sklearn/utils.py
+++ b/eli5/sklearn/utils.py
@@ -69,7 +69,7 @@ def has_intercept(estimator):
def get_feature_names(clf, vec=None, bias_name='', feature_names=None,
num_features=None, estimator_feature_names=None):
- # type: (Any, Any, str, Any, int, Any) -> FeatureNames
+ # type: (Any, Any, Optional[str], Any, int, Any) -> FeatureNames
"""
Return a FeatureNames instance that holds all feature names
and a bias feature.
diff --git a/eli5/xgboost.py b/eli5/xgboost.py
index 56db0504..86d3ad58 100644
--- a/eli5/xgboost.py
+++ b/eli5/xgboost.py
@@ -218,7 +218,7 @@ def explain_prediction_xgboost(
def _check_booster_args(xgb, is_regression=None):
- # type: (Any, bool) -> Tuple[Booster, bool]
+ # type: (Any, Optional[bool]) -> Tuple[Booster, Optional[bool]]
if isinstance(xgb, Booster):
booster = xgb
else:
diff --git a/tox.ini b/tox.ini
index c8054c69..fe7c4648 100644
--- a/tox.ini
+++ b/tox.ini
@@ -79,7 +79,7 @@ commands={[testenv:py35-extra]commands}
basepython=python3.6
deps=
{[testenv]deps}
- mypy == 0.550
+ mypy == 0.641
lxml
commands=
mypy --html-report ./mypy-cov --check-untyped-defs eli5