Crabs with XGBoost

Download the notebook, the raw script, or the annotated script for this tutorial (right-click on the link and save). This example is inspired from this post showing how to use XGBoost.

First steps

Again, the crabs dataset is so common that there is a simple load function for it:

using MLJ
using StatsBase
using Random
using PyPlot
using CategoricalArrays
using PrettyPrinting
import DataFrames
using LossFunctions

X, y = @load_crabs
X = DataFrames.DataFrame(X)
@show size(X)
@show y[1:3]
first(X, 3) |> pretty
size(X) = (200, 5)
y[1:3] = CategoricalArrays.CategoricalValue{String,UInt32}["B", "B", "B"]
┌────────────┬────────────┬────────────┬────────────┬────────────┐
│ FL         │ RW         │ CL         │ CW         │ BD         │
│ Float64    │ Float64    │ Float64    │ Float64    │ Float64    │
│ Continuous │ Continuous │ Continuous │ Continuous │ Continuous │
├────────────┼────────────┼────────────┼────────────┼────────────┤
│ 8.1        │ 6.7        │ 16.1       │ 19.0       │ 7.0        │
│ 8.8        │ 7.7        │ 18.1       │ 20.8       │ 7.4        │
│ 9.2        │ 7.8        │ 19.0       │ 22.4       │ 7.7        │
└────────────┴────────────┴────────────┴────────────┴────────────┘

It's a classification problem with the following classes:

levels(y) |> pprint
["B", "O"]

Note that the dataset is currently sorted by target, let's shuffle it to avoid the obvious issues this may cause

Random.seed!(523)
perm = randperm(length(y))
X = X[perm,:]
y = y[perm];

It's not a very big dataset so we will likely overfit it badly using something as sophisticated as XGBoost but it will do for a demonstration.

train, test = partition(eachindex(y), 0.70, shuffle=true, rng=52)
@load XGBoostClassifier
xgb_model = XGBoostClassifier()
XGBoostClassifier(
    num_round = 100,
    booster = "gbtree",
    disable_default_eval_metric = 0,
    eta = 0.3,
    gamma = 0.0,
    max_depth = 6,
    min_child_weight = 1.0,
    max_delta_step = 0.0,
    subsample = 1.0,
    colsample_bytree = 1.0,
    colsample_bylevel = 1.0,
    lambda = 1.0,
    alpha = 0.0,
    tree_method = "auto",
    sketch_eps = 0.03,
    scale_pos_weight = 1.0,
    updater = "auto",
    refresh_leaf = 1,
    process_type = "default",
    grow_policy = "depthwise",
    max_leaves = 0,
    max_bin = 256,
    predictor = "cpu_predictor",
    sample_type = "uniform",
    normalize_type = "tree",
    rate_drop = 0.0,
    one_drop = 0,
    skip_drop = 0.0,
    feature_selector = "cyclic",
    top_k = 0,
    tweedie_variance_power = 1.5,
    objective = "automatic",
    base_score = 0.5,
    eval_metric = "mlogloss",
    seed = 0) @696

Let's check whether the training and is balanced, StatsBase.countmap is useful for that:

countmap(y[train]) |> pprint
Dict("B" => 70, "O" => 70)

which is pretty balanced. You could check the same on the test set and full set and the same comment would still hold.

XGBoost machine

Wrap a machine around an XGBoost model (XGB) and the data:

xgb  = XGBoostClassifier()
xgbm = machine(xgb, X, y)
Machine{XGBoostClassifier} @364 trained 0 times.
  args: 
    1:	Source @659 ⏎ `Table{AbstractArray{Continuous,1}}`
    2:	Source @860 ⏎ `AbstractArray{Multiclass{2},1}`

We will tune it varying the number of rounds used and generate a learning curve

r = range(xgb, :num_round, lower=50, upper=500)
curve = learning_curve!(xgbm, range=r, resolution=50,
                        measure=HingeLoss())
(parameter_name = "num_round",
 parameter_scale = :linear,
 parameter_values = [50, 59, 68, 78, 87, 96, 105, 114, 123, 133, 142, 151, 160, 169, 179, 188, 197, 206, 215, 224, 234, 243, 252, 261, 270, 280, 289, 298, 307, 316, 326, 335, 344, 353, 362, 371, 381, 390, 399, 408, 417, 427, 436, 445, 454, 463, 472, 482, 491, 500],
 measurements = [0.3120960593223572, 0.3002642095088959, 0.28969401121139526, 0.28333112597465515, 0.27584463357925415, 0.2703813910484314, 0.26844921708106995, 0.26478245854377747, 0.26034364104270935, 0.2557550072669983, 0.2547350525856018, 0.2518833875656128, 0.24957433342933655, 0.24690766632556915, 0.24586206674575806, 0.2421659678220749, 0.24172064661979675, 0.23799732327461243, 0.2363770306110382, 0.23531433939933777, 0.23471292853355408, 0.23354791104793549, 0.23355154693126678, 0.23136228322982788, 0.23119983077049255, 0.23020349442958832, 0.22977910935878754, 0.22889092564582825, 0.2243921160697937, 0.22364655137062073, 0.22226248681545258, 0.2218494713306427, 0.22111089527606964, 0.22052700817584991, 0.22054848074913025, 0.21970060467720032, 0.21936337649822235, 0.2194405496120453, 0.21801182627677917, 0.21841095387935638, 0.21867485344409943, 0.21839170157909393, 0.21813097596168518, 0.21723081171512604, 0.21732456982135773, 0.21644990146160126, 0.21593405306339264, 0.21572065353393555, 0.21549715101718903, 0.21505625545978546],)

Let's have a look

figure(figsize=(8,6))
plot(curve.parameter_values, curve.measurements)
xlabel("Number of rounds", fontsize=14)
ylabel("HingeLoss", fontsize=14)
xticks([10, 100, 200, 500], fontsize=12)
Cross entropy vs Num Round

So, in short, using more rounds helps. Let's arbitrarily fix it to 200.

xgb.num_round = 200;

More tuning (1)

Let's now tune the maximum depth of each tree and the minimum child weight in the boosting.

r1 = range(xgb, :max_depth, lower=3, upper=10)
r2 = range(xgb, :min_child_weight, lower=0, upper=5)

tm = TunedModel(model=xgb, tuning=Grid(resolution=8),
                resampling=CV(rng=11), ranges=[r1,r2],
                measure=cross_entropy)
mtm = machine(tm, X, y)
fit!(mtm, rows=train)
Machine{ProbabilisticTunedModel{Grid,…}} @786 trained 1 time.
  args: 
    1:	Source @035 ⏎ `Table{AbstractArray{Continuous,1}}`
    2:	Source @114 ⏎ `AbstractArray{Multiclass{2},1}`

Great, as always we can investigate the tuning by using report and can, for instance, plot a heatmap of the measurements:

r = report(mtm)

res = r.plotting

md = res.parameter_values[:,1]
mcw = res.parameter_values[:,2]

figure(figsize=(8,6))
tricontourf(md, mcw, res.measurements)

xlabel("Maximum tree depth", fontsize=14)
ylabel("Minimum child weight", fontsize=14)
xticks(3:2:10, fontsize=12)
yticks(fontsize=12)
Hyperparameter heatmap

Let's extract the optimal model and inspect its parameters:

xgb = fitted_params(mtm).best_model
@show xgb.max_depth
@show xgb.min_child_weight
xgb.max_depth = 3
xgb.min_child_weight = 2.857142857142857

More tuning (2)

Let's examine the effect of gamma:

xgbm = machine(xgb, X, y)
r = range(xgb, :gamma, lower=0, upper=10)
curve = learning_curve!(xgbm, range=r, resolution=30,
                        measure=cross_entropy);

it looks like the gamma parameter substantially affects model performance:

@show round(minimum(curve.measurements), sigdigits=3)
@show round(maximum(curve.measurements), sigdigits=3)
round(minimum(curve.measurements), sigdigits = 3) = 0.211
round(maximum(curve.measurements), sigdigits = 3) = 0.464

More tuning (3)

Let's examine the effect of subsample and colsample_bytree:

r1 = range(xgb, :subsample, lower=0.6, upper=1.0)
r2 = range(xgb, :colsample_bytree, lower=0.6, upper=1.0)
tm = TunedModel(model=xgb, tuning=Grid(resolution=8),
                resampling=CV(rng=234), ranges=[r1,r2],
                measure=cross_entropy)
mtm = machine(tm, X, y)
fit!(mtm, rows=train)
Machine{ProbabilisticTunedModel{Grid,…}} @668 trained 1 time.
  args: 
    1:	Source @505 ⏎ `Table{AbstractArray{Continuous,1}}`
    2:	Source @574 ⏎ `AbstractArray{Multiclass{2},1}`

and the usual procedure to visualise it:

r = report(mtm)

res = r.plotting

ss = res.parameter_values[:,1]
cbt = res.parameter_values[:,2]

figure(figsize=(8,6))
tricontourf(ss, cbt, res.measurements)

xlabel("Sub sample", fontsize=14)
ylabel("Col sample by tree", fontsize=14)
xticks(fontsize=12)
yticks(fontsize=12)
Hyperparameter heatmap

Let's retrieve the best models:

xgb = fitted_params(mtm).best_model
@show xgb.subsample
@show xgb.colsample_bytree
xgb.subsample = 0.6
xgb.colsample_bytree = 0.9428571428571428

We could continue with more fine tuning but given how small the dataset is, it doesn't make much sense. How does it fare on the test set?

ŷ = predict_mode(mtm, rows=test)
round(accuracy(ŷ, y[test]), sigdigits=3)