设计思路
更新中
输出结果
核心代码
# flake8: noqa
import warnings
import sys
__version__ = '0.37.0'
# check python version
if (sys.version_info < (3, 0)):
warnings.warn("As of version 0.29.0 shap only supports Python 3 (not 2)!")
from ._explanation import Explanation, Cohorts
# explainers
from .explainers._explainer import Explainer
from .explainers._kernel import Kernel as KernelExplainer
from .explainers._sampling import Sampling as SamplingExplainer
from .explainers._tree import Tree as TreeExplainer
from .explainers._deep import Deep as DeepExplainer
from .explainers._gradient import Gradient as GradientExplainer
from .explainers._linear import Linear as LinearExplainer
from .explainers._partition import Partition as PartitionExplainer
from .explainers._permutation import Permutation as PermutationExplainer
from .explainers._additive import Additive as AdditiveExplainer
from .explainers import other
# plotting (only loaded if matplotlib is present)
def unsupported(*args, **kwargs):
warnings.warn("matplotlib is not installed so plotting is not available! Run `pip install matplotlib` to fix this.")
try:
import matplotlib
have_matplotlib = True
except ImportError:
have_matplotlib = False
if have_matplotlib:
from .plots._beeswarm import summary_legacy as summary_plot
from .plots._decision import decision as decision_plot, multioutput_decision as multioutput_decision_plot
from .plots._scatter import dependence_legacy as dependence_plot
from .plots._force import force as force_plot, initjs, save_html, getjs
from .plots._image import image as image_plot
from .plots._monitoring import monitoring as monitoring_plot
from .plots._embedding import embedding as embedding_plot
from .plots._partial_dependence import partial_dependence as partial_dependence_plot
from .plots._bar import bar_legacy as bar_plot
from .plots._waterfall import waterfall as waterfall_plot
from .plots._group_difference import group_difference as group_difference_plot
from .plots._text import text as text_plot
else:
summary_plot = unsupported
decision_plot = unsupported
multioutput_decision_plot = unsupported
dependence_plot = unsupported
force_plot = unsupported
initjs = unsupported
save_html = unsupported
image_plot = unsupported
monitoring_plot = unsupported
embedding_plot = unsupported
partial_dependence_plot = unsupported
bar_plot = unsupported
waterfall_plot = unsupported
text_plot = unsupported
# other stuff :)
from . import datasets
from . import utils
from . import links
#from . import benchmark
from .utils._legacy import kmeans
from .utils import sample, approximate_interactions
# TODO: Add support for hclustering based explanations where we sort the leaf order by magnitude and then show the dendrogram to the left
def summary_legacy(shap_values, features=None, feature_names=None, max_display=None, plot_type=None,
color=None, axis_color="#333333", title=None, alpha=1, show=True, sort=True,
color_bar=True, plot_size="auto", layered_violin_max_num_bins=20, class_names=None,
class_inds=None,
color_bar_label=labels["FEATURE_VALUE"],
cmap=colors.red_blue,
# depreciated
auto_size_plot=None,
use_log_scale=False):
"""Create a SHAP beeswarm plot, colored by feature values when they are provided.
Parameters
----------
shap_values : numpy.array
For single output explanations this is a matrix of SHAP values (# samples x # features).
For multi-output explanations this is a list of such matrices of SHAP values.
features : numpy.array or pandas.DataFrame or list
Matrix of feature values (# samples x # features) or a feature_names list as shorthand
feature_names : list
Names of the features (length # features)
max_display : int
How many top features to include in the plot (default is 20, or 7 for interaction plots)
plot_type : "dot" (default for single output), "bar" (default for multi-output), "violin",
or "compact_dot".
What type of summary plot to produce. Note that "compact_dot" is only used for
SHAP interaction values.
plot_size : "auto" (default), float, (float, float), or None
What size to make the plot. By default the size is auto-scaled based on the number of
features that are being displayed. Passing a single float will cause each row to be that
many inches high. Passing a pair of floats will scale the plot by that
number of inches. If None is passed then the size of the current figure will be left
unchanged.
"""
# support passing an explanation object
if str(type(shap_values)).endswith("Explanation'>"):
shap_exp = shap_values
base_value = shap_exp.base_value
shap_values = shap_exp.values
if features is None:
features = shap_exp.data
if feature_names is None:
feature_names = shap_exp.feature_names
# if out_names is None: # TODO: waiting for slicer support of this
# out_names = shap_exp.output_names
# deprecation warnings
if auto_size_plot is not None:
warnings.warn("auto_size_plot=False is deprecated and is now ignored! Use plot_size=None instead.")
multi_class = False
if isinstance(shap_values, list):
multi_class = True
if plot_type is None:
plot_type = "bar" # default for multi-output explanations
assert plot_type == "bar", "Only plot_type = 'bar' is supported for multi-output explanations!"
else:
if plot_type is None:
plot_type = "dot" # default for single output explanations
assert len(shap_values.shape) != 1, "Summary plots need a matrix of shap_values, not a vector."
# default color:
if color is None:
if plot_type == 'layered_violin':
color = "coolwarm"
elif multi_class:
color = lambda i: colors.red_blue_circle(i/len(shap_values))
else:
color = colors.blue_rgb