ceml.backend.jax¶
ceml.backend.jax.costfunctions¶
-
class
ceml.backend.jax.costfunctions.costfunctions.
CostFunctionDifferentiableJax
(input_to_output=None, **kwds)¶ Bases:
ceml.costfunctions.costfunctions.CostFunctionDifferentiable
Base class of differentiable cost functions implemented in jax.
-
grad
(mask=None)¶ Computes the gradient with respect to the input.
- Parameters
mask (numpy.array, optional) –
A mask that is multiplied elementwise to the gradient - can be used to mask some features/dimensions.
If mask is None, the gradient is not masked.
The default is None.
- Returns
The gradient.
- Return type
callable
-
-
class
ceml.backend.jax.costfunctions.costfunctions.
DummyCost
(**kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Dummy cost function - always returns zero.
-
score_impl
(x)¶ Computes the loss - always returns zero.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.
L1Cost
(x_orig, input_to_output=None, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
L1 cost function.
-
score_impl
(x)¶ Computes the loss - l1 norm.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.
L2Cost
(x_orig, input_to_output=None, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
L2 cost function.
-
score_impl
(x)¶ Computes the loss - l2 norm.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.
LMadCost
(x_orig, mad, input_to_output=None, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Manhattan distance weighted feature-wise with the inverse median absolute deviation (MAD).
-
score_impl
(x)¶ Computes the loss.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.
MinOfListDistCost
(dist, samples, input_to_output=None, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Minimum distance to a list of data points.
-
score_impl
(x)¶ Computes the loss.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.
MinOfListDistExCost
(omegas, samples, input_to_output=None, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Minimum distance to a list of data points.
In contrast to
MinOfListDistCost
,MinOfListDistExCost
uses a user defined metric matrix (distortion of the Euclidean distance).-
score_impl
(x)¶ Computes the loss.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.
NegLogLikelihoodCost
(input_to_output, y_target, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Negative-log-likelihood cost function.
-
score_impl
(y)¶ Computes the loss - negative-log-likelihood.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.
RegularizedCost
(penalize_input, penalize_output, C=1.0, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Regularized cost function.
-
score_impl
(x)¶ Computes the loss.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.
SquaredError
(input_to_output, y_target, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Squared error cost function.
-
score_impl
(y)¶ Computes the loss - squared error.
-
-
class
ceml.backend.jax.costfunctions.costfunctions.
TopKMinOfListDistCost
(dist, samples, k, input_to_output=None, **kwds)¶ Bases:
ceml.backend.jax.costfunctions.costfunctions.CostFunctionDifferentiableJax
Computes the sum of the distances to the k closest samples.
-
score_impl
(x)¶ Computes the loss.
-
ceml.backend.jax.preprocessing¶
-
class
ceml.backend.jax.preprocessing.
AffinePreprocessing
(A, b, **kwds)¶ Bases:
object
Wrapper for an affine mapping (preprocessing)
-
class
ceml.backend.jax.preprocessing.
MinMaxScaler
(min_, scale, **kwds)¶ Bases:
ceml.model.model.Model
,ceml.backend.jax.preprocessing.affine_preprocessing.AffinePreprocessing
Wrapper for the min max scaler.
-
predict
(x)¶ Computes the forward pass.
-
-
class
ceml.backend.jax.preprocessing.
Model
(**kwds)¶ Bases:
abc.ABC
Base class of a model.
Note
The class
Model
can not be instantiated because it contains an abstract method.-
abstract
predict
(x)¶ Predict the output of a given input.
Abstract method for computing a prediction.
Note
All derived classes must implement this method.
-
abstract
-
class
ceml.backend.jax.preprocessing.
Normalizer
(**kwds)¶ Bases:
ceml.model.model.Model
Wrapper for the normalizer.
-
predict
(x)¶ Computes the forward pass.
-
-
class
ceml.backend.jax.preprocessing.
PCA
(w, **kwds)¶ Bases:
ceml.model.model.Model
,ceml.backend.jax.preprocessing.affine_preprocessing.AffinePreprocessing
Wrapper for PCA - Principle component analysis.
-
predict
(x)¶ Computes the forward pass.
-
-
class
ceml.backend.jax.preprocessing.
PolynomialFeatures
(powers, **kwds)¶ Bases:
ceml.model.model.Model
Wrapper for polynomial feature transformation.
-
predict
(x)¶ Computes the forward pass.
-
-
class
ceml.backend.jax.preprocessing.
StandardScaler
(mu, sigma, **kwds)¶ Bases:
ceml.model.model.Model
,ceml.backend.jax.preprocessing.affine_preprocessing.AffinePreprocessing
Wrapper for the standard scaler.
-
predict
(x)¶ Computes the forward pass.
-
-
ceml.backend.jax.preprocessing.
reduce
(function, sequence[, initial]) → value¶ Apply a function of two arguments cumulatively to the items of a sequence, from left to right, so as to reduce the sequence to a single value. For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates ((((1+2)+3)+4)+5). If initial is present, it is placed before the items of the sequence in the calculation, and serves as a default when the sequence is empty.