Plausible counterfactuals

In Convex Density Constraints for Computing Plausible Counterfactual Explanations (Artelt et al. 2020) a general framework for computing plausible counterfactuals was proposed. CEML currently implements these methods for softmax regression and decision tree classifiers.

In order to compute plausible counterfactual explanations, some preparations are required:

Use the ceml.sklearn.plausibility.prepare_computation_of_plausible_counterfactuals() function for creating a dictionary that can be passed to functions for generating counterfactuals. You have to provide class dependent fitted Gaussian Mixture Models (GMMs) and the training data itself. In addition, you can also provide an affine preprocessing and a requested density/plausibility threshold (if you do not specify any, a suitable threshold will be selected automatically).

A complete example for computing a plausible counterfactual of a digit classifier (logistic regression) is given below:

  1#!/usr/bin/env python3
  2# -*- coding: utf-8 -*-
  3import numpy as np
  4import random
  5random.seed(424242)
  6import matplotlib.pyplot as plt
  7from sklearn.linear_model import LogisticRegression
  8from sklearn.mixture import GaussianMixture
  9from sklearn.model_selection import GridSearchCV, train_test_split
 10from sklearn.decomposition import PCA
 11from sklearn.datasets import load_digits
 12from sklearn.metrics import accuracy_score
 13from sklearn.utils import shuffle
 14
 15from ceml.sklearn.softmaxregression import softmaxregression_generate_counterfactual
 16from ceml.sklearn.plausibility import prepare_computation_of_plausible_counterfactuals
 17
 18
 19if __name__ == "__main__":
 20    # Load data set
 21    X, y = load_digits(return_X_y=True);pca_dim=40
 22
 23    X, y = shuffle(X, y, random_state=42)
 24    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=4242)
 25
 26    # Choose target labels
 27    y_test_target = []
 28    labels = np.unique(y)
 29    for i in range(X_test.shape[0]):
 30        y_test_target.append(random.choice(list(filter(lambda l: l != y_test[i], labels))))
 31    y_test_target = np.array(y_test_target)
 32
 33    # Reduce dimensionality
 34    X_train_orig = np.copy(X_train)
 35    X_test_orig = np.copy(X_test)
 36    projection_matrix = None
 37    projection_mean_sub = None
 38
 39    pca = PCA(n_components=pca_dim)
 40    pca.fit(X_train)
 41
 42    projection_matrix = pca.components_ # Projection matrix
 43    projection_mean_sub = pca.mean_
 44
 45    X_train = np.dot(X_train - projection_mean_sub, projection_matrix.T)
 46    X_test = np.dot(X_test - projection_mean_sub, projection_matrix.T)
 47
 48    # Fit classifier
 49    model = LogisticRegression(multi_class="multinomial", solver="lbfgs", random_state=42)
 50    model.fit(X_train, y_train)
 51
 52    # Compute accuracy on test set
 53    print("Accuracy: {0}".format(accuracy_score(y_test, model.predict(X_test))))
 54
 55    # For each class, fit density estimators
 56    density_estimators = {}
 57    kernel_density_estimators = {}
 58    labels = np.unique(y)
 59    for label in labels:
 60        # Get all samples with the 'correct' label
 61        idx = y_train == label
 62        X_ = X_train[idx, :]
 63
 64        # Optimize hyperparameters
 65        cv = GridSearchCV(estimator=GaussianMixture(covariance_type='full'), param_grid={'n_components': range(2, 10)}, n_jobs=-1, cv=5)
 66        cv.fit(X_)
 67        n_components = cv.best_params_["n_components"]
 68
 69        # Build density estimators
 70        de = GaussianMixture(n_components=n_components, covariance_type='full', random_state=42)
 71        de.fit(X_)
 72
 73        density_estimators[label] = de
 74
 75    # Build dictionary for ceml
 76    plausibility_stuff = prepare_computation_of_plausible_counterfactuals(X_train_orig, y_train, density_estimators, projection_mean_sub, projection_matrix)
 77
 78    # Compute and plot counterfactual with vs. without density constraints
 79    i = 0
 80
 81    x_orig = X_test[i,:]
 82    x_orig_orig = X_test_orig[i,:]
 83    y_orig = y_test[i]
 84    y_target = y_test_target[i]
 85    print("Original label: {0}".format(y_orig))
 86    print("Target label: {0}".format(y_target))
 87
 88    if(model.predict([x_orig]) == y_target):  # Model already predicts target label!
 89        raise ValueError("Requested prediction already satisfied")
 90
 91    # Compute plausible counterfactual
 92    x_cf_plausible = softmaxregression_generate_counterfactual(model, x_orig_orig, y_target, plausibility=plausibility_stuff)
 93    x_cf_plausible_projected = np.dot(x_cf_plausible - projection_mean_sub, projection_matrix.T)
 94    print("Predictec label of plausible countrefactual: {0}".format(model.predict([x_cf_plausible_projected])))
 95
 96    # Compute closest counterfactual     
 97    plausibility_stuff["use_density_constraints"] = False   
 98    x_cf = softmaxregression_generate_counterfactual(model, x_orig_orig, y_target, plausibility=plausibility_stuff)
 99    x_cf_projected = np.dot(x_cf - projection_mean_sub, projection_matrix.T)
100    print("Predicted label of closest counterfactual: {0}".format(model.predict([x_cf_projected])))
101
102    # Plot results
103    fig, axes = plt.subplots(3, 1)
104    axes[0].imshow(x_orig_orig.reshape(8, 8))    # Original sample
105    axes[1].imshow(x_cf.reshape(8, 8))           # Closest counterfactual
106    axes[2].imshow(x_cf_plausible.reshape(8, 8)) # Plausible counterfactual
107    plt.show()