Adding Models for General Use


Models implementing the MLJ model interface according to the instructions given here should import MLJModelInterface version 1.0.0 or higher. This is enforced with a statement such as MLJModelInterface = "^1" under [compat] in the Project.toml file of the package containing the implementation.

This guide outlines the specification of the MLJ model interface and provides detailed guidelines for implementing the interface for models intended for general use. See also the more condensed Quick-Start Guide to Adding Models.

For sample implementations, see MLJModels/src.

Interface code can be hosted by the package providing the core machine learning algorithm, or by a stand-alone "interface-only" package, using the template MLJExampleInterface.jl (see Where to place code implementing new models below).

The machine learning tools provided by MLJ can be applied to the models in any package that imports the package MLJModelInterface and implements the API defined there, as outlined below. For a quick-and-dirty implementation of user-defined models see Simple User Defined Models. To make new models available to all MLJ users, see Where to place code implementing new models.


MLJModelInterface is a very light-weight interface allowing you to define your interface, but does not provide the functionality required to use or test your interface; this requires MLJBase. So, while you only need to add MLJModelInterface to your project's [deps], for testing purposes you need to add MLJBase to your project's [extras] and [targets]. In testing, simply use MLJBase in place of MLJModelInterface.

It is assumed the reader has read Getting Started. To implement the API described here, some familiarity with the following packages is also helpful:

  • ScientificTypes.jl (for specifying model requirements of data)

  • Distributions.jl (for probabilistic predictions)

  • CategoricalArrays.jl (essential if you are implementing a model handling data of Multiclass or OrderedFactor scitype; familiarity with CategoricalPool objects required)

  • Tables.jl (if your algorithm needs input data in a novel format).

In MLJ, the basic interface exposed to the user, built atop the model interface described here, is the machine interface. After a first reading of this document, the reader may wish to refer to MLJ Internals for context.


A model is an object storing hyperparameters associated with some machine learning algorithm, and that is all. In MLJ, hyperparameters include configuration parameters, like the number of threads, and special instructions, such as "compute feature rankings", which may or may not affect the final learning outcome. However, the logging level (verbosity below) is excluded. Learned parameters (such as the coefficients in a linear model) have no place in the model struct.

The name of the Julia type associated with a model indicates the associated algorithm (e.g., DecisionTreeClassifier). The outcome of training a learning algorithm is called a fitresult. For ordinary multivariate regression, for example, this would be the coefficients and intercept. For a general supervised model, it is the (generally minimal) information needed to make new predictions.

The ultimate supertype of all models is MLJModelInterface.Model, which has two abstract subtypes:

abstract type Supervised <: Model end
abstract type Unsupervised <: Model end

Supervised models are further divided according to whether they are able to furnish probabilistic predictions of the target (which they will then do by default) or directly predict "point" estimates, for each new input pattern:

abstract type Probabilistic <: Supervised end
abstract type Deterministic <: Supervised end

Further division of model types is realized through Trait declarations.

Associated with every concrete subtype of Model there must be a fit method, which implements the associated algorithm to produce the fitresult. Additionally, every Supervised model has a predict method, while Unsupervised models must have a transform method. More generally, methods such as these, that are dispatched on a model instance and a fitresult (plus other data), are called operations. Probabilistic supervised models optionally implement a predict_mode operation (in the case of classifiers) or a predict_mean and/or predict_median operations (in the case of regressors) although MLJModelInterface also provides fallbacks that will suffice in most cases. Unsupervised models may implement an inverse_transform operation.

New model type declarations and optional clean! method

Here is an example of a concrete supervised model type declaration, for a model with a single hyper-parameter:

import MLJModelInterface
const MMI = MLJModelInterface

mutable struct RidgeRegressor <: MMI.Deterministic

Models (which are mutable) should not be given internal constructors. It is recommended that they be given an external lazy keyword constructor of the same name. This constructor defines default values for every field, and optionally corrects invalid field values by calling a clean! method (whose fallback returns an empty message string):

function MMI.clean!(model::RidgeRegressor)
    warning = ""
    if model.lambda < 0
        warning *= "Need lambda ≥ 0. Resetting lambda=0. "
        model.lambda = 0
    return warning

# keyword constructor
function RidgeRegressor(; lambda=0.0)
    model = RidgeRegressor(lambda)
    message = MMI.clean!(model)
    isempty(message) || @warn message
    return model

Important. The clean method must have the property that clean!(clean!(model)) == clean!(model) for any instance model.

Although not essential, try to avoid Union types for model fields. For example, a field declaration features::Vector{Symbol} with a default of Symbol[] (detected with isempty method) is preferred to features::Union{Vector{Symbol}, Nothing} with a default of nothing.

Hyper-parameters for parellizatioin options

The section Acceleration and Parallelism indicates how MLJ models specify an option to run an algorithm using distributed processing or multi-threading. A hyper-parameter specifying such an option should be called acceleration. Its value a should satisfy a isa AbstractResource where AbstractResource is defined in the ComputationalResources.jl package. An option to run on a GPU is ordinarily indicated with the CUDALibs() resource.

Hyper-parameter access and mutation

To support hyper-parameter optimization (see Tuning Models) any hyper-parameter to be individually controlled must be:

  • property-accessible; nested property access allowed, as in model.detector.K
  • mutable

For an un-nested hyper-parameter, the requirement is that getproperty(model, :param_name) and setproperty!(model, :param_name, value) have the expected behavior. (In hyper-parameter tuning, recursive access is implemented using MLJBase.recursive_getpropertyand [MLJBase.recursively_setproperty!`](@ref).)

Combining hyper-parameters in a named tuple does not generally work, because, although property-accessible (with nesting), an individual value cannot be mutated.

For a suggested way to deal with hyper-parameters varying in number, see the implementation of Stack, where the model struct stores a varying number of base models internally as a vector, but components are named at construction and accessed by overloading getproperty/setproperty! appropriately.

Macro shortcut

An alternative to declaring the model struct, clean! method and keyword constructor, is to use the @mlj_model macro, as in the following example:

@mlj_model mutable struct YourModel <: MMI.Deterministic
    a::Float64 = 0.5::(_ > 0)
    b::String  = "svd"::(_ in ("svd","qr"))

This declaration specifies:

  • A keyword constructor (here YourModel(; a=..., b=...)),
  • Default values for the hyperparameters,
  • Constraints on the hyperparameters where _ refers to a value passed.

For example, a::Float64 = 0.5::(_ > 0) indicates that the field a is a Float64, takes 0.5 as default value, and expects its value to be positive.

You cannot use the @mlj_model macro if your model struct has type parameters.

Known issue with @mlj_macro

Defaults with negative values can trip up the @mlj_macro (see this issue). So, for example, this does not work:

@mlj_model mutable struct Bar
    a::Int = -1::(_ > -2)

But this does:

@mlj_model mutable struct Bar
    a::Int = (-)(1)::(_ > -2)

Supervised models

Mathematical assumptions

At present, MLJ's performance estimate functionality (resampling using evaluate/evaluate!) tacitly assumes that feature-label pairs of observations (X1, y1), (X2, y2), (X2, y2), ... are being modelled as identically independent random variables (i.i.d.), and constructs some kind of representation of an estimate of the conditional probablility p(y | X) (y and X single observations). It may be that a model implementing the MLJ interface has the potential to make predictions under weaker assumptions (e.g., time series forecasting models). However the output of the compulsory predict method described below should be the output of the model under the i.i.d assumption.

In the future newer methods may be introduced to handle weaker assumptions (see, e.g., The predict_joint method below).

Summary of methods

The compulsory and optional methods to be implemented for each concrete type SomeSupervisedModel <: MMI.Supervised are summarized below.

An = indicates the return value for a fallback version of the method.

Compulsory:, verbosity, X, y) -> fitresult, cache, report
MMI.predict(model::SomeSupervisedModel, fitresult, Xnew) -> yhat

Optional, to check and correct invalid hyperparameter values:

MMI.clean!(model::SomeSupervisedModel) = ""

Optional, to return user-friendly form of fitted parameters:

MMI.fitted_params(model::SomeSupervisedModel, fitresult) = fitresult

Optional, to avoid redundant calculations when re-fitting machines associated with a model:

MMI.update(model::SomeSupervisedModel, verbosity, old_fitresult, old_cache, X, y) =, verbosity, X, y)

Optional, to specify default hyperparameter ranges (for use in tuning):

MMI.hyperparameter_ranges(T::Type) = Tuple(fill(nothing, length(fieldnames(T))))

Optional, if SomeSupervisedModel <: Probabilistic:

MMI.predict_mode(model::SomeSupervisedModel, fitresult, Xnew) =
    mode.(predict(model, fitresult, Xnew))
MMI.predict_mean(model::SomeSupervisedModel, fitresult, Xnew) =
    mean.(predict(model, fitresult, Xnew))
MMI.predict_median(model::SomeSupervisedModel, fitresult, Xnew) =
    median.(predict(model, fitresult, Xnew))

Required, if the model is to be registered (findable by general users):

MMI.load_path(::Type{<:SomeSupervisedModel})    = ""
MMI.package_name(::Type{<:SomeSupervisedModel}) = "Unknown"
MMI.package_uuid(::Type{<:SomeSupervisedModel}) = "Unknown"
MMI.input_scitype(::Type{<:SomeSupervisedModel}) = Unknown

Strongly recommended, to constrain the form of target data passed to fit:

MMI.target_scitype(::Type{<:SomeSupervisedModel}) = Unknown

Optional but recommended:

MMI.package_url(::Type{<:SomeSupervisedModel})  = "unknown"
MMI.is_pure_julia(::Type{<:SomeSupervisedModel}) = false
MMI.package_license(::Type{<:SomeSupervisedModel}) = "unknown"

If SomeSupervisedModel supports sample weights or class weights, then instead of the fit above, one implements, verbosity, X, y, w=nothing) -> fitresult, cache, report

and, if appropriate

MMI.update(model::SomeSupervisedModel, verbosity, old_fitresult, old_cache, X, y, w=nothing) =, verbosity, X, y, w)

Additionally, if SomeSupervisedModel supports sample weights, one must declare

MMI.supports_weights(model::Type{<:SomeSupervisedModel}) = true

Optionally, an implemenation may add a data front-end, for transforming user data (such as a table) into some model-specific format (such as a matrix), and for adding methods to specify how said format is resampled. (This alters the meaning of X, y and w in the signatures of fit, update, predict, etc; see Implementing a data front-end for details). This can provide the MLJ user certain performance advantages when fitting a machine.

MLJModelInterface.reformat(model::SomeSupervisedModel, args...) = args
MLJModelInterface.selectrows(model::SomeSupervisedModel, I, data...) = data

Optionally, to customized support for serialization of machines (see Serialization), overload, model::SomeModel, fitresult; kwargs...) = fitresult

and possibly

MMI.restore(filename, model::SomeModel, serializable_fitresult) -> serializable_fitresult

These last two are unlikely to be needed if wrapping pure Julia code.

The form of data for fitting and predicting

The model implementer does not have absolute control over the types of data X, y and Xnew appearing in the fit and predict methods they must implement. Rather, they can specify the scientific type of this data by making appropriate declarations of the traits input_scitype and target_scitype discussed later under Trait declarations.

Important Note. Unless it genuinely makes little sense to do so, the MLJ recommendation is to specify a Table scientific type for X (and hence Xnew) and an AbstractVector scientific type (e.g., AbstractVector{Continuous}) for targets y. Algorithms requiring matrix input can coerce their inputs appropriately; see below.

Additional type coercions

If the core algorithm being wrapped requires data in a different or more specific form, then fit will need to coerce the table into the form desired (and the same coercions applied to X will have to be repeated for Xnew in predict). To assist with common cases, MLJ provides the convenience method MMI.matrix. MMI.matrix(Xtable) has type Matrix{T} where T is the tightest common type of elements of Xtable, and Xtable is any table. (If Xtable is itself just a wrapped matrix, Xtable=Tables.table(A), then A=MMI.table(Xtable) will be returned without any copying.)

Alternatively, a more performant option is to implement a data front-end for your model; see Implementing a data front-end.

Other auxiliary methods provided by MLJModelInterface for handling tabular data are: selectrows, selectcols, select and schema (for extracting the size, names and eltypes of a table's columns). See Convenience methods below for details.

Important convention

It is to be understood that the columns of the table X correspond to features and the rows to observations. So, for example, the predict method for a linear regression model might look like predict(model, w, Xnew) = MMI.matrix(Xnew)*w, where w is the vector of learned coefficients.

The fit method

A compulsory fit method returns three objects:, verbosity, X, y) -> fitresult, cache, report
  1. fitresult is the fitresult in the sense above (which becomes an argument for predict discussed below).

  2. report is a (possibly empty) NamedTuple, for example, report=(deviance=..., dof_residual=..., stderror=..., vcov=...). Any training-related statistics, such as internal estimates of the generalization error, and feature rankings, should be returned in the report tuple. How, or if, these are generated should be controlled by hyperparameters (the fields of model). Fitted parameters, such as the coefficients of a linear model, do not go in the report as they will be extractable from fitresult (and accessible to MLJ through the fitted_params method described below).

3. The value of cache can be nothing, unless one is also defining an update method (see below). The Julia type of cache is not presently restricted.


The fit (and update) methods should not mutate the model. If necessary, fit can create a deepcopy of model first.

It is not necessary for fit to provide type or dimension checks on X or y or to call clean! on the model; MLJ will carry out such checks.

The types of X and y are constrained by the input_scitype and target_scitype trait declarations; see Trait declarations below. (That is, unless a data front-end is implemented, in which case these traits refer instead to the arguments of the overloaded reformat method, and the types of X and y are determined by the output of reformat.)

The method fit should never alter hyperparameter values, the sole exception being fields of type <:AbstractRNG. If the package is able to suggest better hyperparameters, as a byproduct of training, return these in the report field.

The verbosity level (0 for silent) is for passing to learning algorithm itself. A fit method wrapping such an algorithm should generally avoid doing any of its own logging.

Sample weight support. If supports_weights(::Type{<:SomeSupervisedModel}) has been declared true, then one instead implements the following variation on the above fit:, verbosity, X, y, w=nothing) -> fitresult, cache, report

The fitted_params method

A fitted_params method may be optionally overloaded. It's purpose is to provide MLJ access to a user-friendly representation of the learned parameters of the model (as opposed to the hyperparameters). They must be extractable from fitresult.

MMI.fitted_params(model::SomeSupervisedModel, fitresult) -> friendly_fitresult::NamedTuple

For a linear model, for example, one might declare something like friendly_fitresult=(coefs=[...], bias=...).

The fallback is to return (fitresult=fitresult,).

The predict method

A compulsory predict method has the form

MMI.predict(model::SomeSupervisedModel, fitresult, Xnew) -> yhat

Here Xnew will have the same form as the X passed to fit.

Note that while Xnew generally consists of multiple observations (e.g., has multiple rows in the case of a table) it is assumed, in view of the i.i.d assumption recalled above, that calling predict(..., Xnew) is equivalent to broadcasting some method predict_one(..., x) over the individual observations x in Xnew (a method implementing the probablility distribution p(X |y) above).

Prediction types for deterministic responses.

In the case of Deterministic models, yhat should have the same scitype as the y passed to fit (see above). Any CategoricalValue elements of yhat must have a pool == to the pool of the target y presented in training, even if not all levels appear in the training data or prediction itself. For example, in the case of a univariate target, such as scitype(y) <: AbstractVector{Multiclass{3}}, one requires MLJ.classes(yhat[i]) == MLJ.classes(y[j]) for all admissible i and j. (The method classes is described under Convenience methods below).

Unfortunately, code not written with the preservation of categorical levels in mind poses special problems. To help with this, MLJModelInterface provides three utility methods: int (for converting a CategoricalValue into an integer, the ordering of these integers being consistent with that of the pool), decoder (for constructing a callable object that decodes the integers back into CategoricalValue objects), and classes, for extracting all the CategoricalValue objects sharing the pool of a particular value. Refer to Convenience methods below for important details.

Note that a decoder created during fit may need to be bundled with fitresult to make it available to predict during re-encoding. So, for example, if the core algorithm being wrapped by fit expects a nominal target yint of type Vector{<:Integer} then a fit method may look something like this:

function, verbosity, X, y)
    yint =
    a_target_element = y[1]                    # a CategoricalValue/String
    decode = MMI.decoder(a_target_element) # can be called on integers

    core_fitresult =, yint, verbosity=verbosity)

    fitresult = (decode, core_fitresult)
    cache = nothing
    report = nothing
    return fitresult, cache, report

while a corresponding deterministic predict operation might look like this:

function MMI.predict(model::SomeSupervisedModel, fitresult, Xnew)
    decode, core_fitresult = fitresult
    yhat = SomePackage.predict(core_fitresult, Xnew)
    return decode.(yhat)  # or decode(yhat) also works

For a concrete example, refer to the code for SVMClassifier.

Of course, if you are coding a learning algorithm from scratch, rather than wrapping an existing one, these extra measures may be unnecessary.

Prediction types for probabilistic responses

In the case of Probabilistic models with univariate targets, yhat must be an AbstractVector whose elements are distributions (one distribution per row of Xnew).

Presently, a distribution is any object d for which MMI.isdistribution(::d) = true, which is the case for objects of type Distributions.Sampleable.

Use the distribution MMI.UnivariateFinite for Probabilistic models predicting a target with Finite scitype (classifiers). In this case the eltype of the training target y will be a CategoricalValue.

For efficiency, one should not construct UnivariateDistribution instances one at a time. Rather, once a probability vector or matrix is known, construct an instance of UnivariateFiniteVector <: AbstractArray{<:UnivariateFinite},1} to return. Both UnivariateFinite and UnivariateFiniteVector objects are constructed using the single UnivariateFinite function.

For example, suppose the target y arrives as a subsample of some ybig and is missing some classes:

ybig = categorical([:a, :b, :a, :a, :b, :a, :rare, :a, :b])
y = ybig[1:6]

Your fit method has bundled the first element of y with the fitresult to make it available to predict for purposes of tracking the complete pool of classes. Let's call this an_element = y[1]. Then, supposing the corresponding probabilities of the observed classes [:a, :b] are in an n x 2 matrix probs (where n the number of rows of Xnew) then you return

yhat = UnivariateFinite([:a, :b], probs, pool=an_element)

This object automatically assigns zero-probability to the unseen class :rare (i.e., pdf.(yhat, :rare) works and returns a zero vector). If you would like to assign :rare non-zero probabilities, simply add it to the first vector (the support) and supply a larger probs matrix.

If instead of raw labels [:a, :b] you have the corresponding CategoricalElements (from, e.g., filter(cv->cv in unique(y), classes(y))) then you can use these instead and drop the pool specifier.

In a binary classification problem it suffices to specify a single vector of probabilities, provided you specify augment=true, as in the following example, and note carefully that these probablities are associated with the last (second) class you specify in the constructor:

y = categorical([:TRUE, :FALSE, :FALSE, :TRUE, :TRUE])
an_element = y[1]
probs = rand(10)
yhat = UnivariateFinite([:FALSE, :TRUE], probs, augment=true, pool=an_element)

The constructor has a lot of options, including passing a dictionary instead of vectors. See UnivariateFinite for details.

See LinearBinaryClassifier for an example of a Probabilistic classifier implementation.

Important note on binary classifiers. There is no "Binary" scitype distinct from Multiclass{2} or OrderedFactor{2}; Binary is just an alias for Union{Multiclass{2},OrderedFactor{2}}. The target_scitype of a binary classifier will generally be AbstractVector{<:Binary} and according to the mlj scitype convention, elements of y have type CategoricalValue, and not Bool. See BinaryClassifier for an example.

The predict_joint method


The following API is experimental. It is subject to breaking changes during minor or major releases without warning.

MMI.predict_joint(model::SomeSupervisedModel, fitresult, Xnew) -> yhat

Any Probabilistic model type SomeModelmay optionally implement a predict_joint method, which has the same signature as predict, but whose predictions are a single distribution (rather than a vector of per-observation distributions).

Specifically, the output yhat of predict_joint should be an instance of Distributions.Sampleable{<:Multivariate,V}, where scitype(V) = target_scitype(SomeModel) and samples have length n, where n is the number of observations in Xnew.

If a new model type subtypes JointProbablistic <: Probabilistic then implementation of predict_joint is compulsory.

Trait declarations

Two trait functions allow the implementer to restrict the types of data X, y and Xnew discussed above. The MLJ task interface uses these traits for data type checks but also for model search. If they are omitted (and your model is registered) then a general user may attempt to use your model with inappropriately typed data.

The trait functions input_scitype and target_scitype take scientific data types as values. We assume here familiarity with ScientificTypes.jl (see Getting Started for the basics).

For example, to ensure that the X presented to the DecisionTreeClassifier fit method is a table whose columns all have Continuous element type (and hence AbstractFloat machine type), one declares

MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = MMI.Table(MMI.Continuous)

or, equivalently,

MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = Table(Continuous)

If, instead, columns were allowed to have either: (i) a mixture of Continuous and Missing values, or (ii) Count (i.e., integer) values, then the declaration would be

MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = Table(Union{Continuous,Missing},Count)

Similarly, to ensure the target is an AbstractVector whose elements have Finite scitype (and hence CategoricalValue machine type) we declare

MMI.target_scitype(::Type{<:DecisionTreeClassifier}) = AbstractVector{<:Finite}

Multivariate targets

The above remarks continue to hold unchanged for the case multivariate targets. For example, if we declare

target_scitype(SomeSupervisedModel) = Table(Continuous)

then this constrains the target to be any table whose columns have Continous element scitype (i.e., AbstractFloat), while

target_scitype(SomeSupervisedModel) = Table(Continuous, Finite{2})

restricts to tables with continuous or binary (ordered or unordered) columns.

For predicting variable length sequences of, say, binary values (CategoricalValues) with some common size-two pool) we declare

target_scitype(SomeSupervisedModel) = AbstractVector{<:NTuple{<:Finite{2}}}

The trait functions controlling the form of data are summarized as follows:

methodreturn typedeclarable return valuesfallback value
input_scitypeTypesome scientfic typeUnknown
target_scitypeTypesome scientific typeUnknown

Additional trait functions tell MLJ's @load macro how to find your model if it is registered, and provide other self-explanatory metadata about the model:

methodreturn typedeclarable return valuesfallback value
is_pure_juliaBooltrue or falsefalse
supports_weightsBooltrue or falsefalse

Here is the complete list of trait function declarations for DecisionTreeClassifier, whose core algorithms are provided by DecisionTree.jl, but whose interface actually lives at MLJDecisionTreeInterface.jl.

MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = MMI.Table(MMI.Continuous)
MMI.target_scitype(::Type{<:DecisionTreeClassifier}) = AbstractVector{<:MMI.Finite}
MMI.load_path(::Type{<:DecisionTreeClassifier}) = "MLJDecisionTreeInterface.DecisionTreeClassifier"
MMI.package_name(::Type{<:DecisionTreeClassifier}) = "DecisionTree"
MMI.package_uuid(::Type{<:DecisionTreeClassifier}) = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
MMI.package_url(::Type{<:DecisionTreeClassifier}) = ""
MMI.is_pure_julia(::Type{<:DecisionTreeClassifier}) = true

Alternatively these traits can also be declared using MMI.metadata_pkg and MMI.metadata_model helper functions as:



Important. Do not omit the load_path specification. If unsure what it should be, post an issue at MLJ.

metadata_pkg(T; args...)

Helper function to write the metadata for a package providing model T. Use it with broadcasting to define the metadata of the package providing a series of models.


  • package_name="unknown" : package name
  • package_uuid="unknown" : package uuid
  • package_url="unknown" : package url
  • is_pure_julia=missing : whether the package is pure julia
  • package_license="unknown": package license
  • is_wrapper=false : whether the package is a wrapper


metadata_pkg.((KNNRegressor, KNNClassifier),
metadata_model(`T`; args...)

Helper function to write the metadata for a model T.


  • input_scitype=Unknown : allowed scientific type of the input data
  • target_scitype=Unknown: allowed sc. type of the target (supervised)
  • output_scitype=Unknown: allowed sc. type of the transformed data (unsupervised)
  • supports_weights=false : whether the model supports sample weights
  • docstring="" : short description of the model
  • load_path="" : where the model is (usually PackageName.ModelName)


    docstring="K-Nearest Neighbors classifier: ...",

Iterative models and the update! method

An update method may be optionally overloaded to enable a call by MLJ to retrain a model (on the same training data) to avoid repeating computations unnecessarily.

MMI.update(model::SomeSupervisedModel, verbosity, old_fitresult, old_cache, X, y) -> fit
result, cache, report
MMI.update(model::SomeSupervisedModel, verbosity, old_fitresult, old_cache, X, y, w=nothing) -> fit
result, cache, report

Here the second variation applies if SomeSupervisedModel supports sample weights.

If an MLJ Machine is being fit! and it is not the first time, then update is called instead of fit, unless the machine fit! has been called with a new rows keyword argument. However, MLJModelInterface defines a fallback for update which just calls fit. For context, see MLJ Internals.

Learning networks wrapped as models constitute one use-case (see Composing Models): one would like each component model to be retrained only when hyperparameter changes "upstream" make this necessary. In this case MLJ provides a fallback (specifically, the fallback is for any subtype of SupervisedNetwork = Union{DeterministicNetwork,ProbabilisticNetwork}). A second more generally relevant use-case is iterative models, where calls to increase the number of iterations only restarts the iterative procedure if other hyperparameters have also changed. (A useful method for inspecting model changes in such cases is MLJModelInterface.is_same_except. ) For an example, see the MLJ ensemble code.

A third use-case is to avoid repeating time-consuming preprocessing of X and y required by some models.

In the event that the argument fitresult (returned by a preceding call to fit) is not sufficient for performing an update, the author can arrange for fit to output in its cache return value any additional information required (for example, pre-processed versions of X and y), as this is also passed as an argument to the update method.

Implementing a data front-end


It is suggested that packages implementing MLJ's model API, that later implement a data front-end, should tag their changes in a breaking release. (The changes will not break use of models for the ordinary MLJ user, who interacts with models exlusively through the machine interface. However, it will break usage for some external packages that have chosen to depend directly on the model API.)

MLJModelInterface.reformat(model, args...) -> data
MLJModelInterface.selectrows(::Model, I, data...) -> sampled_data

Models optionally overload reformat to define transformations of user-supplied data into some model-specific representation (e.g., from a table to a matrix). Computational overheads associated with multiple fit!/predict/transform calls (on MLJ machines) are then avoided, when memory resources allow. The fallback returns args (no transformation).

The selectrows(::Model, I, data...) method is overloaded to specify how the model-specific data is to be subsampled, for some observation indices I (a colon, :, or instance of AbstractVector{<:Integer}). In this way, implementing a data front-end also allow more efficient resampling of data (in user calls to evaluate!).

After detailing formal requirments for implementing a data front-end, we give a Sample implementation. A simple implementation implementation also appears in the EvoTrees.jl package.

Here "user-supplied data" is what the MLJ user supplies when constructing a machine, as in machine(models, args...), which coincides with the arguments expected by fit(model, verbosity, args...) when reformat is not overloaded.

Implementing a reformat data front-end is permitted for any Model subtype, except for subtypes of Static. Here is a complete list of responsibilities for such an implementation, for some model::SomeModelType (a sample implementation follows after):

  • A reformat(model::SomeModelType, args...) -> data method must be implemented for each form of args... appearing in a valid machine construction machine(model, args...) (there will be one for each possible signature of fit(::SomeModelType, ...)).

  • Additionally, if not included above, there must be a single argument form of reformat, reformat(model::SommeModelType, arg) -> (data,), serving as a data front-end for operations like predict. It must always hold that reformat(model, args...)[1] = reformat(model, args[1]).

Important. reformat(model::SomeModelType, args...) must always return a tuple of the same length as args, even if this is one.

  • fit(model::SomeModelType, verbosity, data...) should be implemented as if data is the output of reformat(model, args...), where args is the data an MLJ user has bound to model in some machine. The same applies to any overloading of update.

  • Each implemented operation, such as predict and transform - but excluding inverse_transform - must be defined as if its data arguments are reformated versions of user-supplied data. For example, in the supervised case, data_new in predict(model::SomeModelType, fitresult, data_new) is reformat(model, Xnew), where Xnew is the data provided by the MLJ user in a call predict(mach, Xnew) (mach.model == model).

  • To specify how the model-specific representation of data is to be resampled, implement selectrows(model::SomeModelType, I, data...) -> resampled_data for each overloading of reformat(model::SomeModel, args...) -> data above. Here I is an arbitrary abstract integer vector or : (type Colon).

Important. selectrows(model::SomeModelType, I, args...) must always return a tuple of the same length as args, even if this is one.

The fallback for selectrows is described at selectrows.

Sample implementation

Suppose a supervised model type SomeSupervised supports sample weights, leading to two different fit signatures, and that it has a single operation predict:

fit(model::SomeSupervised, verbosity, X, y)
fit(model::SomeSupervised, verbosity, X, y, w)

predict(model::SomeSupervised, fitresult, Xnew)

Without a data front-end implemented, suppose X is expected to be a table and y a vector, but suppose the core algorithm always converts X to a matrix with features as rows (features corresponding to columns in the table). Then a new data-front end might look like this:

constant MMI = MLJModelInterface

# for fit:
MMI.reformat(::SomeSupervised, X, y) = (MMI.matrix(X, transpose=true), y)
MMI.reformat(::SomeSupervised, X, y, w) = (MMI.matrix(X, transpose=true), y, w)
MMI.selectrows(::SomeSupervised, I, Xmatrix, y) =
    (view(Xmatrix, :, I), view(y, I))
MMI.selectrows(::SomeSupervised, I, Xmatrix, y, w) =
    (view(Xmatrix, :, I), view(y, I), view(w, I))

# for predict:
MMI.reformat(::SomeSupervised, X) = (MMI.matrix(X, transpose=true),)
MMI.selectrows(::SomeSupervised, I, Xmatrix) = view(Xmatrix, I)

With these additions, fit and predict are refactored, so that X and Xnew represent matrices with features as rows.

Supervised models with a transform method

A supervised model may optionally implement a transform method, whose signature is the same as predict. In that case the implementation should define a value for the output_scitype trait. A declaration

output_scitype(::Type{<:SomeSupervisedModel}) = T

is an assurance that scitype(transform(model, fitresult, Xnew)) <: T always holds, for any model of type SomeSupervisedModel.

A use-case for a transform method for a supervised model is a neural network that learns feature embeddings for categorical input features as part of overall training. Such a model becomes a transformer that other supervised models can use to transform the categorical features (instead of applying the higher-dimensional one-hot encoding representations).

Models that learn a probability distribution


The following API is experimental. It is subject to breaking changes during minor or major releases without warning. Models implementing this interface will not work with MLJBase versions earlier than 0.17.5.

Models that fit a probability distribution to some data should be regarded as Probablisitic <: Supervised models with target y = data and X = nothing.

The predict method should return a single distribution.

A working implementation of a model that fits a UnivariateFinite distribution to some categorical data using Laplace smoothing controlled by a hyper-parameter alpha is given here.



The following API is experimental. It is subject to breaking changes during minor or major releases without warning.

The MLJ user can serialize and deserialize a machine, which means serializing/deserializing:

  • the associated Model object (storing hyperparameters)
  • the fitresult (learned parameters)
  • the report generating during training

These are bundled into a single file or IO stream specified by the user using the package JLSO. There are two scenarios in which a new MLJ model API implementation will want to overload two additional methods save and restore to support serialization:

  1. The algorithm-providing package already has it's own serialization format for learned parameters and/or hyper-parameters, which users may want to access. In that case the implementation overloads save.
  1. The fitresult is not a sufficiently persistent object; for example, it is a pointer passed from wrapped C code. In that case the implementation overloads save and restore.

In case 2, 1 presumably applies also, for otherwise MLJ serialization is probably not going to be possible without changes to the algorithm-providing package. An example is given below.

Note that in case 1, MLJ will continue to create it's own self-contained serialization of the machine. Below filename refers to the corresponding serialization file name, as specified by the user, but with any final extension (e.g., ".jlso", ".gz") removed. If the user has alternatively specified an IO object for serialization, then filename is a randomly generated numeric string.

The save method, model::SomeModel, fitresult; kwargs...) -> serializable_fitresult

Implement this method to serialize using a format specific to models of type SomeModel. The fitresult is the first return value of for such model types; kwargs is a list of keyword arguments specified by the user and understood to relate to a some model-specific serialization (cannot be format=... or compression=...). The value of serializable_fitresult should be a persistent representation of fitresult, from which a correct and valid fitresult can be reconstructed using restore (see below).

The fallback of save performs no action and returns fitresult.

The restore method

MMI.restore(filename, model::SomeModel, serializable_fitresult) -> fitresult

Implement this method to reconstruct a fitresult (as returned by from a persistent representation constructed using as described above.

The fallback of restore returns serializable_fitresult.


Below is an example drawn from MLJ's XGBoost wrapper. In this example the fitresult returned by is a tuple (booster, a_target_element) where booster is the XGBoost.jl object storing the learned parameters (essentially a pointer to some object created by C code) and a_target_element is an ordinary CategoricalValue used to track the target classes (a persistent object, requiring no special treatment).

    booster, a_target_element = fitresult

    xgb_filename = string(filename, ".xgboost.model"), xgb_filename)
    persistent_booster = read(xgb_filename)
    @info "Additional XGBoost serialization file \"$xgb_filename\" generated. "
    return (persistent_booster, a_target_element)

function MLJModelInterface.restore(filename,
    persistent_booster, a_target_element = serializable_fitresult

    xgb_filename = string(filename, ".tmp")
    open(xgb_filename, "w") do file
        write(file, persistent_booster)
    booster = XGBoost.Booster(model_file=xgb_filename)
    fitresult = (booster, a_target_element)
    return fitresult

Unsupervised models

Unsupervised models implement the MLJ model interface in a very similar fashion. The main differences are:

  • The fit method has only one training argument X, as in, verbosity, X). However, it has the same return value (fitresult, cache, report). An update method (e.g., for iterative models) can be optionally implemented in the same way.

  • A transform method is compulsory and has the same signature as predict, as in MLJModelInterface.transform(model, fitresult, Xnew).

  • Instead of defining the target_scitype trait, one declares an output_scitype trait (see above for the meaning).

  • An inverse_transform can be optionally implemented. The signature is the same as transform, as in MLJModelInterface.inverse_transform(model, fitresult, Xout), which:

    • must make sense for any Xout for which scitype(Xout) <: output_scitype(SomeSupervisedModel) (see below); and

    • must return an object Xin satisfying scitype(Xin) <: input_scitype(SomeSupervisedModel).

  • A predict method may be optionally implemented, and has the same signature as for supervised models, as in MLJModelInterface.predict(model, fitresult, Xnew). A use-case is clustering algorithms that predict labels and transform new input features into a space of lower-dimension. See Transformers that also predict for an example.

Convenience methods

table(columntable; prototype=nothing)

Convert a named tuple of vectors or tuples columntable, into a table of the "preferred sink type" of prototype. This is often the type of prototype itself, when prototype is a sink; see the Tables.jl documentation. If prototype is not specified, then a named tuple of vectors is returned.

table(A::AbstractMatrix; names=nothing, prototype=nothing)

Wrap an abstract matrix A as a Tables.jl compatible table with the specified column names (a tuple of symbols). If names are not specified, names=(:x1, :x2, ..., :xn) is used, where n=size(A, 2).

If a prototype is specified, then the matrix is materialized as a table of the preferred sink type of prototype, rather than wrapped. Note that if prototype is not specified, then matrix(table(A)) is essentially a no-op.

matrix(X; transpose=false)

If X <: AbstractMatrix, return X or permutedims(X) if transpose=true. If X is a Tables.jl compatible table source, convert X into a Matrix.


int(x; type=nothing)

The positional integer of the CategoricalString or CategoricalValue x, in the ordering defined by the pool of x. The type of int(x) is the reference type of x.

Not to be confused with x.ref, which is unchanged by reordering of the pool of x, but has the same type.


Broadcasted versions of int.

julia> v = categorical([:c, :b, :c, :a])
julia> levels(v)
3-element Array{Symbol,1}:
julia> int(v)
4-element Array{UInt32,1}:

See also: decoder.


A list of categorial elements in the common pool of classes used to construct d.

v = categorical(["yes", "maybe", "no", "yes"])
d = UnivariateFinite(v[1:2], [0.3, 0.7])
classes(d) # CategoricalArray{String,1,UInt32}["maybe", "no", "yes"]

All the categorical elements with the same pool as x (including x), returned as a list, with an ordering consistent with the pool. Here x has CategoricalValue or CategoricalString type, and classes(x) is a vector of the same eltype. Note that x in classes(x) is always true.

Not to be confused with levels(x.pool). See the example below.

julia>  v = categorical([:c, :b, :c, :a])
4-element CategoricalArrays.CategoricalArray{Symbol,1,UInt32}:

julia> levels(v)
3-element Array{Symbol,1}:

julia> x = v[4]
CategoricalArrays.CategoricalValue{Symbol,UInt32} :a

julia> classes(x)
3-element CategoricalArrays.CategoricalArray{Symbol,1,UInt32}:

julia> levels(x.pool)
3-element Array{Symbol,1}:
d = decoder(x)

A callable object for decoding the integer representation of a CategoricalString or CategoricalValue sharing the same pool as x. (Here x is of one of these two types.) Specifically, one has d(int(y)) == y for all y in classes(x). One can also call d on integer arrays, in which case d is broadcast over all elements.

julia> v = categorical([:c, :b, :c, :a])
julia> int(v)
4-element Array{UInt32,1}:
julia> d = decoder(v[3])
julia> d(int(v)) == v

Warning: It is not true that int(d(u)) == u always holds.

See also: int, classes.

select(X, r, c)

Select element(s) of a table or matrix at row(s) r and column(s) c. An object of the sink type of X (or a matrix) is returned unless c is a single integer or symbol. In that case a vector is returned, unless r is a single integer, in which case a single element is returned.

See also: selectrows, selectcols.

selectrows(X::AbstractNode, r)

Returns a Node object N such that N() = selectrows(X(), r) (and N(rows=s) = selectrows(X(rows=s), r)).

selectrows(X, r)

Select single or multiple rows from a table, abstract vector or matrix X. If X is tabular, the object returned is a table of the preferred sink type of typeof(X), even if only a single row is selected.

If the object is neither a table, abstract vector or matrix, X is returned and r is ignored.

MLJModelInterface.selectrows(::Model, I, data...) -> sampled_data

A model overloads selectrows whenever it buys into the optional reformat front-end for data preprocessing. See reformat for details. The fallback assumes data is a tuple and calls selectrows(X, I) for each X in data, returning the results in a new tuple of the same length. This call makes sense when X is a table, abstract vector or abstract matrix. In the last two cases, a new object and not a view is returned.

selectcols(X::AbstractNode, c)

Returns Node object N such that N() = selectcols(X(), c).

selectcols(X, c)

Select single or multiple columns from a matrix or table X. If c is an abstract vector of integers or symbols, then the object returned is a table of the preferred sink type of typeof(X). If c is a single integer or column, then an AbstractVector is returned.


Construct a discrete univariate distribution whose finite support is the elements of the vector support, and whose corresponding probabilities are elements of the vector probs. Alternatively, construct an abstract array of UnivariateFinite distributions by choosing probs to be an array of one higher dimension than the array generated.

Unless pool is specified, support should have type AbstractVector{<:CategoricalValue} and all elements are assumed to share the same categorical pool, which may be larger than support.

Important. All levels of the common pool have associated probabilities, not just those in the specified support. However, these probabilities are always zero (see example below).

If probs is a matrix, it should have a column for each class in support (or one less, if augment=true). More generally, probs will be an array whose size is of the form (n1, n2, ..., nk, c), where c = length(support) (or one less, if augment=true) and the constructor then returns an array of size (n1, n2, ..., nk).

using CategoricalArrays
v = categorical([:x, :x, :y, :x, :z])

julia> UnivariateFinite(classes(v), [0.2, 0.3, 0.5])
UnivariateFinite{Multiclass{3}}(x=>0.2, y=>0.3, z=>0.5)

julia> d = UnivariateFinite([v[1], v[end]], [0.1, 0.9])
UnivariateFinite{Multiclass{3}(x=>0.1, z=>0.9)

julia> rand(d, 3)
3-element Array{Any,1}:
 CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
 CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
 CategoricalArrays.CategoricalValue{Symbol,UInt32} :z

julia> levels(d)
3-element Array{Symbol,1}:

julia> pdf(d, :y)

Specifying a pool

Alternatively, support may be a list of raw (non-categorical) elements if pool is:

  • some CategoricalArray, CategoricalValue or CategoricalPool, such that support is a subset of levels(pool)

  • missing, in which case a new categorical pool is created which has support as its only levels.

In the last case, specify ordered=true if the pool is to be considered ordered.

julia> UnivariateFinite([:x, :z], [0.1, 0.9], pool=missing, ordered=true)
UnivariateFinite{OrderedFactor{2}}(x=>0.1, z=>0.9)

julia> d = UnivariateFinite([:x, :z], [0.1, 0.9], pool=v) # v defined above
UnivariateFinite(x=>0.1, z=>0.9) (Multiclass{3} samples)

julia> pdf(d, :y) # allowed as `:y in levels(v)`

v = categorical([:x, :x, :y, :x, :z, :w])
probs = rand(100, 3)
probs = probs ./ sum(probs, dims=2)
julia> UnivariateFinite([:x, :y, :z], probs, pool=v)
100-element UnivariateFiniteVector{Multiclass{4},Symbol,UInt32,Float64}:
 UnivariateFinite{Multiclass{4}}(x=>0.194, y=>0.3, z=>0.505)
 UnivariateFinite{Multiclass{4}}(x=>0.727, y=>0.234, z=>0.0391)
 UnivariateFinite{Multiclass{4}}(x=>0.674, y=>0.00535, z=>0.321)
 UnivariateFinite{Multiclass{4}}(x=>0.292, y=>0.339, z=>0.369)

Probability augmentation

Unless augment=true, sums of elements along the last axis (row-sums in the case of a matrix) must be equal to one, and otherwise such an array is created by inserting appropriate elements ahead of those provided. This means the provided probabilities are associated with the the classes c2, c3, ..., cn.

UnivariateFinite(prob_given_class; pool=nothing, ordered=false)

Construct a discrete univariate distribution whose finite support is the set of keys of the provided dictionary, prob_given_class, and whose values specify the corresponding probabilities.

The type requirements on the keys of the dictionary are the same as the elements of support given above with this exception: if non-categorical elements (raw labels) are used as keys, then pool=... must be specified and cannot be missing.

If the values (probabilities) are arrays instead of scalars, then an abstract array of UnivariateFinite elements is created, with the same size as the array.

recursive_getproperty(object, nested_name::Expr)

Call getproperty recursively on object to extract the value of some nested property, as in the following example:

julia> object = (X = (x = 1, y = 2), Y = 3)
julia> recursive_getproperty(object, :(X.y))
recursively_setproperty!(object, nested_name::Expr, value)

Set a nested property of an object to value, as in the following example:

julia> mutable struct Foo

julia> mutable struct Bar

julia> object = Foo(Bar(1, 2), 3)
Foo(Bar(1, 2), 3)

julia> recursively_setproperty!(object, :(X.y), 42)

julia> object
Foo(Bar(1, 42), 3)

Where to place code implementing new models

Note that different packages can implement models having the same name without causing conflicts, although an MLJ user cannot simultaneously load two such models.

There are two options for making a new model implementation available to all MLJ users:

  1. Native implementations (preferred option). The implementation code lives in the same package that contains the learning algorithms implementing the interface. An example is EvoTrees.jl. In this case, it is sufficient to open an issue at MLJ requesting the package to be registered with MLJ. Registering a package allows the MLJ user to access its models' metadata and to selectively load them.

  2. Separate interface package. Implementation code lives in a separate interface package, which has the algorithm providing package as a dependency. See the template repository MLJExampleInterface.jl.

Additionally, one needs to ensure that the implementation code defines the package_name and load_path model traits appropriately, so that MLJ's @load macro can find the necessary code (see MLJModels/src for examples).

How to add models to the MLJ model registry?

The MLJ model registry is located in the MLJModels.jl repository. To add a model, you need to follow these steps

  • Ensure your model conforms to the interface defined above

  • Raise an issue at MLJModels.jl and point out where the MLJ-interface implementation is, e.g. by providing a link to the code.

  • An administrator will then review your implementation and work with you to add the model to the registry