Source code for braindecode.visualization.confusion_matrices

# Authors: Robin Schirrmeister <robintibor@gmail.com>
#
# License: BSD (3-clause)

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm


[docs]def plot_confusion_matrix(confusion_mat, class_names=None, figsize=None, colormap=cm.bwr, textcolor='black', vmin=None, vmax=None, fontweight='normal', rotate_row_labels=90, rotate_col_labels=0, with_f1_score=False, norm_axes=(0, 1), rotate_precision=False, class_names_fontsize=12): """ Generates a confusion matrix with additional precision and sensitivity metrics as in [1]_. Parameters ---------- confusion_mat: 2d numpy array A confusion matrix, e.g. sklearn confusion matrix: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html class_names: array, optional List of classes/targets. figsize: tuple, optional Size of the generated confusion matrix figure. colormap: matplotlib cm colormap, optional textcolor: str, optional Color of the text in the figure. vmin, vmax: float, optional The data range that the colormap covers. fontweight: str, optional Weight of the font in the figure: [ 'normal' | 'bold' | 'heavy' | 'light' | 'ultrabold' | 'ultralight'] rotate_row_labels: int, optional The rotation angle of the row labels rotate_col_labels: int, optional The rotation angle of the column labels with_f1_score: bool, optional norm_axes: tuple, optional rotate_precision: bool, optional class_names_fontsize: int, optional Returns ------- fig: matplotlib figure References ---------- .. [1] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T. (2017). Deep learning with convolutional neural networks for EEG decoding and visualization. Human Brain Mapping , Aug. 2017. Online: http://dx.doi.org/10.1002/hbm.23730 """ # transpose to get confusion matrix same way as matlab confusion_mat = confusion_mat.T n_classes = confusion_mat.shape[0] if class_names is None: class_names = [str(i_class + 1) for i_class in range(n_classes)] # norm by all targets normed_conf_mat = confusion_mat / np.float32(np.sum(confusion_mat, axis=norm_axes, keepdims=True)) fig = plt.figure(figsize=figsize) plt.clf() ax = fig.add_subplot(111) ax.set_aspect(1) if vmin is None: vmin = np.min(normed_conf_mat) if vmax is None: vmax = np.max(normed_conf_mat) # see http://stackoverflow.com/a/31397438/1469195 # brighten so that black text remains readable # used alpha=0.6 before def _brighten(x, ): brightened_x = 1 - ((1 - np.array(x)) * 0.4) return brightened_x brightened_cmap = _cmap_map(_brighten, colormap) # colormap # ax.imshow(np.array(normed_conf_mat), cmap=brightened_cmap, interpolation='nearest', vmin=vmin, vmax=vmax) # make space for precision and sensitivity plt.xlim(-0.5, normed_conf_mat.shape[0] + 0.5) plt.ylim(normed_conf_mat.shape[1] + 0.5, -0.5) width = len(confusion_mat) height = len(confusion_mat[0]) for x in range(width): for y in range(height): if x == y: this_font_weight = 'bold' else: this_font_weight = fontweight annotate_str = "{:d}".format(confusion_mat[x][y]) annotate_str += "\n" ax.annotate(annotate_str.format(confusion_mat[x][y]), xy=(y, x), horizontalalignment='center', verticalalignment='center', fontsize=12, color=textcolor, fontweight=this_font_weight) if x != y or (not with_f1_score): ax.annotate( "\n\n{:4.1f}%".format( normed_conf_mat[x][y] * 100), xy=(y, x), horizontalalignment='center', verticalalignment='center', fontsize=10, color=textcolor, fontweight=this_font_weight) else: assert x == y precision = confusion_mat[x][x] / float(np.sum( confusion_mat[x, :])) sensitivity = confusion_mat[x][x] / float(np.sum( confusion_mat[:, y])) f1_score = 2 * precision * sensitivity / (precision + sensitivity) ax.annotate("\n{:4.1f}%\n{:4.1f}% (F)".format( (confusion_mat[x][y] / float(np.sum(confusion_mat))) * 100, f1_score * 100), xy=(y, x + 0.1), horizontalalignment='center', verticalalignment='center', fontsize=10, color=textcolor, fontweight=this_font_weight) # Add values for target correctness etc. for x in range(width): y = len(confusion_mat) if float(np.sum(confusion_mat[x, :])) == 0: annotate_str = "-" else: correctness = confusion_mat[x][x] / float(np.sum(confusion_mat[x, :])) annotate_str = "" annotate_str += "\n{:5.2f}%".format(correctness * 100) ax.annotate(annotate_str, xy=(y, x), horizontalalignment='center', verticalalignment='center', fontsize=12) for y in range(height): x = len(confusion_mat) if float(np.sum(confusion_mat[:, y])) == 0: annotate_str = "-" else: correctness = confusion_mat[y][y] / float(np.sum(confusion_mat[:, y])) annotate_str = "" annotate_str += "\n{:5.2f}%".format(correctness * 100) ax.annotate(annotate_str, xy=(y, x), horizontalalignment='center', verticalalignment='center', fontsize=12) overall_correctness = np.sum(np.diag(confusion_mat)) / np.sum(confusion_mat).astype(float) ax.annotate("{:5.2f}%".format(overall_correctness * 100), xy=(len(confusion_mat), len(confusion_mat)), horizontalalignment='center', verticalalignment='center', fontsize=12, fontweight='bold') plt.xticks(range(width), class_names, fontsize=class_names_fontsize, rotation=rotate_col_labels) plt.yticks(np.arange(0, height), class_names, va='center', fontsize=class_names_fontsize, rotation=rotate_row_labels) plt.grid(False) plt.ylabel('Predictions', fontsize=15) plt.xlabel('Targets', fontsize=15) # n classes is also shape of matrix/size ax.text(-1.2, n_classes + 0.2, "Recall", ha='center', va='center', fontsize=13) if rotate_precision: rotation = 90 x_pos = -1.1 va = 'center' else: rotation = 0 x_pos = -0.8 va = 'top' ax.text(n_classes, x_pos, "Precision", ha='center', va=va, rotation=rotation, # 270, fontsize=13) return fig
# see http://stackoverflow.com/a/31397438/1469195 def _cmap_map(function, cmap, name='colormap_mod', N=None, gamma=None): """ Modify a colormap using `function` which must operate on 3-element arrays of [r, g, b] values. You may specify the number of colors, `N`, and the opacity, `gamma`, value of the returned colormap. These values default to the ones in the input `cmap`. You may also specify a `name` for the colormap, so that it can be loaded using plt.get_cmap(name). """ from matplotlib.colors import LinearSegmentedColormap as lsc if N is None: N = cmap.N if gamma is None: gamma = cmap._gamma cdict = cmap._segmentdata # Cast the steps into lists: step_dict = {key: list(map(lambda x: x[0], cdict[key])) for key in cdict} # Now get the unique steps (first column of the arrays): step_dicts = np.array(list(step_dict.values())) step_list = np.unique(step_dicts) # 'y0', 'y1' are as defined in LinearSegmentedColormap docstring: y0 = cmap(step_list)[:, :3] y1 = y0.copy()[:, :3] # Go back to catch the discontinuities, and place them into y0, y1 for iclr, key in enumerate(['red', 'green', 'blue']): for istp, step in enumerate(step_list): try: ind = step_dict[key].index(step) except ValueError: # This step is not in this color continue y0[istp, iclr] = cdict[key][ind][1] y1[istp, iclr] = cdict[key][ind][2] # Map the colors to their new values: y0 = np.array(list(map(function, y0))) y1 = np.array(list(map(function, y1))) # Build the new colormap (overwriting step_dict): for iclr, clr in enumerate(['red', 'green', 'blue']): step_dict[clr] = np.vstack((step_list, y0[:, iclr], y1[:, iclr])).T # Remove alpha, otherwise crashes... step_dict.pop('alpha', None) return lsc(name, step_dict, N=N, gamma=gamma)