Plausible counterfactuals
In Convex Density Constraints for Computing Plausible Counterfactual Explanations (Artelt et al. 2020) a general framework for computing plausible counterfactuals was proposed. CEML currently implements these methods for softmax regression and decision tree classifiers.
In order to compute plausible counterfactual explanations, some preparations are required:
Use the ceml.sklearn.plausibility.prepare_computation_of_plausible_counterfactuals()
function for creating a dictionary that can be passed to functions for generating counterfactuals.
You have to provide class dependent fitted Gaussian Mixture Models (GMMs) and the training data itself. In addition, you can also provide an affine preprocessing and a requested density/plausibility threshold (if you do not specify any, a suitable threshold will be selected automatically).
A complete example for computing a plausible counterfactual of a digit classifier (logistic regression) is given below:
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3import numpy as np
4import random
5random.seed(424242)
6import matplotlib.pyplot as plt
7from sklearn.linear_model import LogisticRegression
8from sklearn.mixture import GaussianMixture
9from sklearn.model_selection import GridSearchCV, train_test_split
10from sklearn.decomposition import PCA
11from sklearn.datasets import load_digits
12from sklearn.metrics import accuracy_score
13from sklearn.utils import shuffle
14
15from ceml.sklearn.softmaxregression import softmaxregression_generate_counterfactual
16from ceml.sklearn.plausibility import prepare_computation_of_plausible_counterfactuals
17
18
19if __name__ == "__main__":
20 # Load data set
21 X, y = load_digits(return_X_y=True);pca_dim=40
22
23 X, y = shuffle(X, y, random_state=42)
24 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=4242)
25
26 # Choose target labels
27 y_test_target = []
28 labels = np.unique(y)
29 for i in range(X_test.shape[0]):
30 y_test_target.append(random.choice(list(filter(lambda l: l != y_test[i], labels))))
31 y_test_target = np.array(y_test_target)
32
33 # Reduce dimensionality
34 X_train_orig = np.copy(X_train)
35 X_test_orig = np.copy(X_test)
36 projection_matrix = None
37 projection_mean_sub = None
38
39 pca = PCA(n_components=pca_dim)
40 pca.fit(X_train)
41
42 projection_matrix = pca.components_ # Projection matrix
43 projection_mean_sub = pca.mean_
44
45 X_train = np.dot(X_train - projection_mean_sub, projection_matrix.T)
46 X_test = np.dot(X_test - projection_mean_sub, projection_matrix.T)
47
48 # Fit classifier
49 model = LogisticRegression(multi_class="multinomial", solver="lbfgs", random_state=42)
50 model.fit(X_train, y_train)
51
52 # Compute accuracy on test set
53 print("Accuracy: {0}".format(accuracy_score(y_test, model.predict(X_test))))
54
55 # For each class, fit density estimators
56 density_estimators = {}
57 kernel_density_estimators = {}
58 labels = np.unique(y)
59 for label in labels:
60 # Get all samples with the 'correct' label
61 idx = y_train == label
62 X_ = X_train[idx, :]
63
64 # Optimize hyperparameters
65 cv = GridSearchCV(estimator=GaussianMixture(covariance_type='full'), param_grid={'n_components': range(2, 10)}, n_jobs=-1, cv=5)
66 cv.fit(X_)
67 n_components = cv.best_params_["n_components"]
68
69 # Build density estimators
70 de = GaussianMixture(n_components=n_components, covariance_type='full', random_state=42)
71 de.fit(X_)
72
73 density_estimators[label] = de
74
75 # Build dictionary for ceml
76 plausibility_stuff = prepare_computation_of_plausible_counterfactuals(X_train_orig, y_train, density_estimators, projection_mean_sub, projection_matrix)
77
78 # Compute and plot counterfactual with vs. without density constraints
79 i = 0
80
81 x_orig = X_test[i,:]
82 x_orig_orig = X_test_orig[i,:]
83 y_orig = y_test[i]
84 y_target = y_test_target[i]
85 print("Original label: {0}".format(y_orig))
86 print("Target label: {0}".format(y_target))
87
88 if(model.predict([x_orig]) == y_target): # Model already predicts target label!
89 raise ValueError("Requested prediction already satisfied")
90
91 # Compute plausible counterfactual
92 x_cf_plausible = softmaxregression_generate_counterfactual(model, x_orig_orig, y_target, plausibility=plausibility_stuff)
93 x_cf_plausible_projected = np.dot(x_cf_plausible - projection_mean_sub, projection_matrix.T)
94 print("Predictec label of plausible countrefactual: {0}".format(model.predict([x_cf_plausible_projected])))
95
96 # Compute closest counterfactual
97 plausibility_stuff["use_density_constraints"] = False
98 x_cf = softmaxregression_generate_counterfactual(model, x_orig_orig, y_target, plausibility=plausibility_stuff)
99 x_cf_projected = np.dot(x_cf - projection_mean_sub, projection_matrix.T)
100 print("Predicted label of closest counterfactual: {0}".format(model.predict([x_cf_projected])))
101
102 # Plot results
103 fig, axes = plt.subplots(3, 1)
104 axes[0].imshow(x_orig_orig.reshape(8, 8)) # Original sample
105 axes[1].imshow(x_cf.reshape(8, 8)) # Closest counterfactual
106 axes[2].imshow(x_cf_plausible.reshape(8, 8)) # Plausible counterfactual
107 plt.show()