ceml.backend.jax¶
ceml.backend.jax.costfunctions¶
-
class
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax(input_to_output=None, **kwds)¶ Bases:
ceml.costfunctions.costfunctions.CostFunctionDifferentiableBase class of differentiable cost functions implemented in jax.
-
grad(mask=None)¶ Computes the gradient with respect to the input.
- Parameters
mask (numpy.array, optional) –
A mask that is multiplied elementwise to the gradient - can be used to mask some features/dimensions.
If mask is None, the gradient is not masked.
The default is None.
- Returns
The gradient.
- Return type
callable
-
-
class
ceml.backend.jax.costfunctions.costfunctions.DummyCost(**kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJaxDummy cost function - always returns zero.
-
score_impl(x)¶ Computes the loss - always returns zero.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.L1Cost(x_orig, input_to_output=None, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJaxL1 cost function.
-
score_impl(x)¶ Computes the loss - l1 norm.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.L2Cost(x_orig, input_to_output=None, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJaxL2 cost function.
-
score_impl(x)¶ Computes the loss - l2 norm.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.LMadCost(x_orig, mad, input_to_output=None, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJaxManhattan distance weighted feature-wise with the inverse median absolute deviation (MAD).
-
score_impl(x)¶ Computes the loss.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.MinOfListDistCost(dist, samples, input_to_output=None, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJaxMinimum distance to a list of data points.
-
score_impl(x)¶ Computes the loss.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.MinOfListDistExCost(omegas, samples, input_to_output=None, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJaxMinimum distance to a list of data points.
In contrast to
MinOfListDistCost,MinOfListDistExCostuses a user defined metric matrix (distortion of the Euclidean distance).-
score_impl(x)¶ Computes the loss.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.NegLogLikelihoodCost(input_to_output, y_target, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJaxNegative-log-likelihood cost function.
-
score_impl(y)¶ Computes the loss - negative-log-likelihood.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.RegularizedCost(penalize_input, penalize_output, C=1.0, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJaxRegularized cost function.
-
score_impl(x)¶ Computes the loss.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.SquaredError(input_to_output, y_target, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJaxSquared error cost function.
-
score_impl(y)¶ Computes the loss - squared error.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.TopKMinOfListDistCost(dist, samples, k, input_to_output=None, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJaxComputes the sum of the distances to the k closest samples.
-
score_impl(x)¶ Computes the loss.
-
ceml.backend.jax.preprocessing¶
-
class
ceml.backend.jax.preprocessing.AffinePreprocessing(A, b, **kwds)¶ Bases:
objectWrapper for an affine mapping (preprocessing)
-
class
ceml.backend.jax.preprocessing.MinMaxScaler(min_, scale, **kwds)¶ Bases:
ceml.model.model.Model,ceml.backend.jax.preprocessing.affine_preprocessing.AffinePreprocessingWrapper for the min max scaler.
-
predict(x)¶ Computes the forward pass.
-
-
class
ceml.backend.jax.preprocessing.Model(**kwds)¶ Bases:
abc.ABCBase class of a model.
Note
The class
Modelcan not be instantiated because it contains an abstract method.-
abstract
predict(x)¶ Predict the output of a given input.
Abstract method for computing a prediction.
Note
All derived classes must implement this method.
-
abstract
-
class
ceml.backend.jax.preprocessing.Normalizer(**kwds)¶ Bases:
ceml.model.model.ModelWrapper for the normalizer.
-
predict(x)¶ Computes the forward pass.
-
-
class
ceml.backend.jax.preprocessing.PCA(w, **kwds)¶ Bases:
ceml.model.model.Model,ceml.backend.jax.preprocessing.affine_preprocessing.AffinePreprocessingWrapper for PCA - Principle component analysis.
-
predict(x)¶ Computes the forward pass.
-
-
class
ceml.backend.jax.preprocessing.PolynomialFeatures(powers, **kwds)¶ Bases:
ceml.model.model.ModelWrapper for polynomial feature transformation.
-
predict(x)¶ Computes the forward pass.
-
-
class
ceml.backend.jax.preprocessing.StandardScaler(mu, sigma, **kwds)¶ Bases:
ceml.model.model.Model,ceml.backend.jax.preprocessing.affine_preprocessing.AffinePreprocessingWrapper for the standard scaler.
-
predict(x)¶ Computes the forward pass.
-
-
ceml.backend.jax.preprocessing.reduce(function, sequence[, initial]) → value¶ Apply a function of two arguments cumulatively to the items of a sequence, from left to right, so as to reduce the sequence to a single value. For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates ((((1+2)+3)+4)+5). If initial is present, it is placed before the items of the sequence in the calculation, and serves as a default when the sequence is empty.