ceml.backend.jax
ceml.backend.jax.costfunctions
- class ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax(input_to_output=None, **kwds)
Bases:
ceml.costfunctions.costfunctions.CostFunctionDifferentiable
Base class of differentiable cost functions implemented in jax.
- grad(mask=None)
Computes the gradient with respect to the input.
- Parameters
mask (numpy.array, optional) –
A mask that is multiplied elementwise to the gradient - can be used to mask some features/dimensions.
If mask is None, the gradient is not masked.
The default is None.
- Returns
The gradient.
- Return type
callable
- class ceml.backend.jax.costfunctions.costfunctions.DummyCost(**kwds)
Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Dummy cost function - always returns zero.
- score_impl(x)
Computes the loss - always returns zero.
- class ceml.backend.jax.costfunctions.costfunctions.L1Cost(x_orig, input_to_output=None, **kwds)
Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
L1 cost function.
- score_impl(x)
Computes the loss - l1 norm.
- class ceml.backend.jax.costfunctions.costfunctions.L2Cost(x_orig, input_to_output=None, **kwds)
Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
L2 cost function.
- score_impl(x)
Computes the loss - l2 norm.
- class ceml.backend.jax.costfunctions.costfunctions.LMadCost(x_orig, mad, input_to_output=None, **kwds)
Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Manhattan distance weighted feature-wise with the inverse median absolute deviation (MAD).
- score_impl(x)
Computes the loss.
- class ceml.backend.jax.costfunctions.costfunctions.MinOfListDistCost(dist, samples, input_to_output=None, **kwds)
Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Minimum distance to a list of data points.
- score_impl(x)
Computes the loss.
- class ceml.backend.jax.costfunctions.costfunctions.MinOfListDistExCost(omegas, samples, input_to_output=None, **kwds)
Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Minimum distance to a list of data points.
In contrast to
MinOfListDistCost
,MinOfListDistExCost
uses a user defined metric matrix (distortion of the Euclidean distance).- score_impl(x)
Computes the loss.
- class ceml.backend.jax.costfunctions.costfunctions.NegLogLikelihoodCost(input_to_output, y_target, **kwds)
Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Negative-log-likelihood cost function.
- score_impl(y)
Computes the loss - negative-log-likelihood.
- class ceml.backend.jax.costfunctions.costfunctions.RegularizedCost(penalize_input, penalize_output, C=1.0, **kwds)
Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Regularized cost function.
- score_impl(x)
Computes the loss.
- class ceml.backend.jax.costfunctions.costfunctions.SquaredError(input_to_output, y_target, **kwds)
Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Squared error cost function.
- score_impl(y)
Computes the loss - squared error.
- class ceml.backend.jax.costfunctions.costfunctions.TopKMinOfListDistCost(dist, samples, k, input_to_output=None, **kwds)
Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Computes the sum of the distances to the k closest samples.
- score_impl(x)
Computes the loss.
ceml.backend.jax.preprocessing
- class ceml.backend.jax.preprocessing.AffinePreprocessing(A, b, **kwds)
Bases:
object
Wrapper for an affine mapping (preprocessing)
- class ceml.backend.jax.preprocessing.MinMaxScaler(min_, scale, **kwds)
Bases:
ceml.model.model.Model
,ceml.backend.jax.preprocessing.affine_preprocessing.AffinePreprocessing
Wrapper for the min max scaler.
- predict(x)
Computes the forward pass.
- class ceml.backend.jax.preprocessing.Model(**kwds)
Bases:
abc.ABC
Base class of a model.
Note
The class
Model
can not be instantiated because it contains an abstract method.- abstract predict(x)
Predict the output of a given input.
Abstract method for computing a prediction.
Note
All derived classes must implement this method.
- class ceml.backend.jax.preprocessing.Normalizer(**kwds)
Bases:
ceml.model.model.Model
Wrapper for the normalizer.
- predict(x)
Computes the forward pass.
- class ceml.backend.jax.preprocessing.PCA(w, **kwds)
Bases:
ceml.model.model.Model
,ceml.backend.jax.preprocessing.affine_preprocessing.AffinePreprocessing
Wrapper for PCA - Principle component analysis.
- predict(x)
Computes the forward pass.
- class ceml.backend.jax.preprocessing.PolynomialFeatures(powers, **kwds)
Bases:
ceml.model.model.Model
Wrapper for polynomial feature transformation.
- predict(x)
Computes the forward pass.
- class ceml.backend.jax.preprocessing.StandardScaler(mu, sigma, **kwds)
Bases:
ceml.model.model.Model
,ceml.backend.jax.preprocessing.affine_preprocessing.AffinePreprocessing
Wrapper for the standard scaler.
- predict(x)
Computes the forward pass.
- ceml.backend.jax.preprocessing.reduce(function, sequence[, initial]) → value
Apply a function of two arguments cumulatively to the items of a sequence, from left to right, so as to reduce the sequence to a single value. For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates ((((1+2)+3)+4)+5). If initial is present, it is placed before the items of the sequence in the calculation, and serves as a default when the sequence is empty.