ceml.backend.jax

ceml.backend.jax.costfunctions

class ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax(input_to_output=None, **kwds)

Bases: ceml.costfunctions.costfunctions.CostFunctionDifferentiable

Base class of differentiable cost functions implemented in jax.

grad(mask=None)

Computes the gradient with respect to the input.

Parameters

mask (numpy.array, optional) –

A mask that is multiplied elementwise to the gradient - can be used to mask some features/dimensions.

If mask is None, the gradient is not masked.

The default is None.

Returns

The gradient.

Return type

callable

class ceml.backend.jax.costfunctions.costfunctions.DummyCost(**kwds)

Bases: ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax

Dummy cost function - always returns zero.

score_impl(x)

Computes the loss - always returns zero.

class ceml.backend.jax.costfunctions.costfunctions.L1Cost(x_orig, input_to_output=None, **kwds)

Bases: ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax

L1 cost function.

score_impl(x)

Computes the loss - l1 norm.

class ceml.backend.jax.costfunctions.costfunctions.L2Cost(x_orig, input_to_output=None, **kwds)

Bases: ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax

L2 cost function.

score_impl(x)

Computes the loss - l2 norm.

class ceml.backend.jax.costfunctions.costfunctions.LMadCost(x_orig, mad, input_to_output=None, **kwds)

Bases: ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax

Manhattan distance weighted feature-wise with the inverse median absolute deviation (MAD).

score_impl(x)

Computes the loss.

class ceml.backend.jax.costfunctions.costfunctions.MinOfListDistCost(dist, samples, input_to_output=None, **kwds)

Bases: ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax

Minimum distance to a list of data points.

score_impl(x)

Computes the loss.

class ceml.backend.jax.costfunctions.costfunctions.MinOfListDistExCost(omegas, samples, input_to_output=None, **kwds)

Bases: ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax

Minimum distance to a list of data points.

In contrast to MinOfListDistCost, MinOfListDistExCost uses a user defined metric matrix (distortion of the Euclidean distance).

score_impl(x)

Computes the loss.

class ceml.backend.jax.costfunctions.costfunctions.NegLogLikelihoodCost(input_to_output, y_target, **kwds)

Bases: ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax

Negative-log-likelihood cost function.

score_impl(y)

Computes the loss - negative-log-likelihood.

class ceml.backend.jax.costfunctions.costfunctions.RegularizedCost(penalize_input, penalize_output, C=1.0, **kwds)

Bases: ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax

Regularized cost function.

score_impl(x)

Computes the loss.

class ceml.backend.jax.costfunctions.costfunctions.SquaredError(input_to_output, y_target, **kwds)

Bases: ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax

Squared error cost function.

score_impl(y)

Computes the loss - squared error.

class ceml.backend.jax.costfunctions.costfunctions.TopKMinOfListDistCost(dist, samples, k, input_to_output=None, **kwds)

Bases: ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax

Computes the sum of the distances to the k closest samples.

score_impl(x)

Computes the loss.

ceml.backend.jax.preprocessing

class ceml.backend.jax.preprocessing.AffinePreprocessing(A, b, **kwds)

Bases: object

Wrapper for an affine mapping (preprocessing)

class ceml.backend.jax.preprocessing.MinMaxScaler(min_, scale, **kwds)

Bases: ceml.model.model.Model, ceml.backend.jax.preprocessing.affine_preprocessing.AffinePreprocessing

Wrapper for the min max scaler.

predict(x)

Computes the forward pass.

class ceml.backend.jax.preprocessing.Model(**kwds)

Bases: abc.ABC

Base class of a model.

Note

The class Model can not be instantiated because it contains an abstract method.

abstract predict(x)

Predict the output of a given input.

Abstract method for computing a prediction.

Note

All derived classes must implement this method.

class ceml.backend.jax.preprocessing.Normalizer(**kwds)

Bases: ceml.model.model.Model

Wrapper for the normalizer.

predict(x)

Computes the forward pass.

class ceml.backend.jax.preprocessing.PCA(w, **kwds)

Bases: ceml.model.model.Model, ceml.backend.jax.preprocessing.affine_preprocessing.AffinePreprocessing

Wrapper for PCA - Principle component analysis.

predict(x)

Computes the forward pass.

class ceml.backend.jax.preprocessing.PolynomialFeatures(powers, **kwds)

Bases: ceml.model.model.Model

Wrapper for polynomial feature transformation.

predict(x)

Computes the forward pass.

class ceml.backend.jax.preprocessing.StandardScaler(mu, sigma, **kwds)

Bases: ceml.model.model.Model, ceml.backend.jax.preprocessing.affine_preprocessing.AffinePreprocessing

Wrapper for the standard scaler.

predict(x)

Computes the forward pass.

ceml.backend.jax.preprocessing.reduce(function, sequence[, initial])value

Apply a function of two arguments cumulatively to the items of a sequence, from left to right, so as to reduce the sequence to a single value. For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates ((((1+2)+3)+4)+5). If initial is present, it is placed before the items of the sequence in the calculation, and serves as a default when the sequence is empty.