Boston with LightGBM

Download the notebook, the raw script, or the annotated script for this tutorial (right-click on the link and save). Main author: Yaqub Alwan (IQVIA).

Getting started

using MLJ
using PrettyPrinting
import DataFrames
import Statistics
using PyPlot
using StableRNGs

@load LGBMRegressor
LGBMRegressor(
    num_iterations = 10,
    learning_rate = 0.1,
    num_leaves = 31,
    max_depth = -1,
    tree_learner = "serial",
    histogram_pool_size = -1.0,
    min_data_in_leaf = 20,
    min_sum_hessian_in_leaf = 0.001,
    lambda_l1 = 0.0,
    lambda_l2 = 0.0,
    min_gain_to_split = 0.0,
    feature_fraction = 1.0,
    feature_fraction_seed = 2,
    bagging_fraction = 1.0,
    bagging_freq = 0,
    bagging_seed = 3,
    early_stopping_round = 0,
    max_bin = 255,
    init_score = "",
    objective = "regression",
    categorical_feature = Int64[],
    data_random_seed = 1,
    is_sparse = true,
    is_unbalance = false,
    metric = ["l2"],
    metric_freq = 1,
    is_training_metric = false,
    ndcg_at = [1, 2, 3, 4, 5],
    num_machines = 1,
    num_threads = 0,
    local_listen_port = 12400,
    time_out = 120,
    machine_list_file = "",
    save_binary = false,
    device_type = "cpu") @990

Let us try LightGBM out by doing a regression task on the Boston house prices dataset. This is a commonly used dataset so there is a loader built into MLJ.

Here, the objective is to show how LightGBM can do better than a Linear Regressor with minimal effort.

We start out by taking a quick peek at the data itself and its statistical properties.

features, targets = @load_boston
features = DataFrames.DataFrame(features)
@show size(features)
@show targets[1:3]
first(features, 3) |> pretty
size(features) = (506, 12)
targets[1:3] = [24.0, 21.6, 34.7]
┌────────────┬────────────┬────────────┬────────────┬────────────┬────────────┬────────────┬────────────┬────────────┬────────────┬────────────┬────────────┐
│ Crim       │ Zn         │ Indus      │ NOx        │ Rm         │ Age        │ Dis        │ Rad        │ Tax        │ PTRatio    │ Black      │ LStat      │
│ Float64    │ Float64    │ Float64    │ Float64    │ Float64    │ Float64    │ Float64    │ Float64    │ Float64    │ Float64    │ Float64    │ Float64    │
│ Continuous │ Continuous │ Continuous │ Continuous │ Continuous │ Continuous │ Continuous │ Continuous │ Continuous │ Continuous │ Continuous │ Continuous │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┼────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ 0.00632    │ 18.0       │ 2.31       │ 0.538      │ 6.575      │ 65.2       │ 4.09       │ 1.0        │ 296.0      │ 15.3       │ 396.9      │ 4.98       │
│ 0.02731    │ 0.0        │ 7.07       │ 0.469      │ 6.421      │ 78.9       │ 4.9671     │ 2.0        │ 242.0      │ 17.8       │ 396.9      │ 9.14       │
│ 0.02729    │ 0.0        │ 7.07       │ 0.469      │ 7.185      │ 61.1       │ 4.9671     │ 2.0        │ 242.0      │ 17.8       │ 392.83     │ 4.03       │
└────────────┴────────────┴────────────┴────────────┴────────────┴────────────┴────────────┴────────────┴────────────┴────────────┴────────────┴────────────┘

We can also describe the dataframe

DataFrames.describe(features)
12×8 DataFrame
│ Row │ variable │ mean     │ min     │ median  │ max     │ nunique │ nmissing │ eltype   │
│     │ Symbol   │ Float64  │ Float64 │ Float64 │ Float64 │ Nothing │ Nothing  │ DataType │
├─────┼──────────┼──────────┼─────────┼─────────┼─────────┼─────────┼──────────┼──────────┤
│ 1   │ Crim     │ 3.61352  │ 0.00632 │ 0.25651 │ 88.9762 │         │          │ Float64  │
│ 2   │ Zn       │ 11.3636  │ 0.0     │ 0.0     │ 100.0   │         │          │ Float64  │
│ 3   │ Indus    │ 11.1368  │ 0.46    │ 9.69    │ 27.74   │         │          │ Float64  │
│ 4   │ NOx      │ 0.554695 │ 0.385   │ 0.538   │ 0.871   │         │          │ Float64  │
│ 5   │ Rm       │ 6.28463  │ 3.561   │ 6.2085  │ 8.78    │         │          │ Float64  │
│ 6   │ Age      │ 68.5749  │ 2.9     │ 77.5    │ 100.0   │         │          │ Float64  │
│ 7   │ Dis      │ 3.79504  │ 1.1296  │ 3.20745 │ 12.1265 │         │          │ Float64  │
│ 8   │ Rad      │ 9.54941  │ 1.0     │ 5.0     │ 24.0    │         │          │ Float64  │
│ 9   │ Tax      │ 408.237  │ 187.0   │ 330.0   │ 711.0   │         │          │ Float64  │
│ 10  │ PTRatio  │ 18.4555  │ 12.6    │ 19.05   │ 22.0    │         │          │ Float64  │
│ 11  │ Black    │ 356.674  │ 0.32    │ 391.44  │ 396.9   │         │          │ Float64  │
│ 12  │ LStat    │ 12.6531  │ 1.73    │ 11.36   │ 37.97   │         │          │ Float64  │

Do the usual train/test partitioning. This is important so we can estimate generalisation.

train, test = partition(eachindex(targets), 0.70, shuffle=true,
                        rng=StableRNG(52))
([52, 17, 330, 191, 265, 172, 19, 481, 94, 493, 490, 463, 430, 282, 169, 408, 437, 89, 234, 72, 130, 461, 135, 398, 287, 351, 50, 363, 182, 346, 99, 80, 203, 255, 60, 366, 459, 87, 71, 359, 390, 200, 21, 142, 483, 369, 304, 505, 377, 121, 29, 37, 20, 65, 258, 281, 133, 362, 116, 426, 27, 211, 489, 165, 55, 420, 256, 149, 98, 495, 6, 69, 53, 260, 434, 318, 128, 168, 160, 70, 131, 90, 181, 331, 187, 321, 157, 223, 343, 110, 427, 198, 144, 348, 44, 176, 350, 113, 31, 482, 36, 303, 190, 396, 189, 380, 152, 251, 415, 245, 45, 114, 237, 141, 132, 136, 78, 283, 353, 16, 117, 263, 344, 63, 262, 424, 306, 11, 312, 372, 277, 43, 123, 86, 308, 416, 57, 289, 497, 405, 284, 272, 103, 236, 32, 502, 356, 75, 414, 28, 302, 208, 407, 143, 423, 59, 368, 466, 115, 197, 334, 473, 174, 242, 428, 462, 317, 111, 469, 84, 1, 153, 393, 498, 127, 231, 247, 79, 421, 381, 222, 455, 374, 9, 66, 395, 224, 18, 345, 139, 300, 202, 34, 192, 339, 24, 432, 352, 299, 354, 243, 328, 313, 292, 218, 269, 503, 4, 285, 394, 35, 389, 47, 291, 41, 436, 413, 442, 148, 496, 384, 173, 166, 48, 409, 288, 367, 81, 225, 338, 42, 112, 323, 324, 171, 446, 315, 464, 229, 319, 364, 491, 73, 137, 422, 365, 227, 38, 105, 271, 207, 74, 92, 298, 209, 278, 453, 275, 379, 470, 204, 444, 22, 332, 39, 458, 457, 163, 250, 178, 460, 201, 196, 314, 186, 268, 106, 316, 228, 468, 164, 124, 270, 406, 297, 118, 307, 341, 97, 7, 376, 474, 257, 467, 371, 62, 274, 412, 438, 439, 445, 311, 327, 241, 399, 403, 215, 125, 375, 180, 183, 146, 122, 388, 161, 488, 431, 238, 401, 476, 235, 433, 177, 452, 220, 296, 101, 383, 61, 456, 261, 46, 440, 342, 309, 167, 232, 23, 417, 347, 109, 355, 322, 253, 216, 120, 244, 26, 397, 170, 162, 67, 230, 154], [233, 385, 30, 91, 140, 335, 25, 358, 254, 226, 279, 156, 301, 214, 492, 96, 76, 249, 259, 400, 280, 83, 449, 159, 305, 88, 337, 3, 477, 199, 185, 494, 184, 402, 425, 392, 360, 378, 295, 479, 104, 441, 108, 290, 58, 273, 56, 294, 193, 10, 325, 478, 386, 49, 326, 82, 54, 206, 349, 129, 194, 107, 480, 475, 472, 195, 77, 252, 485, 248, 499, 155, 15, 219, 188, 382, 100, 504, 239, 387, 221, 443, 447, 179, 210, 418, 64, 150, 410, 465, 404, 12, 102, 276, 5, 212, 119, 486, 8, 501, 51, 310, 320, 340, 429, 95, 240, 451, 213, 175, 500, 293, 138, 333, 448, 145, 357, 134, 471, 205, 68, 126, 329, 391, 14, 33, 336, 370, 264, 217, 411, 13, 266, 454, 435, 267, 147, 286, 40, 2, 158, 373, 419, 85, 487, 450, 506, 151, 246, 484, 93, 361])

Let us investigation some of the commonly tweaked LightGBM parameters. We start with looking at a learning curve for number of boostings.

lgb = LGBMRegressor() #initialised a model with default params
lgbm = machine(lgb, features[train, :], targets[train, 1])
boostrange = range(lgb, :num_iterations, lower=2, upper=500)
curve = learning_curve!(lgbm, resampling=CV(nfolds=5),
                        range=boostrange, resolution=100,
                        measure=rms)


figure(figsize=(8,6))
plot(curve.parameter_values, curve.measurements)
xlabel("Number of rounds", fontsize=14)
ylabel("RMSE", fontsize=14)

It looks like that we don't need to go much past 100 boosts

Since LightGBM is a gradient based learning method, we also have a learning rate parameter which controls the size of gradient updates. Let us look at a learning curve for this parameter too

lgb = LGBMRegressor() #initialised a model with default params
lgbm = machine(lgb, features[train, :], targets[train, 1])
learning_range = range(lgb, :learning_rate, lower=1e-3, upper=1, scale=:log)
curve = learning_curve!(lgbm, resampling=CV(nfolds=5),
                        range=learning_range, resolution=100,
                        measure=rms)


figure(figsize=(8,6))
plot(curve.parameter_values, curve.measurements)
xscale("log")
xlabel("Learning rate (log scale)", fontsize=14)
ylabel("RMSE", fontsize=14)

It seems like near 0.5 is a reasonable place. Bearing in mind that for lower values of learning rate we possibly require more boosting in order to converge, so the default value of 100 might not be sufficient for convergence. We leave this as an exercise to the reader. We can still try to tune this parameter, however.

Finally let us check number of datapoints required to produce a leaf in an individual tree. This parameter controls the complexity of individual learner trees, and too low a value might lead to overfitting.

lgb = LGBMRegressor() #initialised a model with default params
lgbm = machine(lgb, features[train, :], targets[train, 1])
Machine{LGBMRegressor} @403 trained 0 times.
  args: 
    1:	Source @853 ⏎ `Table{AbstractArray{Continuous,1}}`
    2:	Source @131 ⏎ `AbstractArray{Continuous,1}`

dataset is small enough and the lower and upper sets the tree to have certain number of leaves

leaf_range = range(lgb, :min_data_in_leaf, lower=1, upper=50)


curve = learning_curve!(lgbm, resampling=CV(nfolds=5),
                        range=leaf_range, resolution=50,
                        measure=rms)

figure(figsize=(8,6))
plot(curve.parameter_values, curve.measurements)
xlabel("Min data in leaf", fontsize=14)
ylabel("RMSE", fontsize=14)

It does not seem like there is a huge risk for overfitting, and lower is better for this parameter.

Using the learning curves above we can select some small-ish ranges to jointly search for the best combinations of these parameters via cross validation.

r1 = range(lgb, :num_iterations, lower=50, upper=100)
r2 = range(lgb, :min_data_in_leaf, lower=2, upper=10)
r3 = range(lgb, :learning_rate, lower=1e-1, upper=1e0)
tm = TunedModel(model=lgb, tuning=Grid(resolution=5),
                resampling=CV(rng=StableRNG(123)), ranges=[r1,r2,r3],
                measure=rms)
mtm = machine(tm, features, targets)
fit!(mtm, rows=train);

Lets see what the cross validation best model parameters turned out to be?

best_model = fitted_params(mtm).best_model
@show best_model.learning_rate
@show best_model.min_data_in_leaf
@show best_model.num_iterations
best_model.learning_rate = 0.325
best_model.min_data_in_leaf = 10
best_model.num_iterations = 50

Great, and now let's predict using the held out data.

predictions = predict(mtm, rows=test)
rms_score = round(rms(predictions, targets[test, 1]), sigdigits=4)

@show rms_score
rms_score = 3.744