Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions doc/tutorials/custom_metric_obj.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,30 @@ namely prediction and labels. For implementing ``SLE``, we define:
hess = hessian(predt, dtrain)
return grad, hess

.. code-block:: r

library(xgboost)

gradient <- function(predt, dtrain) {
# Compute the gradient squared log error.
y <- getinfo(dtrain, "label")
return((log1p(predt) - log1p(y)) / (predt + 1))
}

hessian <- function(predt, dtrain) {
# Compute the hessian for squared log error.
y <- getinfo(dtrain, "label")
return((-log1p(predt) + log1p(y) + 1) / (predt + 1)^2)
}

squared_log <- function(predt, dtrain) {
# Squared Log Error objective. A simplified version for RMSLE used as
# objective function.
predt[predt < -1] <- -1 + 1e-6
grad <- gradient(predt, dtrain)
hess <- hessian(predt, dtrain)
return(list(grad = grad, hess = hess))
}

In the above code snippet, ``squared_log`` is the objective function we want. It accepts a
numpy array ``predt`` as model prediction, and the training DMatrix for obtaining required
Expand All @@ -104,6 +128,13 @@ a callback function for XGBoost during training by passing it as an argument to
num_boost_round=10,
obj=squared_log)

.. code-block:: r

xgb.train(list(tree_method = 'hist', seed = 1994), # any other tree method is fine.
data = dtrain,
nrounds = 10,
obj = squared_log)

Notice that in our definition of the objective, whether we subtract the labels from the
prediction or the other way around is important. If you find the training error goes up
instead of down, this might be the reason.
Expand All @@ -126,6 +157,16 @@ monitor our model's performance. As mentioned above, the default metric for ``S
elements = np.power(np.log1p(y) - np.log1p(predt), 2)
return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y)))

.. code-block:: r

rmsle <- function(predt, dtrain) {
# Root mean squared log error metric.
y <- getinfo(dtrain, "label")
predt[predt < -1] <- -1 + 1e-6
elements <- (log1p(y) - log1p(predt))^2
return(list("RRMSLE", sqrt(sum(elements) / length(y))))
}

Since we are demonstrating in Python, the metric or objective need not be a function, any
callable object should suffice. Similar to the objective function, our metric also
accepts ``predt`` and ``dtrain`` as inputs, but returns the name of the metric itself and
Expand All @@ -143,6 +184,17 @@ a floating point value as the result. After passing it into XGBoost as argument
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
evals_result=results)

.. code-block:: r

xgb.train(list(tree_method = 'hist', seed = 1994,
disable_default_eval_metric = 1),
data = dtrain,
nrounds = 10,
obj = squared_log,
feval = rmsle,
watchlist = list(dtrain = dtrain, dtest = dtest),
evals_result = results)

We will be able to see XGBoost printing something like:

.. code-block:: none
Expand Down Expand Up @@ -209,6 +261,30 @@ metric functions implementing the same underlying metric for comparison,
errors[y != out] = 1.0
return 'PyMError', np.sum(errors) / dtrain.num_row()

.. code-block:: r

library(xgboost)

merror_with_transform <- function(predt, dtrain) {
# Used when custom objective is supplied.
y <- getinfo(dtrain, "label")
n_classes <- length(predt) / nrow(dtrain)
# Like custom objective, the predt is untransformed leaf weight when custom objective
# is provided.

# With the use of `feval` parameter in train function, custom metric receives
# raw input only when custom objective is also being used. Otherwise custom metric
# will receive transformed prediction.
stopifnot(length(predt) == nrow(dtrain) * n_classes)
predt_mat <- matrix(predt, nrow = nrow(dtrain), ncol = n_classes, byrow = TRUE)
out <- apply(predt_mat, 1, which.max) - 1 # R is 1-indexed, adjust to 0-indexed

stopifnot(length(y) == length(out))

errors <- as.numeric(y != out)
return(list('RMError', sum(errors) / nrow(dtrain)))
}

The above function is only needed when we want to use custom objective and XGBoost doesn't
know how to transform the prediction. The normal implementation for multi-class error
function is:
Expand All @@ -222,6 +298,16 @@ function is:
errors[y != out] = 1.0
return 'PyMError', np.sum(errors) / dtrain.num_row()

.. code-block:: r

merror <- function(predt, dtrain) {
# Used when there's no custom objective.
y <- getinfo(dtrain, "label")
# No need to do transform, XGBoost handles it internally.
errors <- as.numeric(y != predt)
return(list('RMError', sum(errors) / nrow(dtrain)))
}


Next we need the custom softprob objective:

Expand All @@ -237,6 +323,18 @@ Next we need the custom softprob objective:

return grad, hess

.. code-block:: r

softprob_obj <- function(predt, dtrain) {
# Loss function. Computing the gradient and approximated hessian (diagonal).
# Reimplements the `multi:softprob` inside XGBoost.

# Full implementation is available in the R demo script linked below
# ...

return(list(grad = grad, hess = hess))
}

Lastly we can train the model using ``obj`` and ``custom_metric`` parameters:

.. code-block:: python
Expand All @@ -252,6 +350,19 @@ Lastly we can train the model using ``obj`` and ``custom_metric`` parameters:
evals=[(m, "train")],
)

.. code-block:: r

Xy <- xgb.DMatrix(X, label = y)
booster <- xgb.train(
list(num_class = kClasses, disable_default_eval_metric = TRUE),
data = m,
nrounds = kRounds,
obj = softprob_obj,
feval = merror_with_transform,
evals_result = custom_results,
watchlist = list(train = m)
)

Or if you don't need the custom objective and just want to supply a metric that's not
available in XGBoost:

Expand All @@ -271,6 +382,22 @@ available in XGBoost:
evals=[(m, "train")],
)

.. code-block:: r

booster <- xgb.train(
list(
num_class = kClasses,
disable_default_eval_metric = TRUE,
objective = "multi:softmax"
),
data = m,
nrounds = kRounds,
# Use a simpler metric implementation.
feval = merror,
evals_result = custom_results,
watchlist = list(train = m)
)

We use ``multi:softmax`` to illustrate the differences of transformed prediction. With
``softprob`` the output prediction array has shape ``(n_samples, n_classes)`` while for
``softmax`` it's ``(n_samples, )``. A demo for multi-class objective function is also
Expand All @@ -297,6 +424,29 @@ function (not scoring functions) from scikit-learn out of the box:
)
reg.fit(X, y, eval_set=[(X, y)])

.. code-block:: r

library(xgboost)

# Load diabetes dataset (you may need to load from another source or create similar data)
# X, y <- load_diabetes_data() # placeholder - use your own data loading

mae_metric <- function(predt, dtrain) {
y <- getinfo(dtrain, "label")
mae <- mean(abs(y - predt))
return(list("MAE", mae))
}

# Using the low-level interface
dtrain <- xgb.DMatrix(X, label = y)
model <- xgb.train(
list(tree_method = "hist"),
data = dtrain,
nrounds = 100,
feval = mae_metric,
watchlist = list(train = dtrain)
)

Also, for custom objective function, users can define the objective without having to
access ``DMatrix``:

Expand All @@ -322,3 +472,38 @@ access ``DMatrix``:
return grad, hess

clf = xgb.XGBClassifier(tree_method="hist", objective=softprob_obj)

.. code-block:: r

softmax <- function(x) {
exp_x <- exp(x - max(x)) # subtract max for numerical stability
return(exp_x / sum(exp_x))
}

softprob_obj <- function(labels, predt) {
# Note: In R, this simplified interface is not directly available
# You would typically use the DMatrix-based interface shown earlier
rows <- length(labels)
classes <- ncol(predt)
grad <- matrix(0, nrow = rows, ncol = classes)
hess <- matrix(0, nrow = rows, ncol = classes)
eps <- 1e-6

for (r in 1:rows) {
target <- labels[r] + 1 # R is 1-indexed
p <- softmax(predt[r, ])
for (c in 1:classes) {
g <- ifelse(c == target, p[c] - 1.0, p[c])
h <- max(2.0 * p[c] * (1.0 - p[c]), eps)
grad[r, c] <- g
hess[r, c] <- h
}
}

grad <- as.vector(t(grad))
hess <- as.vector(t(hess))
return(list(grad = grad, hess = hess))
}

# For R, use the low-level interface with DMatrix
# clf <- xgb.train(list(tree_method = "hist"), data = dtrain, obj = softprob_obj)
18 changes: 18 additions & 0 deletions doc/tutorials/dart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,21 @@ Sample Script
num_round = 50
bst = xgb.train(param, dtrain, num_round)
preds = bst.predict(dtest)

.. code-block:: r

library(xgboost)
# read in data
dtrain <- xgb.DMatrix('demo/data/agaricus.txt.train?format=libsvm')
dtest <- xgb.DMatrix('demo/data/agaricus.txt.test?format=libsvm')
# specify parameters via list
param <- list(booster = 'dart',
max_depth = 5, learning_rate = 0.1,
objective = 'binary:logistic',
sample_type = 'uniform',
normalize_type = 'tree',
rate_drop = 0.1,
skip_drop = 0.5)
num_round <- 50
bst <- xgb.train(param, dtrain, num_round)
preds <- predict(bst, dtest)
19 changes: 19 additions & 0 deletions doc/tutorials/feature_interaction_constraint.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ Suppose the following code fits your model without feature interaction constrain
num_boost_round = 1000, evals = evallist,
early_stopping_rounds = 10)

.. code-block:: r

model_no_constraints <- xgb.train(params, dtrain,
nrounds = 1000, watchlist = evallist,
early_stopping_rounds = 10)

Then fitting with feature interaction constraints only requires adding a single
parameter:

Expand All @@ -174,6 +180,19 @@ parameter:
num_boost_round = 1000, evals = evallist,
early_stopping_rounds = 10)

.. code-block:: r

params_constrained <- params
# Use nested list to define feature interaction constraints
params_constrained$interaction_constraints <- '[[0, 2], [1, 3, 4], [5, 6]]'
# Features 0 and 2 are allowed to interact with each other but with no other feature
# Features 1, 3, 4 are allowed to interact with one another but with no other feature
# Features 5 and 6 are allowed to interact with each other but with no other feature

model_with_constraints <- xgb.train(params_constrained, dtrain,
nrounds = 1000, watchlist = evallist,
early_stopping_rounds = 10)

**************************
Using feature name instead
**************************
Expand Down
36 changes: 35 additions & 1 deletion doc/tutorials/intercept.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,40 @@ and multi-class, the ``base_margin`` is a matrix with size ``(n_samples, n_targe
reg_1.fit(X, y, base_margin=m)
reg_1.predict(X, base_margin=m)

.. code-block:: r

library(xgboost)

# Generate regression data
set.seed(42)
n_samples <- 100
n_features <- 20
X <- matrix(rnorm(n_samples * n_features), nrow = n_samples, ncol = n_features)
y <- rnorm(n_samples)

# First model
dtrain <- xgb.DMatrix(X, label = y)
reg <- xgb.train(
list(objective = "reg:squarederror"),
data = dtrain,
nrounds = 10
)

# Request for raw prediction (output_margin = TRUE)
m <- predict(reg, dtrain, outputmargin = TRUE)

# Second model with base_margin
dtrain_with_margin <- xgb.DMatrix(X, label = y, base_margin = m)
reg_1 <- xgb.train(
list(objective = "reg:squarederror"),
data = dtrain_with_margin,
nrounds = 10
)

# Predict with base_margin
dtest <- xgb.DMatrix(X, base_margin = m)
predict(reg_1, dtest)


It specifies the bias for each sample and can be used for stacking an XGBoost model on top
of other models, see :ref:`sphx_glr_python_examples_boost_from_prediction.py` for a worked
Expand Down Expand Up @@ -136,4 +170,4 @@ We have:
E[c_i] &= \exp{(F(x_i) + \ln{\gamma_i})} \\
E[c_i] &= g^{-1}(F(x_i) + g(\gamma_i))

As you can see, we can use the ``base_margin`` for modeling with offset similar to GLMs
As you can see, we can use the ``base_margin`` for modeling with offset similar to GLMs
Loading