Skip to content

Commit

Permalink
Merge pull request TeamHG-Memex#283 from TeamHG-Memex/update-mypy
Browse files Browse the repository at this point in the history
TST upgrade mypy
  • Loading branch information
lopuhin committed Nov 19, 2018
2 parents 988f066 + 11c3ea9 commit 054e656
Show file tree
Hide file tree
Showing 18 changed files with 72 additions and 46 deletions.
1 change: 1 addition & 0 deletions eli5/_decision_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def get_top_features(weights, scale=1.0):
is_regression=is_regression,
targets=[],
)
assert explanation.targets is not None

if is_multiclass:
for label_id, label in display_names:
Expand Down
6 changes: 4 additions & 2 deletions eli5/_feature_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def filtered(self, feature_filter, x=None):
"""
indices = []
filtered_feature_names = []
indexed_names = None # type: Iterable[Tuple[int, Any]]
indexed_names = None # type: Optional[Iterable[Tuple[int, Any]]]
if isinstance(self.feature_names, (np.ndarray, list)):
indexed_names = enumerate(self.feature_names)
elif isinstance(self.feature_names, dict):
Expand All @@ -116,7 +116,8 @@ def filtered(self, feature_filter, x=None):
assert x.shape[0] == 1
flt = lambda nm, i: feature_filter(nm, x[0, i])
else:
flt = lambda nm, i: feature_filter(nm, x[i])
# FIXME: mypy warns about x[i] because it thinks x can be None
flt = lambda nm, i: feature_filter(nm, x[i]) # type: ignore
else:
flt = lambda nm, i: feature_filter(nm)

Expand All @@ -125,6 +126,7 @@ def filtered(self, feature_filter, x=None):
indices.append(idx)
filtered_feature_names.append(name)
if self.has_bias and flt(self.bias_name, self.bias_idx):
assert self.bias_idx is not None # for mypy
bias_name = self.bias_name
indices.append(self.bias_idx)
else:
Expand Down
16 changes: 8 additions & 8 deletions eli5/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ class Explanation(object):
"""
def __init__(self,
estimator, # type: str
description=None, # type: str
error=None, # type: str
method=None, # type: str
description=None, # type: Optional[str]
error=None, # type: Optional[str]
method=None, # type: Optional[str]
is_regression=False, # type: bool
targets=None, # type: List[TargetExplanation]
feature_importances=None, # type: FeatureImportances
decision_tree=None, # type: TreeInfo
targets=None, # type: Optional[List[TargetExplanation]]
feature_importances=None, # type: Optional[FeatureImportances]
decision_tree=None, # type: Optional[TreeInfo]
highlight_spaces=None, # type: Optional[bool]
transition_features=None, # type: TransitionFeatureWeights
transition_features=None, # type: Optional[TransitionFeatureWeights]
):
# type: (...) -> None
self.estimator = estimator
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self,
feature_weights, # type: FeatureWeights
proba=None, # type: float
score=None, # type: float
weighted_spans=None, # type: WeightedSpans
weighted_spans=None, # type: Optional[WeightedSpans]
):
# type: (...) -> None
self.target = target
Expand Down
2 changes: 1 addition & 1 deletion eli5/base_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ def attrs(class_):
if idx >= defaults_shift:
attrib_kwargs['default'] = init_args.defaults[idx - defaults_shift]
these[arg] = attr.ib(**attrib_kwargs)
return attr.s(class_, these=these, init=False, slots=True, **attrs_kwargs)
return attr.s(class_, these=these, init=False, slots=True, **attrs_kwargs) # type: ignore
24 changes: 14 additions & 10 deletions eli5/formatters/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,22 +143,26 @@ def render_targets_weighted_spans(
targets, # type: List[TargetExplanation]
preserve_density, # type: Optional[bool]
):
# type: (...) -> List[str]
# type: (...) -> List[Optional[str]]
""" Return a list of rendered weighted spans for targets.
Function must accept a list in order to select consistent weight
ranges across all targets.
"""
prepared_weighted_spans = prepare_weighted_spans(
targets, preserve_density)
return [
'<br/>'.join(
'{}{}'.format(
'<b>{}:</b> '.format(pws.doc_weighted_spans.vec_name)
if pws.doc_weighted_spans.vec_name else '',
render_weighted_spans(pws))
for pws in pws_lst)
if pws_lst else None
for pws_lst in prepared_weighted_spans]

def _fmt_pws(pws):
# type: (PreparedWeightedSpans) -> str
name = ('<b>{}:</b> '.format(pws.doc_weighted_spans.vec_name)
if pws.doc_weighted_spans.vec_name else '')
return '{}{}'.format(name, render_weighted_spans(pws))

def _fmt_pws_list(pws_lst):
# type: (List[PreparedWeightedSpans]) -> str
return '<br/>'.join(_fmt_pws(pws) for pws in pws_lst)

return [_fmt_pws_list(pws_lst) if pws_lst else None
for pws_lst in prepared_weighted_spans]


def render_weighted_spans(pws):
Expand Down
7 changes: 5 additions & 2 deletions eli5/formatters/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _method_lines(explanation):

def _description_lines(explanation):
# type: (Explanation) -> List[str]
return [explanation.description]
return [explanation.description or '']


def _error_lines(explanation):
Expand All @@ -117,6 +117,7 @@ def _error_lines(explanation):
def _feature_importances_lines(explanation, hl_spaces):
# type: (Explanation, Optional[bool]) -> Iterator[str]
max_width = 0
assert explanation.feature_importances is not None
for line in _fi_lines(explanation.feature_importances, hl_spaces):
max_width = max(max_width, len(line))
yield line
Expand Down Expand Up @@ -146,13 +147,15 @@ def _fi_lines(feature_importances, hl_spaces):

def _decision_tree_lines(explanation):
# type: (Explanation) -> List[str]
assert explanation.decision_tree is not None
return ["", tree2text(explanation.decision_tree)]


def _transition_features_lines(explanation):
# type: (Explanation) -> List[str]
from tabulate import tabulate # type: ignore
tf = explanation.transition_features
assert tf is not None
return [
"",
"Transition features:",
Expand All @@ -169,7 +172,7 @@ def _targets_lines(explanation, # type: Explanation
):
# type: (...) -> List[str]
lines = []

assert explanation.targets is not None
for target in explanation.targets:
scores = _format_scores(target.proba, target.score)
if scores:
Expand Down
17 changes: 11 additions & 6 deletions eli5/formatters/text_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __eq__(self, other):
def prepare_weighted_spans(targets, # type: List[TargetExplanation]
preserve_density=None, # type: Optional[bool]
):
# type: (...) -> List[List[PreparedWeightedSpans]]
# type: (...) -> List[Optional[List[PreparedWeightedSpans]]]
""" Return weighted spans prepared for rendering.
Calculate a separate weight range for each different weighted
span (for each different index): each target has the same number
Expand All @@ -67,18 +67,23 @@ def prepare_weighted_spans(targets, # type: List[TargetExplanation]
targets_char_weights = [
[get_char_weights(ws, preserve_density=preserve_density)
for ws in t.weighted_spans.docs_weighted_spans]
if t.weighted_spans else None
for t in targets] # type: List[List[np.ndarray]]
if t.weighted_spans else None
for t in targets] # type: List[Optional[List[np.ndarray]]]
max_idx = max_or_0(len(ch_w or []) for ch_w in targets_char_weights)

targets_char_weights_not_None = [
cw for cw in targets_char_weights
if cw is not None] # type: List[List[np.ndarray]]

spans_weight_ranges = [
max_or_0(
abs(x) for char_weights in targets_char_weights
for x in char_weights[idx] if char_weights is not None)
abs(x) for char_weights in targets_char_weights_not_None
for x in char_weights[idx])
for idx in range(max_idx)]
return [
[PreparedWeightedSpans(ws, char_weights, weight_range)
for ws, char_weights, weight_range in zip(
t.weighted_spans.docs_weighted_spans,
t.weighted_spans.docs_weighted_spans, # type: ignore
t_char_weights,
spans_weight_ranges)]
if t_char_weights is not None else None
Expand Down
2 changes: 2 additions & 0 deletions eli5/formatters/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def p(*args):
value_repr = _format_leaf_value(tree_obj, node)
parts.append(" ---> {}".format(value_repr))
else:
assert node.left is not None
assert node.right is not None
feat_name = node.feature_name

if depth > 0:
Expand Down
10 changes: 7 additions & 3 deletions eli5/formatters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def format_signed(feature, # type: Dict[str, Any]

def should_highlight_spaces(explanation):
# type: (Explanation) -> bool
hl_spaces = explanation.highlight_spaces
hl_spaces = bool(explanation.highlight_spaces)
if explanation.feature_importances:
hl_spaces = hl_spaces or any(
_has_invisible_spaces(fw.feature)
Expand Down Expand Up @@ -97,7 +97,7 @@ def has_any_values_for_weights(explanation):


def tabulate(data, # type: List[List[Any]]
header=None, # type: List[Any]
header=None, # type: Optional[List[Any]]
col_align=None, # type: Union[str, List[str]]
):
# type: (...) -> List[str]
Expand All @@ -107,7 +107,11 @@ def tabulate(data, # type: List[List[Any]]
"""
if not data and not header:
return []
n_cols = len(data[0] if data else header)
if data:
n_cols = len(data[0])
else:
assert header is not None
n_cols = len(header)
if not all(len(row) == n_cols for row in data):
raise ValueError('data is not rectangular')

Expand Down
4 changes: 2 additions & 2 deletions eli5/lightgbm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division
from collections import defaultdict
from typing import DefaultDict
from typing import DefaultDict, Optional

import numpy as np # type: ignore
import lightgbm # type: ignore
Expand Down Expand Up @@ -253,7 +253,7 @@ def _get_prediction_feature_weights(lgb, X, n_targets):

res = []
for target in range(n_targets):
feature_weights = defaultdict(float) # type: DefaultDict[str, float]
feature_weights = defaultdict(float) # type: DefaultDict[Optional[str], float]
for info, leaf_id in zip(tree_info[:, target], pred_leafs[:, target]):
leaf_index, split_index = _get_leaf_split_indices(
info['tree_structure']
Expand Down
4 changes: 2 additions & 2 deletions eli5/lime/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(self,

if char_based is None:
if token_pattern is None:
self.char_based = False
self.char_based = False # type: Optional[bool]
self.token_pattern = DEFAULT_TOKEN_PATTERN
else:
self.char_based = None
Expand Down Expand Up @@ -335,7 +335,7 @@ def _train_local_classifier(estimator,
samples,
similarity, # type: np.ndarray
y_proba, # type: np.ndarray
expand_factor=10, # type: int
expand_factor=10, # type: Optional[int]
test_size=0.3, # type: float
random_state=None,
):
Expand Down
7 changes: 4 additions & 3 deletions eli5/lime/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import absolute_import
import abc
from functools import partial
from typing import List, Tuple, Any, Union, Dict
from typing import List, Tuple, Any, Union, Dict, Optional
import six

import numpy as np # type: ignore
Expand Down Expand Up @@ -68,7 +68,7 @@ class MaskingTextSampler(BaseSampler):
Default is 1, meaning individual tokens are replaced.
"""
def __init__(self,
token_pattern=None, # type: str
token_pattern=None, # type: Optional[str]
bow=True, # type: bool
random_state=None,
replacement='', # type: str
Expand Down Expand Up @@ -127,7 +127,7 @@ class MaskingTextSamplers(BaseSampler):
"""
def __init__(self,
sampler_params, # type: List[Dict[str, Any]]
token_pattern=None, # type: str
token_pattern=None, # type: Optional[str]
random_state=None,
weights=None, # type: Union[np.ndarray, List[float]]
):
Expand Down Expand Up @@ -168,6 +168,7 @@ def sample_near_with_mask(self,
):
# type: (...) -> Tuple[List[str], np.ndarray, np.ndarray, TokenizedText]
assert n_samples >= 1
assert self.token_pattern is not None
text = TokenizedText(doc, token_pattern=self.token_pattern)
all_docs = [] # type: List[str]
similarities = []
Expand Down
4 changes: 2 additions & 2 deletions eli5/lime/textutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import absolute_import
import re
import math
from typing import List, Tuple, Union
from typing import List, Tuple, Union, Optional

import numpy as np # type: ignore
from sklearn.utils import check_random_state # type: ignore
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(self, text, token_pattern=DEFAULT_TOKEN_PATTERN):
# type: (str, str) -> None
self.text = text
self.split = SplitResult.fromtext(text, token_pattern)
self._vocab = None # type: List[str]
self._vocab = None # type: Optional[List[str]]

def replace_random_tokens(self,
n_samples, # type: int
Expand Down
4 changes: 4 additions & 0 deletions eli5/sklearn/explain_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def explain_prediction_linear_classifier(clf, doc,
method='linear model',
targets=[],
)
assert res.targets is not None

_weights = _linear_weights(clf, x, top, feature_names, flt_indices)
classes = getattr(clf, "classes_", ["-1", "1"]) # OneClassSVM support
Expand Down Expand Up @@ -302,6 +303,7 @@ def explain_prediction_linear_regressor(reg, doc,
targets=[],
is_regression=True,
)
assert res.targets is not None

_weights = _linear_weights(reg, x, top, feature_names, flt_indices)
names = get_default_target_names(reg)
Expand Down Expand Up @@ -426,6 +428,7 @@ def _weights(label_id, scale=1.0):
description=(DESCRIPTION_TREE_CLF_MULTICLASS if is_multiclass
else DESCRIPTION_TREE_CLF_BINARY),
)
assert res.targets is not None

display_names = get_target_display_names(
clf.classes_, target_names, targets, top_targets,
Expand Down Expand Up @@ -524,6 +527,7 @@ def _weights(label_id, scale=1.0):
targets=[],
is_regression=True,
)
assert res.targets is not None

names = get_default_target_names(reg, num_targets=num_targets)
display_names = get_target_display_names(names, target_names, targets,
Expand Down
4 changes: 2 additions & 2 deletions eli5/sklearn/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def add_weighted_spans(doc, vec, vectorized, target_expl):
def _get_doc_weighted_spans(doc,
vec,
feature_weights, # type: FeatureWeights
feature_fn=None # type: Callable[[str], str]
feature_fn=None # type: Optional[Callable[[str], str]]
):
# type: (...) -> Optional[Tuple[FoundFeatures, DocWeightedSpans]]
if isinstance(vec, InvertableHashingVectorizer):
Expand Down Expand Up @@ -85,7 +85,7 @@ def _get_doc_weighted_spans(doc,


def _get_feature_weights_dict(feature_weights, # type: FeatureWeights
feature_fn # type: Callable[[str], str]
feature_fn # type: Optional[Callable[[str], str]]
):
# type: (...) -> Dict[str, Tuple[float, Tuple[str, int]]]
""" Return {feat_name: (weight, (group, idx))} mapping. """
Expand Down
2 changes: 1 addition & 1 deletion eli5/sklearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def has_intercept(estimator):

def get_feature_names(clf, vec=None, bias_name='<BIAS>', feature_names=None,
num_features=None, estimator_feature_names=None):
# type: (Any, Any, str, Any, int, Any) -> FeatureNames
# type: (Any, Any, Optional[str], Any, int, Any) -> FeatureNames
"""
Return a FeatureNames instance that holds all feature names
and a bias feature.
Expand Down
2 changes: 1 addition & 1 deletion eli5/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def explain_prediction_xgboost(


def _check_booster_args(xgb, is_regression=None):
# type: (Any, bool) -> Tuple[Booster, bool]
# type: (Any, Optional[bool]) -> Tuple[Booster, Optional[bool]]
if isinstance(xgb, Booster):
booster = xgb
else:
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ commands={[testenv:py35-extra]commands}
basepython=python3.6
deps=
{[testenv]deps}
mypy == 0.550
mypy == 0.641
lxml
commands=
mypy --html-report ./mypy-cov --check-untyped-defs eli5
Expand Down

0 comments on commit 054e656

Please sign in to comment.