Advanced
CEML can be easily extended and all major components can be customized to fit the users needs.
Below is a (non-exhaustive) list of some (common) use cases:
Custom regularization
Instead of using one of the predefined regularizations, we can pass a custom regularization to ceml.sklearn.models.generate_counterfactual()
.
All regularization implementations must be classes derived from ceml.costfunctions.costfunctions.CostFunction
. In case of scikit-learn, if we want to use a gradient based optimization algorithm, we must derive from ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
- note that ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
is already dervied from ceml.costfunctions.costfunctions.CostFunction
.
Note
For tensorflow/keras or PyTorch models the base classes are ceml.backend.tensorflow.costfunctions.costfunctions.CostFunctionDifferentiableTf
and ceml.backend.torch.costfunctions.costfunctions.CostFunctionDifferentiableTorch
.
The computation of the regularization itself must be implemented in the score_impl function.
A complete example of a re-implementation of the l2-regularization is given below:
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3import jax.numpy as npx
4from sklearn.datasets import load_iris
5from sklearn.model_selection import train_test_split
6from sklearn.metrics import accuracy_score
7from sklearn.naive_bayes import GaussianNB
8
9from ceml.sklearn import generate_counterfactual
10from ceml.backend.jax.costfunctions import CostFunctionDifferentiableJax
11
12
13# Custom implementation of the l2-regularization. Note that this regularization is differentiable
14class MyRegularization(CostFunctionDifferentiableJax):
15 def __init__(self, x_orig):
16 self.x_orig = x_orig
17
18 super(MyRegularization, self).__init__()
19
20 def score_impl(self, x):
21 return npx.sum(npx.square(x - self.x_orig)) # Note: This expression must be written in jax and it must be differentiable!
22
23
24if __name__ == "__main__":
25 # Load data
26 X, y = load_iris(return_X_y=True)
27 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=4242)
28
29 # Whitelist of features - list of features we can change/use when computing a counterfactual
30 features_whitelist = None # All features can be used.
31
32 # Create and fit the model
33 model = GaussianNB() # Note that ceml requires: multi_class='multinomial'
34 model.fit(X_train, y_train)
35
36 # Select data point for explaining its prediction
37 x = X_test[1,:]
38 print("Prediction on x: {0}".format(model.predict([x])))
39
40 # Create custom regularization function
41 regularization = MyRegularization(x)
42
43 # Compute counterfactual
44 print("\nCompute counterfactual ....")
45 print(generate_counterfactual(model, x, y_target=0, features_whitelist=features_whitelist, regularization=regularization, optimizer="bfgs"))
Custom loss function
In order to use a custom loss function we have to do three things:
Implement the loss function. This is the same as implementing a custom regularization - a regularization is a loss function that works on the input rather than on the output.
Derive a child class from the model class and overwrite the get_loss function to use our custom loss function.
Derive a child class from the counterfactual class of the model and overwrite the rebuild_model function to use our model from the previous step.
A complete example of using a custom loss for a linear regression model is given below:
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3import numpy as np
4import jax.numpy as npx
5from sklearn.datasets import load_boston
6from sklearn.model_selection import train_test_split
7from sklearn.metrics import accuracy_score
8from sklearn.linear_model import Ridge
9
10
11from ceml.sklearn import generate_counterfactual
12from ceml.sklearn import LinearRegression, LinearRegressionCounterfactual
13from ceml.backend.jax.costfunctions import CostFunctionDifferentiableJax
14
15
16# Custom implementation of the l2-regularization. Note that this regularization is differentiable.
17class MyLoss(CostFunctionDifferentiableJax):
18 def __init__(self, input_to_output, y_target):
19 self.y_target = y_target
20
21 super(MyLoss, self).__init__(input_to_output)
22
23 def score_impl(self, y):
24 return npx.abs(y - y_target)**4
25
26# Derive a new class from ceml.sklearn.linearregression.LinearRegression and overwrite the get_loss method to use our custom loss MyLoss
27class LinearRegressionWithMyLoss(LinearRegression):
28 def __init__(self, model):
29 super(LinearRegressionWithMyLoss, self).__init__(model)
30
31 def get_loss(self, y_target, pred=None):
32 if pred is None:
33 return MyLoss(self.predict, y_target)
34 else:
35 return MyLoss(pred, y_target)
36
37# Derive a new class from ceml.sklearn.linearregression.LinearRegressionCounterfactual that uses our new linear regression wrapper LinearRegressionWithMyLoss for computing counterfactuals
38class MyLinearRegressionCounterfactual(LinearRegressionCounterfactual):
39 def __init__(self, model):
40 super(MyLinearRegressionCounterfactual, self).__init__(model)
41
42 def rebuild_model(self, model):
43 return LinearRegressionWithMyLoss(model)
44
45
46if __name__ == "__main__":
47 # Load data
48 X, y = load_boston(return_X_y=True)
49 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=4242)
50
51 # Whitelist of features - list of features we can change/use when computing a counterfactual
52 features_whitelist = None # All features can be used.
53
54 # Create and fit model
55 model = Ridge()
56 model.fit(X_train, y_train)
57
58 # Select data point for explaining its prediction
59 x = X_test[1,:]
60 print("Prediction on x: {0}".format(model.predict([x])))
61
62 # Compute counterfactual
63 print("\nCompute counterfactual ....")
64 y_target = 25.0
65 done = lambda z: np.abs(y_target - z) <= 0.5 # Since we might not be able to achieve `y_target` exactly, we tell ceml that we are happy if we do not deviate more than 0.5 from it.
66
67 cf = MyLinearRegressionCounterfactual(model) # Since we are using our own loss function, we can no longer use standard method generate_counterfactual
68 print(cf.compute_counterfactual(x, y_target=y_target, features_whitelist=features_whitelist, regularization="l2", C=1.0, optimizer="bfgs", done=done))
Add a custom optimizer
We can use a custom optimization method by:
Dervice a new class from
ceml.optim.optimizer.Optimizer
and implement the custom optimization method.Create a new instance of this class and pass it as the argument for the optimizer parameter to the function
ceml.sklearn.models.generate_counterfactual()
(or any other function that computes a counterfactual).
A complete example of using a custom optimization method for computing counterfactuals from a logistic regression model is given below:
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3import numpy as np
4from sklearn.datasets import load_iris
5from sklearn.model_selection import train_test_split
6from sklearn.linear_model import LogisticRegression
7from scipy.optimize import minimize
8
9from ceml.sklearn import generate_counterfactual
10from ceml.optim import Optimizer
11
12
13# Custom optimization method that simply calls the BFGS optimizer from scipy
14class MyOptimizer(Optimizer):
15 def __init__(self):
16 self.f = None
17 self.f_grad = None
18 self.x0 = None
19 self.tol = None
20 self.max_iter = None
21
22 super(MyOptimizer, self).__init__()
23
24 def init(self, f, f_grad, x0, tol=None, max_iter=None):
25 self.f = f
26 self.f_grad = f_grad
27 self.x0 = x0
28 self.tol = tol
29 self.max_iter = max_iter
30
31 def is_grad_based(self):
32 return True
33
34 def __call__(self):
35 optimum = minimize(fun=self.f, x0=self.x0, jac=self.f_grad, tol=self.tol, options={'maxiter': self.max_iter}, method="BFGS")
36 return np.array(optimum["x"])
37
38
39if __name__ == "__main__":
40 # Load data
41 X, y = load_iris(return_X_y=True)
42
43 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=4242)
44
45 # Create and fit model
46 model = LogisticRegression(solver='lbfgs', multi_class='multinomial')
47 model.fit(X_train, y_train)
48
49 # Select data point for explaining its prediction
50 x = X_test[1,:]
51 print("Prediction on x: {0}".format(model.predict([x])))
52
53 # Compute counterfactual by using our custom optimizer 'MyOptimizer'
54 print("\nCompute counterfactual ....")
55 print(generate_counterfactual(model, x, y_target=0, optimizer=MyOptimizer(), features_whitelist=None, regularization="l1", C=0.5))