diff --git a/eli5/_decision_path.py b/eli5/_decision_path.py index cd47828a..f7e6bb7b 100644 --- a/eli5/_decision_path.py +++ b/eli5/_decision_path.py @@ -80,6 +80,7 @@ def get_top_features(weights, scale=1.0): is_regression=is_regression, targets=[], ) + assert explanation.targets is not None if is_multiclass: for label_id, label in display_names: diff --git a/eli5/_feature_names.py b/eli5/_feature_names.py index cfffa83d..fecb820f 100644 --- a/eli5/_feature_names.py +++ b/eli5/_feature_names.py @@ -102,7 +102,7 @@ def filtered(self, feature_filter, x=None): """ indices = [] filtered_feature_names = [] - indexed_names = None # type: Iterable[Tuple[int, Any]] + indexed_names = None # type: Optional[Iterable[Tuple[int, Any]]] if isinstance(self.feature_names, (np.ndarray, list)): indexed_names = enumerate(self.feature_names) elif isinstance(self.feature_names, dict): @@ -116,7 +116,8 @@ def filtered(self, feature_filter, x=None): assert x.shape[0] == 1 flt = lambda nm, i: feature_filter(nm, x[0, i]) else: - flt = lambda nm, i: feature_filter(nm, x[i]) + # FIXME: mypy warns about x[i] because it thinks x can be None + flt = lambda nm, i: feature_filter(nm, x[i]) # type: ignore else: flt = lambda nm, i: feature_filter(nm) @@ -125,6 +126,7 @@ def filtered(self, feature_filter, x=None): indices.append(idx) filtered_feature_names.append(name) if self.has_bias and flt(self.bias_name, self.bias_idx): + assert self.bias_idx is not None # for mypy bias_name = self.bias_name indices.append(self.bias_idx) else: diff --git a/eli5/base.py b/eli5/base.py index cdc3b0a7..a029d43f 100644 --- a/eli5/base.py +++ b/eli5/base.py @@ -16,15 +16,15 @@ class Explanation(object): """ def __init__(self, estimator, # type: str - description=None, # type: str - error=None, # type: str - method=None, # type: str + description=None, # type: Optional[str] + error=None, # type: Optional[str] + method=None, # type: Optional[str] is_regression=False, # type: bool - targets=None, # type: List[TargetExplanation] - feature_importances=None, # type: FeatureImportances - decision_tree=None, # type: TreeInfo + targets=None, # type: Optional[List[TargetExplanation]] + feature_importances=None, # type: Optional[FeatureImportances] + decision_tree=None, # type: Optional[TreeInfo] highlight_spaces=None, # type: Optional[bool] - transition_features=None, # type: TransitionFeatureWeights + transition_features=None, # type: Optional[TransitionFeatureWeights] ): # type: (...) -> None self.estimator = estimator @@ -73,7 +73,7 @@ def __init__(self, feature_weights, # type: FeatureWeights proba=None, # type: float score=None, # type: float - weighted_spans=None, # type: WeightedSpans + weighted_spans=None, # type: Optional[WeightedSpans] ): # type: (...) -> None self.target = target diff --git a/eli5/base_utils.py b/eli5/base_utils.py index 4f3f22ca..779c6d64 100644 --- a/eli5/base_utils.py +++ b/eli5/base_utils.py @@ -33,4 +33,4 @@ def attrs(class_): if idx >= defaults_shift: attrib_kwargs['default'] = init_args.defaults[idx - defaults_shift] these[arg] = attr.ib(**attrib_kwargs) - return attr.s(class_, these=these, init=False, slots=True, **attrs_kwargs) + return attr.s(class_, these=these, init=False, slots=True, **attrs_kwargs) # type: ignore diff --git a/eli5/formatters/html.py b/eli5/formatters/html.py index 10acaf66..5d145748 100644 --- a/eli5/formatters/html.py +++ b/eli5/formatters/html.py @@ -143,22 +143,26 @@ def render_targets_weighted_spans( targets, # type: List[TargetExplanation] preserve_density, # type: Optional[bool] ): - # type: (...) -> List[str] + # type: (...) -> List[Optional[str]] """ Return a list of rendered weighted spans for targets. Function must accept a list in order to select consistent weight ranges across all targets. """ prepared_weighted_spans = prepare_weighted_spans( targets, preserve_density) - return [ - '
'.join( - '{}{}'.format( - '{}: '.format(pws.doc_weighted_spans.vec_name) - if pws.doc_weighted_spans.vec_name else '', - render_weighted_spans(pws)) - for pws in pws_lst) - if pws_lst else None - for pws_lst in prepared_weighted_spans] + + def _fmt_pws(pws): + # type: (PreparedWeightedSpans) -> str + name = ('{}: '.format(pws.doc_weighted_spans.vec_name) + if pws.doc_weighted_spans.vec_name else '') + return '{}{}'.format(name, render_weighted_spans(pws)) + + def _fmt_pws_list(pws_lst): + # type: (List[PreparedWeightedSpans]) -> str + return '
'.join(_fmt_pws(pws) for pws in pws_lst) + + return [_fmt_pws_list(pws_lst) if pws_lst else None + for pws_lst in prepared_weighted_spans] def render_weighted_spans(pws): diff --git a/eli5/formatters/text.py b/eli5/formatters/text.py index 2ef52c58..44dcbdec 100644 --- a/eli5/formatters/text.py +++ b/eli5/formatters/text.py @@ -106,7 +106,7 @@ def _method_lines(explanation): def _description_lines(explanation): # type: (Explanation) -> List[str] - return [explanation.description] + return [explanation.description or ''] def _error_lines(explanation): @@ -117,6 +117,7 @@ def _error_lines(explanation): def _feature_importances_lines(explanation, hl_spaces): # type: (Explanation, Optional[bool]) -> Iterator[str] max_width = 0 + assert explanation.feature_importances is not None for line in _fi_lines(explanation.feature_importances, hl_spaces): max_width = max(max_width, len(line)) yield line @@ -146,6 +147,7 @@ def _fi_lines(feature_importances, hl_spaces): def _decision_tree_lines(explanation): # type: (Explanation) -> List[str] + assert explanation.decision_tree is not None return ["", tree2text(explanation.decision_tree)] @@ -153,6 +155,7 @@ def _transition_features_lines(explanation): # type: (Explanation) -> List[str] from tabulate import tabulate # type: ignore tf = explanation.transition_features + assert tf is not None return [ "", "Transition features:", @@ -169,7 +172,7 @@ def _targets_lines(explanation, # type: Explanation ): # type: (...) -> List[str] lines = [] - + assert explanation.targets is not None for target in explanation.targets: scores = _format_scores(target.proba, target.score) if scores: diff --git a/eli5/formatters/text_helpers.py b/eli5/formatters/text_helpers.py index f67e9c28..dc5ff28a 100644 --- a/eli5/formatters/text_helpers.py +++ b/eli5/formatters/text_helpers.py @@ -58,7 +58,7 @@ def __eq__(self, other): def prepare_weighted_spans(targets, # type: List[TargetExplanation] preserve_density=None, # type: Optional[bool] ): - # type: (...) -> List[List[PreparedWeightedSpans]] + # type: (...) -> List[Optional[List[PreparedWeightedSpans]]] """ Return weighted spans prepared for rendering. Calculate a separate weight range for each different weighted span (for each different index): each target has the same number @@ -67,18 +67,23 @@ def prepare_weighted_spans(targets, # type: List[TargetExplanation] targets_char_weights = [ [get_char_weights(ws, preserve_density=preserve_density) for ws in t.weighted_spans.docs_weighted_spans] - if t.weighted_spans else None - for t in targets] # type: List[List[np.ndarray]] + if t.weighted_spans else None + for t in targets] # type: List[Optional[List[np.ndarray]]] max_idx = max_or_0(len(ch_w or []) for ch_w in targets_char_weights) + + targets_char_weights_not_None = [ + cw for cw in targets_char_weights + if cw is not None] # type: List[List[np.ndarray]] + spans_weight_ranges = [ max_or_0( - abs(x) for char_weights in targets_char_weights - for x in char_weights[idx] if char_weights is not None) + abs(x) for char_weights in targets_char_weights_not_None + for x in char_weights[idx]) for idx in range(max_idx)] return [ [PreparedWeightedSpans(ws, char_weights, weight_range) for ws, char_weights, weight_range in zip( - t.weighted_spans.docs_weighted_spans, + t.weighted_spans.docs_weighted_spans, # type: ignore t_char_weights, spans_weight_ranges)] if t_char_weights is not None else None diff --git a/eli5/formatters/trees.py b/eli5/formatters/trees.py index c133c8c1..7be1e9c3 100644 --- a/eli5/formatters/trees.py +++ b/eli5/formatters/trees.py @@ -22,6 +22,8 @@ def p(*args): value_repr = _format_leaf_value(tree_obj, node) parts.append(" ---> {}".format(value_repr)) else: + assert node.left is not None + assert node.right is not None feat_name = node.feature_name if depth > 0: diff --git a/eli5/formatters/utils.py b/eli5/formatters/utils.py index 148b7bde..ab8b5b45 100644 --- a/eli5/formatters/utils.py +++ b/eli5/formatters/utils.py @@ -62,7 +62,7 @@ def format_signed(feature, # type: Dict[str, Any] def should_highlight_spaces(explanation): # type: (Explanation) -> bool - hl_spaces = explanation.highlight_spaces + hl_spaces = bool(explanation.highlight_spaces) if explanation.feature_importances: hl_spaces = hl_spaces or any( _has_invisible_spaces(fw.feature) @@ -97,7 +97,7 @@ def has_any_values_for_weights(explanation): def tabulate(data, # type: List[List[Any]] - header=None, # type: List[Any] + header=None, # type: Optional[List[Any]] col_align=None, # type: Union[str, List[str]] ): # type: (...) -> List[str] @@ -107,7 +107,11 @@ def tabulate(data, # type: List[List[Any]] """ if not data and not header: return [] - n_cols = len(data[0] if data else header) + if data: + n_cols = len(data[0]) + else: + assert header is not None + n_cols = len(header) if not all(len(row) == n_cols for row in data): raise ValueError('data is not rectangular') diff --git a/eli5/lightgbm.py b/eli5/lightgbm.py index 02ea4411..b4510912 100644 --- a/eli5/lightgbm.py +++ b/eli5/lightgbm.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division from collections import defaultdict -from typing import DefaultDict +from typing import DefaultDict, Optional import numpy as np # type: ignore import lightgbm # type: ignore @@ -253,7 +253,7 @@ def _get_prediction_feature_weights(lgb, X, n_targets): res = [] for target in range(n_targets): - feature_weights = defaultdict(float) # type: DefaultDict[str, float] + feature_weights = defaultdict(float) # type: DefaultDict[Optional[str], float] for info, leaf_id in zip(tree_info[:, target], pred_leafs[:, target]): leaf_index, split_index = _get_leaf_split_indices( info['tree_structure'] diff --git a/eli5/lime/lime.py b/eli5/lime/lime.py index 9ab5acfc..924675bb 100644 --- a/eli5/lime/lime.py +++ b/eli5/lime/lime.py @@ -161,7 +161,7 @@ def __init__(self, if char_based is None: if token_pattern is None: - self.char_based = False + self.char_based = False # type: Optional[bool] self.token_pattern = DEFAULT_TOKEN_PATTERN else: self.char_based = None @@ -335,7 +335,7 @@ def _train_local_classifier(estimator, samples, similarity, # type: np.ndarray y_proba, # type: np.ndarray - expand_factor=10, # type: int + expand_factor=10, # type: Optional[int] test_size=0.3, # type: float random_state=None, ): diff --git a/eli5/lime/samplers.py b/eli5/lime/samplers.py index 3a82eff1..ff72f568 100644 --- a/eli5/lime/samplers.py +++ b/eli5/lime/samplers.py @@ -2,7 +2,7 @@ from __future__ import absolute_import import abc from functools import partial -from typing import List, Tuple, Any, Union, Dict +from typing import List, Tuple, Any, Union, Dict, Optional import six import numpy as np # type: ignore @@ -68,7 +68,7 @@ class MaskingTextSampler(BaseSampler): Default is 1, meaning individual tokens are replaced. """ def __init__(self, - token_pattern=None, # type: str + token_pattern=None, # type: Optional[str] bow=True, # type: bool random_state=None, replacement='', # type: str @@ -127,7 +127,7 @@ class MaskingTextSamplers(BaseSampler): """ def __init__(self, sampler_params, # type: List[Dict[str, Any]] - token_pattern=None, # type: str + token_pattern=None, # type: Optional[str] random_state=None, weights=None, # type: Union[np.ndarray, List[float]] ): @@ -168,6 +168,7 @@ def sample_near_with_mask(self, ): # type: (...) -> Tuple[List[str], np.ndarray, np.ndarray, TokenizedText] assert n_samples >= 1 + assert self.token_pattern is not None text = TokenizedText(doc, token_pattern=self.token_pattern) all_docs = [] # type: List[str] similarities = [] diff --git a/eli5/lime/textutils.py b/eli5/lime/textutils.py index dc940860..e896f347 100644 --- a/eli5/lime/textutils.py +++ b/eli5/lime/textutils.py @@ -5,7 +5,7 @@ from __future__ import absolute_import import re import math -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Optional import numpy as np # type: ignore from sklearn.utils import check_random_state # type: ignore @@ -70,7 +70,7 @@ def __init__(self, text, token_pattern=DEFAULT_TOKEN_PATTERN): # type: (str, str) -> None self.text = text self.split = SplitResult.fromtext(text, token_pattern) - self._vocab = None # type: List[str] + self._vocab = None # type: Optional[List[str]] def replace_random_tokens(self, n_samples, # type: int diff --git a/eli5/sklearn/explain_prediction.py b/eli5/sklearn/explain_prediction.py index d6ec0938..7de18783 100644 --- a/eli5/sklearn/explain_prediction.py +++ b/eli5/sklearn/explain_prediction.py @@ -184,6 +184,7 @@ def explain_prediction_linear_classifier(clf, doc, method='linear model', targets=[], ) + assert res.targets is not None _weights = _linear_weights(clf, x, top, feature_names, flt_indices) classes = getattr(clf, "classes_", ["-1", "1"]) # OneClassSVM support @@ -302,6 +303,7 @@ def explain_prediction_linear_regressor(reg, doc, targets=[], is_regression=True, ) + assert res.targets is not None _weights = _linear_weights(reg, x, top, feature_names, flt_indices) names = get_default_target_names(reg) @@ -426,6 +428,7 @@ def _weights(label_id, scale=1.0): description=(DESCRIPTION_TREE_CLF_MULTICLASS if is_multiclass else DESCRIPTION_TREE_CLF_BINARY), ) + assert res.targets is not None display_names = get_target_display_names( clf.classes_, target_names, targets, top_targets, @@ -524,6 +527,7 @@ def _weights(label_id, scale=1.0): targets=[], is_regression=True, ) + assert res.targets is not None names = get_default_target_names(reg, num_targets=num_targets) display_names = get_target_display_names(names, target_names, targets, diff --git a/eli5/sklearn/text.py b/eli5/sklearn/text.py index 23c5f870..8b961c88 100644 --- a/eli5/sklearn/text.py +++ b/eli5/sklearn/text.py @@ -49,7 +49,7 @@ def add_weighted_spans(doc, vec, vectorized, target_expl): def _get_doc_weighted_spans(doc, vec, feature_weights, # type: FeatureWeights - feature_fn=None # type: Callable[[str], str] + feature_fn=None # type: Optional[Callable[[str], str]] ): # type: (...) -> Optional[Tuple[FoundFeatures, DocWeightedSpans]] if isinstance(vec, InvertableHashingVectorizer): @@ -85,7 +85,7 @@ def _get_doc_weighted_spans(doc, def _get_feature_weights_dict(feature_weights, # type: FeatureWeights - feature_fn # type: Callable[[str], str] + feature_fn # type: Optional[Callable[[str], str]] ): # type: (...) -> Dict[str, Tuple[float, Tuple[str, int]]] """ Return {feat_name: (weight, (group, idx))} mapping. """ diff --git a/eli5/sklearn/utils.py b/eli5/sklearn/utils.py index 45ea6466..ba3680aa 100644 --- a/eli5/sklearn/utils.py +++ b/eli5/sklearn/utils.py @@ -69,7 +69,7 @@ def has_intercept(estimator): def get_feature_names(clf, vec=None, bias_name='', feature_names=None, num_features=None, estimator_feature_names=None): - # type: (Any, Any, str, Any, int, Any) -> FeatureNames + # type: (Any, Any, Optional[str], Any, int, Any) -> FeatureNames """ Return a FeatureNames instance that holds all feature names and a bias feature. diff --git a/eli5/xgboost.py b/eli5/xgboost.py index 56db0504..86d3ad58 100644 --- a/eli5/xgboost.py +++ b/eli5/xgboost.py @@ -218,7 +218,7 @@ def explain_prediction_xgboost( def _check_booster_args(xgb, is_regression=None): - # type: (Any, bool) -> Tuple[Booster, bool] + # type: (Any, Optional[bool]) -> Tuple[Booster, Optional[bool]] if isinstance(xgb, Booster): booster = xgb else: diff --git a/tox.ini b/tox.ini index c8054c69..fe7c4648 100644 --- a/tox.ini +++ b/tox.ini @@ -79,7 +79,7 @@ commands={[testenv:py35-extra]commands} basepython=python3.6 deps= {[testenv]deps} - mypy == 0.550 + mypy == 0.641 lxml commands= mypy --html-report ./mypy-cov --check-untyped-defs eli5