Advanced

CEML can be easily extended and all major components can be customized to fit the users needs.

Below is a (non-exhaustive) list of some (common) use cases:

Custom regularization

Instead of using one of the predefined regularizations, we can pass a custom regularization to ceml.sklearn.models.generate_counterfactual().

All regularization implementations must be classes derived from ceml.costfunctions.costfunctions.CostFunction. In case of scikit-learn, if we want to use a gradient based optimization algorithm, we must derive from ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax - note that ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax is already dervied from ceml.costfunctions.costfunctions.CostFunction.

The computation of the regularization itself must be implemented in the score_impl function.

A complete example of a re-implementation of the l2-regularization is given below:

 1#!/usr/bin/env python3
 2# -*- coding: utf-8 -*-
 3import jax.numpy as npx
 4from sklearn.datasets import load_iris
 5from sklearn.model_selection import train_test_split
 6from sklearn.metrics import accuracy_score
 7from sklearn.naive_bayes import GaussianNB
 8
 9from ceml.sklearn import generate_counterfactual
10from ceml.backend.jax.costfunctions import CostFunctionDifferentiableJax
11
12
13# Custom implementation of the l2-regularization. Note that this regularization is differentiable
14class MyRegularization(CostFunctionDifferentiableJax):
15    def __init__(self, x_orig):
16        self.x_orig = x_orig
17
18        super(MyRegularization, self).__init__()
19    
20    def score_impl(self, x):
21        return npx.sum(npx.square(x - self.x_orig)) # Note: This expression must be written in jax and it must be differentiable!
22
23
24if __name__ == "__main__":
25    # Load data
26    X, y = load_iris(return_X_y=True)
27    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=4242)
28
29    # Whitelist of features - list of features we can change/use when computing a counterfactual 
30    features_whitelist = None   # All features can be used.
31
32    # Create and fit the model
33    model = GaussianNB()   # Note that ceml requires: multi_class='multinomial'
34    model.fit(X_train, y_train)
35    
36    # Select data point for explaining its prediction
37    x = X_test[1,:]
38    print("Prediction on x: {0}".format(model.predict([x])))
39
40    # Create custom regularization function
41    regularization = MyRegularization(x)
42
43    # Compute counterfactual
44    print("\nCompute counterfactual ....")
45    print(generate_counterfactual(model, x, y_target=0, features_whitelist=features_whitelist, regularization=regularization, optimizer="bfgs"))

Custom loss function

In order to use a custom loss function we have to do three things:

  1. Implement the loss function. This is the same as implementing a custom regularization - a regularization is a loss function that works on the input rather than on the output.

  2. Derive a child class from the model class and overwrite the get_loss function to use our custom loss function.

  3. Derive a child class from the counterfactual class of the model and overwrite the rebuild_model function to use our model from the previous step.

A complete example of using a custom loss for a linear regression model is given below:

 1#!/usr/bin/env python3
 2# -*- coding: utf-8 -*-
 3import numpy as np
 4import jax.numpy as npx
 5from sklearn.datasets import load_boston
 6from sklearn.model_selection import train_test_split
 7from sklearn.metrics import accuracy_score
 8from sklearn.linear_model import Ridge
 9
10
11from ceml.sklearn import generate_counterfactual
12from ceml.sklearn import LinearRegression, LinearRegressionCounterfactual
13from ceml.backend.jax.costfunctions import CostFunctionDifferentiableJax
14
15
16# Custom implementation of the l2-regularization. Note that this regularization is differentiable.
17class MyLoss(CostFunctionDifferentiableJax):
18    def __init__(self, input_to_output, y_target):
19        self.y_target = y_target
20
21        super(MyLoss, self).__init__(input_to_output)
22    
23    def score_impl(self, y):
24        return npx.abs(y - y_target)**4
25
26# Derive a new class from ceml.sklearn.linearregression.LinearRegression and overwrite the get_loss method to use our custom loss MyLoss
27class LinearRegressionWithMyLoss(LinearRegression):
28    def __init__(self, model):
29        super(LinearRegressionWithMyLoss, self).__init__(model)
30
31    def get_loss(self, y_target, pred=None):
32        if pred is None:
33            return MyLoss(self.predict, y_target)
34        else:
35            return MyLoss(pred, y_target)
36
37# Derive a new class from ceml.sklearn.linearregression.LinearRegressionCounterfactual that uses our new linear regression wrapper LinearRegressionWithMyLoss for computing counterfactuals
38class MyLinearRegressionCounterfactual(LinearRegressionCounterfactual):
39    def __init__(self, model):
40        super(MyLinearRegressionCounterfactual, self).__init__(model)
41
42    def rebuild_model(self, model):
43        return LinearRegressionWithMyLoss(model)
44
45
46if __name__ == "__main__":
47    # Load data
48    X, y = load_boston(return_X_y=True)
49    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=4242)
50
51    # Whitelist of features - list of features we can change/use when computing a counterfactual 
52    features_whitelist = None   # All features can be used.
53
54    # Create and fit model
55    model = Ridge()
56    model.fit(X_train, y_train)
57
58    # Select data point for explaining its prediction
59    x = X_test[1,:]
60    print("Prediction on x: {0}".format(model.predict([x])))
61
62    # Compute counterfactual
63    print("\nCompute counterfactual ....")
64    y_target = 25.0
65    done = lambda z: np.abs(y_target - z) <= 0.5     # Since we might not be able to achieve `y_target` exactly, we tell ceml that we are happy if we do not deviate more than 0.5 from it.
66    
67    cf = MyLinearRegressionCounterfactual(model)    # Since we are using our own loss function, we can no longer use standard method generate_counterfactual 
68    print(cf.compute_counterfactual(x, y_target=y_target, features_whitelist=features_whitelist, regularization="l2", C=1.0, optimizer="bfgs", done=done))

Add a custom optimizer

We can use a custom optimization method by:

  1. Dervice a new class from ceml.optim.optimizer.Optimizer and implement the custom optimization method.

  2. Create a new instance of this class and pass it as the argument for the optimizer parameter to the function ceml.sklearn.models.generate_counterfactual() (or any other function that computes a counterfactual).

A complete example of using a custom optimization method for computing counterfactuals from a logistic regression model is given below:

 1#!/usr/bin/env python3
 2# -*- coding: utf-8 -*-
 3import numpy as np
 4from sklearn.datasets import load_iris
 5from sklearn.model_selection import train_test_split
 6from sklearn.linear_model import LogisticRegression
 7from scipy.optimize import minimize
 8
 9from ceml.sklearn import generate_counterfactual
10from ceml.optim import Optimizer
11
12
13# Custom optimization method that simply calls the BFGS optimizer from scipy
14class MyOptimizer(Optimizer):
15    def __init__(self):
16        self.f = None
17        self.f_grad = None
18        self.x0 = None
19        self.tol = None
20        self.max_iter = None
21
22        super(MyOptimizer, self).__init__()
23    
24    def init(self, f, f_grad, x0, tol=None, max_iter=None):
25        self.f = f
26        self.f_grad = f_grad
27        self.x0 = x0
28        self.tol = tol
29        self.max_iter = max_iter
30
31    def is_grad_based(self):
32        return True
33    
34    def __call__(self):
35        optimum = minimize(fun=self.f, x0=self.x0, jac=self.f_grad, tol=self.tol, options={'maxiter': self.max_iter}, method="BFGS")
36        return np.array(optimum["x"])
37
38
39if __name__ == "__main__":
40    # Load data
41    X, y = load_iris(return_X_y=True)
42
43    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=4242)
44
45    # Create and fit model
46    model = LogisticRegression(solver='lbfgs', multi_class='multinomial')
47    model.fit(X_train, y_train)
48
49    # Select data point for explaining its prediction
50    x = X_test[1,:]
51    print("Prediction on x: {0}".format(model.predict([x])))
52
53    # Compute counterfactual by using our custom optimizer 'MyOptimizer'
54    print("\nCompute counterfactual ....")
55    print(generate_counterfactual(model, x, y_target=0, optimizer=MyOptimizer(), features_whitelist=None, regularization="l1", C=0.5))