Advanced

CEML can be easily extended and all major components can be customized to fit the users needs.

Below is a (non-exhaustive) list of some (common) use cases:

Custom regularization

Instead of using one of the predefined regularizations, we can pass a custom regularization to ceml.sklearn.models.generate_counterfactual().

All regularization implementations must be classes derived from ceml.costfunctions.costfunctions.CostFunction. In case of scikit-learn, if we want to use a gradient based optimization algorithm, we must derive from ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax - note that ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax is already dervied from ceml.costfunctions.costfunctions.CostFunction.

The computation of the regularization itself must be implemented in the score_impl function.

A complete example of a re-implementation of the l2-regularization is given below:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import jax.numpy as npx
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.naive_bayes import GaussianNB

from ceml.sklearn import generate_counterfactual
from ceml.backend.jax.costfunctions import CostFunctionDifferentiableJax


# Custom implementation of the l2-regularization. Note that this regularization is differentiable
class MyRegularization(CostFunctionDifferentiableJax):
    def __init__(self, x_orig):
        self.x_orig = x_orig

        super(MyRegularization, self).__init__()
    
    def score_impl(self, x):
        return npx.sum(npx.square(x - self.x_orig)) # Note: This expression must be written in jax and it must be differentiable!


if __name__ == "__main__":
    # Load data
    X, y = load_iris(True)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=4242)

    # Whitelist of features - list of features we can change/use when computing a counterfactual 
    features_whitelist = None   # All features can be used.

    # Create and fit the model
    model = GaussianNB()   # Note that ceml requires: multi_class='multinomial'
    model.fit(X_train, y_train)
    
    # Select data point for explaining its prediction
    x = X_test[1,:]
    print("Prediction on x: {0}".format(model.predict([x])))

    # Create custom regularization function
    regularization = MyRegularization(x)

    # Compute counterfactual
    print("\nCompute counterfactual ....")
    print(generate_counterfactual(model, x, y_target=0, features_whitelist=features_whitelist, regularization=regularization, optimizer="bfgs"))

Custom loss function

In order to use a custom loss function we have to do three things:

  1. Implement the loss function. This is the same as implementing a custom regularization - a regularization is a loss function that works on the input rather than on the output.

  2. Derive a child class from the model class and overwrite the get_loss function to use our custom loss function.

  3. Derive a child class from the counterfactual class of the model and overwrite the rebuild_model function to use our model from the previous step.

A complete example of using a custom loss for a linear regression model is given below:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import jax.numpy as npx
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.linear_model import Ridge


from ceml.sklearn import generate_counterfactual
from ceml.sklearn import LinearRegression, LinearRegressionCounterfactual
from ceml.backend.jax.costfunctions import CostFunctionDifferentiableJax


# Custom implementation of the l2-regularization. Note that this regularization is differentiable.
class MyLoss(CostFunctionDifferentiableJax):
    def __init__(self, input_to_output, y_target):
        self.y_target = y_target

        super(MyLoss, self).__init__(input_to_output)
    
    def score_impl(self, y):
        return npx.abs(y - y_target)**4

# Derive a new class from ceml.sklearn.linearregression.LinearRegression and overwrite the get_loss method to use our custom loss MyLoss
class LinearRegressionWithMyLoss(LinearRegression):
    def __init__(self, model):
        super(LinearRegressionWithMyLoss, self).__init__(model)

    def get_loss(self, y_target, pred=None):
        if pred is None:
            return MyLoss(self.predict, y_target)
        else:
            return MyLoss(pred, y_target)

# Derive a new class from ceml.sklearn.linearregression.LinearRegressionCounterfactual that uses our new linear regression wrapper LinearRegressionWithMyLoss for computing counterfactuals
class MyLinearRegressionCounterfactual(LinearRegressionCounterfactual):
    def __init__(self, model):
        super(MyLinearRegressionCounterfactual, self).__init__(model)

    def rebuild_model(self, model):
        return LinearRegressionWithMyLoss(model)


if __name__ == "__main__":
    # Load data
    X, y = load_boston(True)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=4242)

    # Whitelist of features - list of features we can change/use when computing a counterfactual 
    features_whitelist = None   # All features can be used.

    # Create and fit model
    model = Ridge()
    model.fit(X_train, y_train)

    # Select data point for explaining its prediction
    x = X_test[1,:]
    print("Prediction on x: {0}".format(model.predict([x])))

    # Compute counterfactual
    print("\nCompute counterfactual ....")
    y_target = 25.0
    done = lambda z: np.abs(y_target - z) <= 0.5     # Since we might not be able to achieve `y_target` exactly, we tell ceml that we are happy if we do not deviate more than 0.5 from it.
    
    cf = MyLinearRegressionCounterfactual(model)    # Since we are using our own loss function, we can no longer use standard method generate_counterfactual 
    print(cf.compute_counterfactual(x, y_target=y_target, features_whitelist=features_whitelist, regularization="l2", C=1.0, optimizer="bfgs", done=done))

Add a custom optimizer

We can use a custom optimization method by:

  1. Dervice a new class from ceml.optim.optimizer.Optimizer and implement the custom optimization method.

  2. Create a new instance of this class and pass it as the argument for the optimizer parameter to the function ceml.sklearn.models.generate_counterfactual() (or any other function that computes a counterfactual).

A complete example of using a custom optimization method for computing counterfactuals from a logistic regression model is given below:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from scipy.optimize import minimize

from ceml.sklearn import generate_counterfactual
from ceml.optim import Optimizer


# Custom optimization method that simply calls the BFGS optimizer from scipy
class MyOptimizer(Optimizer):
    def __init__(self):
        self.f = None
        self.f_grad = None
        self.x0 = None
        self.tol = None
        self.max_iter = None

        super(MyOptimizer, self).__init__()
    
    def init(self, f, f_grad, x0, tol=None, max_iter=None):
        self.f = f
        self.f_grad = f_grad
        self.x0 = x0
        self.tol = tol
        self.max_iter = max_iter

    def is_grad_based(self):
        return True
    
    def __call__(self):
        optimum = minimize(fun=self.f, x0=self.x0, jac=self.f_grad, tol=self.tol, options={'maxiter': self.max_iter}, method="BFGS")
        return np.array(optimum["x"])


if __name__ == "__main__":
    # Load data
    X, y = load_iris(True)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=4242)

    # Create and fit model
    model = LogisticRegression(solver='lbfgs', multi_class='multinomial')
    model.fit(X_train, y_train)

    # Select data point for explaining its prediction
    x = X_test[1,:]
    print("Prediction on x: {0}".format(model.predict([x])))

    # Compute counterfactual by using our custom optimizer 'MyOptimizer'
    print("\nCompute counterfactual ....")
    print(generate_counterfactual(model, x, y_target=0, optimizer=MyOptimizer(), features_whitelist=None, regularization="l1", C=0.5))