ceml.torch

ceml.torch.counterfactual

class ceml.torch.counterfactual.TorchCounterfactual(model, device=torch.device, **kwds)

Bases: ceml.model.counterfactual.Counterfactual

Class for computing a counterfactual of a PyTorch model.

See parent class ceml.model.counterfactual.Counterfactual.

Parameters
  • model (instance of torch.nn.Module and ceml.model.model.ModelWithLoss) – The PyTorch model that is used for computing counterfactuals. The model has to be wrapped inside a class that is derived from the classes torch.nn.Module and ceml.model.model.ModelWithLoss.

  • device (torch.device) –

    Specifies the hardware device (e.g. cpu or gpu) we are working on.

    The default is torch.device(“cpu”).

Raises

TypeError – If model is not an instance of torch.nn.Module and ceml.model.model.ModelWithLoss.

compute_counterfactual(x, y_target, features_whitelist=None, regularization=None, C=1.0, optimizer='nelder-mead', optimizer_args=None, return_as_dict=True, done=None)

Computes a counterfactual of a given input x.

Parameters
  • x (numpy.ndarray) – The input x whose prediction has to be explained.

  • y_target (int or float) – The requested prediction of the counterfactual.

  • feature_whitelist (list(int), optional) –

    List of feature indices (dimensions of the input space) that can be used when computing the counterfactual.

    If feature_whitelist is None, all features can be used.

    The default is None.

  • regularization (str or callable, optional) –

    Regularizer of the counterfactual. Penalty for deviating from the original input x.

    Supported values:

    • l1: Penalizes the absolute deviation.

    • l2: Penalizes the squared deviation.

    You can use your own custom penalty function by setting regularization to a callable that can be called on a potential counterfactual and returns a scalar.

    If regularization is None, no regularization is used.

    The default is “l1”.

  • C (float or list(float), optional) –

    The regularization strength. If C is a list, all values in C are tried and as soon as a counterfactual is found, this counterfactual is returned and no other values of C are tried.

    C is ignored if no regularization is used (regularization=None).

    The default is 1.0

  • optimizer (str or class that is derived from torch.optim.Optimizer, optional) –

    Name/Identifier of the optimizer that is used for computing the counterfactual. See ceml.optim.optimizer.desc_to_optim() for details.

    As an alternative, any optimizer from PyTorch can be used - optimizer must be class that is derived from torch.optim.Optimizer.

    The default is “nelder-mead”.

  • optimizer_args (dict, optional) –

    Dictionary containing additional parameters for the optimization algorithm.

    Supported parameters (keys):

    • args: Arguments of the optimization algorithm (e.g. learning rate, momentum, …)

    • lr_scheduler: Learning rate scheduler (see torch.optim.lr_scheduler)

    • lr_scheduler_args: Arguments of the learning rate scheduler

    • tol: Tolerance for termination

    • max_iter: Maximum number of iterations

    If optimizer_args is None or if some parameters are missing, default values are used.

    The default is None.

    Note

    The parameters tol and max_iter are passed to all optimization algorithms. Whereas the other parameters are only passed to PyTorch optimizers.

  • return_as_dict (boolean, optional) –

    If True, returns the counterfactual, its prediction and the needed changes to the input as dictionary. If False, the results are returned as a triple.

    The default is True.

  • done (callable, optional) –

    A callable that returns True if a counterfactual with a given output/prediction is accepted and False otherwise.

    If done is None, the output/prediction of the counterfactual must match y_target exactly.

    The default is None.

    Note

    In case of a regression it might not always be possible to achieve a given output/prediction exactly.

Returns

A dictionary where the counterfactual is stored in ‘x_cf’, its prediction in ‘y_cf’ and the changes to the original input in ‘delta’.

(x_cf, y_cf, delta) : triple if return_as_dict is False

Return type

dict or triple

Raises

Exception – If no counterfactual was found.

ceml.torch.counterfactual.generate_counterfactual(model, x, y_target, device=torch.device, features_whitelist=None, regularization=None, C=1.0, optimizer='nelder-mead', optimizer_args=None, return_as_dict=True, done=None)

Computes a counterfactual of a given input x.

Parameters
  • model (instance of torch.nn.Module and ceml.model.model.ModelWithLoss) – The PyTorch model that is used for computing the counterfactual.

  • x (numpy.ndarray) – The input x whose prediction has to be explained.

  • y_target (int or float) – The requested prediction of the counterfactual.

  • device (torch.device) –

    Specifies the hardware device (e.g. cpu or gpu) we are working on.

    The default is torch.device(“cpu”).

  • feature_whitelist (list(int), optional) –

    List of feature indices (dimensions of the input space) that can be used when computing the counterfactual.

    If feature_whitelist is None, all features can be used.

    The default is None.

  • regularization (str or callable, optional) –

    Regularizer of the counterfactual. Penalty for deviating from the original input x.

    Supported values:

    • l1: Penalizes the absolute deviation.

    • l2: Penalizes the squared deviation.

    You can use your own custom penalty function by setting regularization to a callable that can be called on a potential counterfactual and returns a scalar.

    If regularization is None, no regularization is used.

    The default is “l1”.

  • C (float or list(float), optional) –

    The regularization strength. If C is a list, all values in C are tried and as soon as a counterfactual is found, this counterfactual is returned and no other values of C are tried.

    If no regularization is used (regularization=None), C is ignored.

    The default is 1.0

  • optimizer (str or class that is derived from torch.optim.Optimizer, optional) –

    Name/Identifier of the optimizer that is used for computing the counterfactual. See ceml.optim.optimizer.desc_to_optim() for details.

    As an alternative, any optimizer from PyTorch can be used - optimizer must be class that is derived from torch.optim.Optimizer.

    The default is “nelder-mead”.

  • optimizer_args (dict, optional) –

    Dictionary containing additional parameters for the optimization algorithm.

    Supported parameters (keys):

    • args: Arguments of the optimization algorithm (e.g. learning rate, momentum, …)

    • lr_scheduler: Learning rate scheduler (see torch.optim.lr_scheduler)

    • lr_scheduler_args: Arguments of the learning rate scheduler

    • tol: Tolerance for termination

    • max_iter: Maximum number of iterations

    If optimizer_args is None or if some parameters are missing, default values are used.

    The default is None.

    Note

    The parameters tol and max_iter are passed to all optimization algorithms. Whereas the other parameters are only passed to PyTorch optimizers.

  • return_as_dict (boolean, optional) –

    If True, returns the counterfactual, its prediction and the needed changes to the input as dictionary. If False, the results are returned as a triple.

    The default is True.

  • done (callable, optional) –

    A callable that returns True if a counterfactual with a given output/prediction is accepted and False otherwise.

    If done is None, the output/prediction of the counterfactual must match y_target exactly.

    The default is None.

    Note

    In case of a regression it might not always be possible to achieve a given output/prediction exactly.

Returns

A dictionary where the counterfactual is stored in ‘x_cf’, its prediction in ‘y_cf’ and the changes to the original input in ‘delta’.

(x_cf, y_cf, delta) : triple if return_as_dict is False

Return type

dict or triple

ceml.torch.utils

ceml.torch.utils.build_regularization_loss(regularization, x, input_wrapper=None)

Builds a regularization loss.

Parameters
  • desc (str, callable or None) –

    Description of the regularization, a callable regularization (not mandatory but we recommend to put your custom regularization into a class and make it a child of ceml.costfunctions.costfunctions.CostFunction or ceml.costfunctions.costfunctions.DifferentiableCostFunction if your cost function is differentiable) or None if no regularization is desired.

    See ceml.torch.utils.desc_to_regcost() for a list of supported descriptions.

    If no regularization is requested, an instance of ceml.backend.torch.costfunctions.costfunctions.DummyCost is returned. This cost function always outputs zero, no matter what the input is.

  • x (numpy.array) – The original input from which we do not want to deviate much.

  • input_wrapper (callable, optional) –

    Converts the input (e.g. if we want to exclude some features/dimensions, we might have to include these missing features before applying any function to it).

    If input_wrapper is None, input is passed without any modifications.

    The default is None.

Returns

An instance of ceml.costfunctions.costfunctions.CostFunction or the user defined, callable, regularization.

Return type

callable

Raises

TypeError – If regularization has an invalid type.

ceml.torch.utils.desc_to_dist(desc)

Converts a description of a distance metric into a torch function.

Supported descriptions:

  • l1: l1-norm

  • l2: l2-norm

Parameters

desc (str) – Description of the distance metric.

Returns

The distance function implemented as a torch function.

Return type

callable

Raises

ValueError – If desc contains an invalid description.

ceml.torch.utils.desc_to_regcost(desc, x, input_wrapper)

Converts a description of a regularization into a torch function.

Supported descriptions:

  • l1: l1-regularization

  • l2: l2-regularization

Parameters
  • desc (str) – Description of the distance metric.

  • x (numpy.array) – The original input from which we do not want to deviate much.

  • input_wrapper (callable) –

    Converts the input (e.g. if we want to exclude some features/dimensions, we might have to include these missing features before applying any function to it).

    Is ignored!

Returns

The regularization function implemented as a torch function.

Return type

callable

Raises

ValueError – If desc contains an invalid description.