Simple User Defined Models

To quickly implement a new supervised model in MLJ, it suffices to:

  • Define a mutable struct to store hyperparameters. This is either a subtype of Probabilistic or Deterministic, depending on whether probabilistic or ordinary point predictions are intended. This struct is the model.

  • Define a fit method, dispatched on the model, returning learned parameters, also known as the fitresult.

  • Define a predict method, dispatched on the model, and the fitresult, to return predictions on new patterns.

In the examples below, the training input X of fit, and the new input Xnew passed to predict, are tables. Each training target y is an AbstractVector.

The predictions returned by predict have the same form as y for deterministic models, but are Vectors of distributions for probabilistic models.

Advanced model functionality not addressed here includes: (i) optional update method to avoid redundant calculations when calling fit! on machines a second time; (ii) reporting extra training-related statistics; (iii) exposing model-specific functionality; (iv) checking the scientific type of data passed to your model in machine construction; and (iv) checking the validity of hyperparameter values. All this is described in Adding Models for General Use.

For an unsupervised model, implement transform and, optionally, inverse_transform using the same signature at predict below.

A simple deterministic regressor

Here's a quick-and-dirty implementation of a ridge regressor with no intercept:

import MLJBase
using LinearAlgebra

mutable struct MyRegressor <: MLJBase.Deterministic
MyRegressor(; lambda=0.1) = MyRegressor(lambda)

# fit returns coefficients minimizing a penalized rms loss function:
function, verbosity, X, y)
    x = MLJBase.matrix(X)                     # convert table to matrix
    fitresult = (x'x + model.lambda*I)\(x'y)  # the coefficients
    return fitresult, cache, report

# predict uses coefficients to make a new prediction:
MLJBase.predict(::MyRegressor, fitresult, Xnew) = MLJBase.matrix(Xnew) * fitresult

After loading this code, all MLJ's basic meta-algorithms can be applied to MyRegressor:

julia> X, y = @load_boston;
julia> model = MyRegressor(lambda=1.0)MyRegressor( lambda = 1.0)
julia> regressor = machine(model, X, y)untrained Machine; caches model-specific representations of data model: MyRegressor(lambda = 1.0) args: 1: Source @047 ⏎ Table{AbstractVector{Continuous}} 2: Source @138 ⏎ AbstractVector{Continuous}
julia> evaluate!(regressor, resampling=CV(), measure=rms, verbosity=0)PerformanceEvaluation object with these fields: model, measure, operation, measurement, per_fold, per_observation, fitted_params_per_fold, report_per_fold, train_test_rows, resampling, repeats Extract: ┌────────────────────────┬───────────┬─────────────┬─────────┬────────────────── │ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯ ├────────────────────────┼───────────┼─────────────┼─────────┼────────────────── │ RootMeanSquaredError() │ predict │ 5.94 │ 2.58 │ [2.71, 4.44, 5. ⋯ └────────────────────────┴───────────┴─────────────┴─────────┴────────────────── 1 column omitted

A simple probabilistic classifier

The following probabilistic model simply fits a probability distribution to the MultiClass training target (i.e., ignores X) and returns this pdf for any new pattern:

import MLJBase
import Distributions

struct MyClassifier <: MLJBase.Probabilistic

# `fit` ignores the inputs X and returns the training target y
# probability distribution:
function, verbosity, X, y)
    fitresult =, y)
    cache = nothing
    report = nothing
    return fitresult, cache, report

# `predict` returns the passed fitresult (pdf) for all new patterns:
MLJBase.predict(model::MyClassifier, fitresult, Xnew) =
    [fitresult for r in 1:nrows(Xnew)]
julia> X, y = @load_iris
julia> mach = fit!(machine(MyClassifier(), X, y))
julia> predict(mach, selectrows(X, 1:2))
2-element Array{UnivariateFinite{String,UInt32,Float64},1}:
 UnivariateFinite(setosa=>0.333, versicolor=>0.333, virginica=>0.333)
 UnivariateFinite(setosa=>0.333, versicolor=>0.333, virginica=>0.333)