Performance Measures
In MLJ loss functions, scoring rules, sensitivities, and so on, are collectively referred to as measures. These include re-exported loss functions from the LossFunctions.jl library, overloaded to behave the same way as the built-in measures.
To see list all measures, run measures()
. Further measures for probabilistic predictors, such as proper scoring rules, and for constructing multi-target product measures, are planned. If you'd like to see measure added to MLJ, post a comment here
Note for developers: The measures interface and the built-in measures described here are defined in MLJBase, but will ultimately live in a separate package.
Using built-in measures
These measures all have the common calling syntax
measure(ŷ, y)
or
measure(ŷ, y, w)
where y
iterates over observations of some target variable, and ŷ
iterates over predictions (Distribution
or Sampler
objects in the probabilistic case). Here w
is an optional vector of sample weights, or a dictionary of class weights, when these are supported by the measure.
julia> using MLJ
julia> y = [1, 2, 3, 4];
julia> ŷ = [2, 3, 3, 3];
julia> w = [1, 2, 2, 1];
julia> rms(ŷ, y) # reports an aggregrate loss
0.8660254037844386
julia> l2(ŷ, y, w) # reports per observation losses
4-element Array{Int64,1}:
1
2
0
1
julia> y = coerce(["male", "female", "female"], Multiclass)
3-element CategoricalArray{String,1,UInt32}:
"male"
"female"
"female"
julia> d = UnivariateFinite(["male", "female"], [0.55, 0.45], pool=y);
julia> ŷ = [d, d, d];
julia> log_loss(ŷ, y)
3-element Array{Float64,1}:
0.7985076962177716
0.5978370007556204
0.5978370007556204
The measures rms
, l2
and log_loss
illustrated here are actually instances of measure types. For, example, l2 = LPLoss(p=2)
and log_loss = LogLoss() = LogLoss(tol=eps())
. Common aliases are provided:
julia> cross_entropy
LogLoss(
tol = 2.220446049250313e-16) @362
Traits and custom measures
Notice that l1
reports per-sample evaluations, while rms
only reports an aggregated result. This and other behavior can be gleaned from measure traits which are summarized by the info
method:
julia> info(l1)
`LPLoss` - lp loss type with instances `l1`, `l2`.
(name = "LPLoss",
instances = ["l1", "l2"],
human_name = "lp loss",
target_scitype = Union{AbstractArray{Continuous,1}, AbstractArray{Count,1}},
supports_weights = true,
prediction_type = :deterministic,
orientation = :loss,
reports_each_observation = true,
aggregation = MLJBase.Mean(),
is_feature_dependent = false,
docstring = "`LPLoss` - lp loss type with instances `l1`, `l2`. ",
distribution_type = missing,
supports_class_weights = false,)
Query the doc-string for a measure using the name of its type:
julia> rms
RootMeanSquaredError() @221
julia> @doc RootMeanSquaredError # same as `?RootMeanSqauredError
MLJBase.RootMeanSquaredError
A measure type for root mean squared error, which includes the instance(s),
rms, rmse, root_mean_squared_error.
RootMeanSquaredError()(ŷ, y)
RootMeanSquaredError()(ŷ, y, w)
Evaluate the root mean squared error on observations ŷ, given ground truth
values y. Optionally specify per-sample weights, w.
\text{root mean squared error} = \sqrt{n^{-1}∑ᵢ|yᵢ-ŷᵢ|^2} or \text{root
mean squared error} = \sqrt{\frac{∑ᵢwᵢ|yᵢ-ŷᵢ|^2}{∑ᵢwᵢ}}
Requires scitype(y) to be a subtype of
Union{AbstractArray{ScientificTypes.Continuous,1},
AbstractArray{ScientificTypes.Count,1}}; ŷ must be a deterministic
prediction.
For more information, run info(RootMeanSquaredError).
Use measures()
to list all measures, and measures(conditions...)
to search for measures with given traits (as you would query models). The trait instances
list the actual callable instances of a given measure type (typically aliases for the default instance).
MLJBase.measures
— Methodmeasures()
List all measures as named-tuples keyed on measure traits.
measures(filters...)
List all measures m
for which filter(m)
is true, for each filter
in filters
.
measures(matching(y))
List all measures compatible with the target y
.
measures(needle::Union{AbstractString,Regex}
List all measures with needle
in a measure's name
or docstring
.
Example
Find all classification measures supporting sample weights:
measures(m -> m.target_scitype <: AbstractVector{<:Finite} &&
m.supports_weights)
Find all classification measures where the number of classes is three:
y = categorical(1:3)
measures(matching(y))
Find all measures in the rms
family:
measures("rms")
A user-defined measure in MLJ can be passed to the evaluate!
method, and elsewhere in MLJ, provided it is a function or callable object conforming to the above syntactic conventions. By default, a custom measure is understood to:
be a loss function (rather than a score)
report an aggregated value (rather than per-sample evaluations)
be feature-independent
To override this behaviour one simply overloads the appropriate trait, as shown in the following examples:
julia> y = [1, 2, 3, 4];
julia> ŷ = [2, 3, 3, 3];
julia> w = [1, 2, 2, 1];
julia> my_loss(ŷ, y) = maximum((ŷ - y).^2);
julia> my_loss(ŷ, y)
1
julia> my_per_sample_loss(ŷ, y) = abs.(ŷ - y);
julia> MLJ.reports_each_observation(::typeof(my_per_sample_loss)) = true;
julia> my_per_sample_loss(ŷ, y)
4-element Array{Int64,1}:
1
1
0
1
julia> my_weighted_score(ŷ, y) = 1/mean(abs.(ŷ - y));
julia> my_weighted_score(ŷ, y, w) = 1/mean(abs.((ŷ - y).^w));
julia> MLJ.supports_weights(::typeof(my_weighted_score)) = true;
julia> MLJ.orientation(::typeof(my_weighted_score)) = :score;
julia> my_weighted_score(ŷ, y)
1.3333333333333333
julia> X = (x=rand(4), penalty=[1, 2, 3, 4]);
julia> my_feature_dependent_loss(ŷ, X, y) = sum(abs.(ŷ - y) .* X.penalty)/sum(X.penalty);
julia> MLJ.is_feature_dependent(::typeof(my_feature_dependent_loss)) = true
julia> my_feature_dependent_loss(ŷ, X, y)
0.7
The possible signatures for custom measures are: measure(ŷ, y)
, measure(ŷ, y, w)
, measure(ŷ, X, y)
and measure(ŷ, X, y, w)
, each measure implementing one non-weighted version, and possibly a second weighted version.
Implementation detail: Internally, every measure is evaluated using the syntax
MLJ.value(measure, ŷ, X, y, w)
and the traits determine what can be ignored and how measure
is actually called. If w=nothing
then the non-weighted form of measure
is dispatched.
Using measures from LossFunctions.jl
The LossFunctions.jl package includes "distance loss" functions for Continuous
targets, and "marginal loss" functions for Finite{2}
(binary) targets. While the LossFunctions.jl interface differs from the present one (for, example binary observations must be +1 or -1), MLJ has overloaded instances of the LossFunctions.jl types to behave the same as the built-in types.
Note that the "distance losses" in the package apply to deterministic predictions, while the "marginal losses" apply to probabilistic predictions.
List of measures
ms = measures()
types = map(ms) do m m.name end
instance = map(ms) do m m.instances end
t = (type=types, instances=instance)
DataFrame(t)
type | instances | |
---|---|---|
String | Array… | |
1 | Accuracy | ["accuracy"] |
2 | AreaUnderCurve | ["area_under_curve", "auc"] |
3 | BalancedAccuracy | ["balanced_accuracy", "bacc", "bac"] |
4 | BrierLoss | ["brier_loss"] |
5 | BrierScore | ["brier_score"] |
6 | ConfusionMatrix | ["confusion_matrix", "confmat"] |
7 | FScore | ["f1score"] |
8 | FalseDiscoveryRate | ["false_discovery_rate", "falsediscovery_rate", "fdr"] |
9 | FalseNegative | ["false_negative", "falsenegative"] |
10 | FalseNegativeRate | ["false_negative_rate", "falsenegative_rate", "fnr", "miss_rate"] |
11 | FalsePositive | ["false_positive", "falsepositive"] |
12 | FalsePositiveRate | ["false_positive_rate", "falsepositive_rate", "fpr", "fallout"] |
13 | LPLoss | ["l1", "l2"] |
14 | LogCoshLoss | ["log_cosh", "log_cosh_loss"] |
15 | LogLoss | ["log_loss", "cross_entropy"] |
16 | MatthewsCorrelation | ["matthews_correlation", "mcc"] |
17 | MeanAbsoluteError | ["mae", "mav", "mean_absolute_error", "mean_absolute_value"] |
18 | MeanAbsoluteProportionalError | ["mape"] |
19 | MisclassificationRate | ["misclassification_rate", "mcr"] |
20 | MulticlassFScore | ["macro_f1score", "micro_f1score", "multiclass_f1score"] |
21 | MulticlassFalseDiscoveryRate | ["multiclass_falsediscovery_rate", "multiclass_fdr"] |
22 | MulticlassFalseNegative | ["multiclass_false_negative", "multiclass_falsenegative"] |
23 | MulticlassFalseNegativeRate | ["multiclass_false_negative_rate", "multiclass_fnr", "multiclass_miss_rate", "multiclass_falsenegative_rate"] |
24 | MulticlassFalsePositive | ["multiclass_false_positive", "multiclass_falsepositive"] |
25 | MulticlassFalsePositiveRate | ["multiclass_false_positive_rate", "multiclass_fpr", "multiclass_fallout", "multiclass_falsepositive_rate"] |
26 | MulticlassNegativePredictiveValue | ["multiclass_negative_predictive_value", "multiclass_negativepredictive_value", "multiclass_npv"] |
27 | MulticlassPrecision | ["multiclass_positive_predictive_value", "multiclass_ppv", "multiclass_positivepredictive_value", "multiclass_recall"] |
28 | MulticlassTrueNegative | ["multiclass_true_negative", "multiclass_truenegative"] |
29 | MulticlassTrueNegativeRate | ["multiclass_true_negative_rate", "multiclass_tnr", "multiclass_specificity", "multiclass_selectivity", "multiclass_truenegative_rate"] |
30 | MulticlassTruePositive | ["multiclass_true_positive", "multiclass_truepositive"] |
31 | MulticlassTruePositiveRate | ["multiclass_true_positive_rate", "multiclass_tpr", "multiclass_sensitivity", "multiclass_recall", "multiclass_hit_rate", "multiclass_truepositive_rate"] |
32 | NegativePredictiveValue | ["negative_predictive_value", "negativepredictive_value", "npv"] |
33 | Precision | ["positive_predictive_value", "ppv", "positivepredictive_value", "precision"] |
34 | RootMeanSquaredError | ["rms", "rmse", "root_mean_squared_error"] |
35 | RootMeanSquaredLogError | ["rmsl", "rmsle", "root_mean_squared_log_error"] |
36 | RootMeanSquaredLogProportionalError | ["rmslp1"] |
37 | RootMeanSquaredProportionalError | ["rmsp"] |
38 | TrueNegative | ["true_negative", "truenegative"] |
39 | TrueNegativeRate | ["true_negative_rate", "truenegative_rate", "tnr", "specificity", "selectivity"] |
40 | TruePositive | ["true_positive", "truepositive"] |
41 | TruePositiveRate | ["true_positive_rate", "truepositive_rate", "tpr", "sensitivity", "recall", "hit_rate"] |
42 | DWDMarginLoss | ["dwd_margin_loss"] |
43 | ExpLoss | ["exp_loss"] |
44 | L1HingeLoss | ["l1_hinge_loss"] |
45 | L2HingeLoss | ["l2_hinge_loss"] |
46 | L2MarginLoss | ["l2_margin_loss"] |
47 | LogitMarginLoss | ["logit_margin_loss"] |
48 | ModifiedHuberLoss | ["modified_huber_loss"] |
49 | PerceptronLoss | ["perceptron_loss"] |
50 | SigmoidLoss | ["sigmoid_loss"] |
51 | SmoothedL1HingeLoss | ["smoothed_l1_hinge_loss"] |
52 | ZeroOneLoss | ["zero_one_loss"] |
53 | HuberLoss | ["huber_loss"] |
54 | L1EpsilonInsLoss | ["l1_epsilon_ins_loss"] |
55 | L2EpsilonInsLoss | ["l2_epsilon_ins_loss"] |
56 | LPDistLoss | ["lp_dist_loss"] |
57 | LogitDistLoss | ["logit_dist_loss"] |
58 | PeriodicLoss | ["periodic_loss"] |
59 | QuantileLoss | ["quantile_loss"] |
Other performance related tools
In MLJ one computes a confusion matrix by calling an instance of the ConfusionMatrix
measure type on the data:
MLJBase.ConfusionMatrix
— TypeMLJBase.ConfusionMatrix
A measure type for confusion matrix, which includes the instance(s), confusion_matrix
, confmat
.
ConfusionMatrix()(ŷ, y)
Evaluate the default instance of ConfusionMatrix on observations ŷ
, given ground truth values y
.
If r
is the return value, then the raw confusion matrix is r.mat
, whose rows correspond to predictions, and columns to ground truth. The ordering follows that of levels(y)
.
Use ConfusionMatrix(perm=[2, 1])
to reverse the class order for binary data. For more than two classes, specify an appropriate permutation, as in ConfusionMatrix(perm=[2, 3, 1])
.
Requires scitype(y)
to be a subtype of AbstractArray{<:OrderedFactor{2}}
(binary classification where choice of "true" effects the measure); ŷ
must be a deterministic prediction.
For more information, run info(ConfusionMatrix)
.
MLJBase.roc_curve
— Functionfprs, tprs, ts = roc_curve(ŷ, y) = roc(ŷ, y)
Return the ROC curve for a two-class probabilistic prediction ŷ
given the ground truth y
. The true positive rates, false positive rates over a range of thresholds ts
are returned. Note that if there are k
unique scores, there are correspondingly k
thresholds and k+1
"bins" over which the FPR and TPR are constant:
[0.0 - thresh[1]]
[thresh[1] - thresh[2]]
- ...
[thresh[k] - 1]
consequently, tprs
and fprs
are of length k+1
if ts
is of length k
.
To draw the curve using your favorite plotting backend, do plot(fprs, tprs)
.