# 1. magic to print version
# 2. magic so that the notebook will reload external python modules
%matplotlib inline
%load_ext watermark
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import pandas as pd
import m2cgen as m2c
import sklearn.datasets as datasets
from xgboost import XGBClassifier, XGBRegressor
import onnxruntime as rt
from skl2onnx import convert_sklearn, update_registered_converter
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx.common.shape_calculator import calculate_linear_classifier_output_shapes
from onnxmltools.convert.xgboost.operator_converters.XGBoost import convert_xgboost
# prevent scientific notations
pd.set_option('display.float_format', lambda x: '%.3f' % x)
%watermark -a 'Ethen' -u -d -v -p numpy,pandas,sklearn,m2cgen,xgboost
Once we train our machine learning model, depending on the use case, we may wish to operationize it by putting it behind a service for (near) real time inferencing. We can definitely generate predictions in batch offline, store them in some downstream tables or look up services, and pull out pre-computed predictions when needed. Although this batch prediction approach might sound easier to implement, and we might not have to worry about latency issues when it comes to real time services, this paradigm does come with its limitations. e.g.
It's very common in industry setting to prototype a machine learning model in Python and translate it into other languages such as C++, Java, etc, when it comes to deploying. This usually happens where the core application is written in other languages such as C++, Java, etc. and it is an extremely time sensitive application where we can't afford the cost of calling an external API to fetch the model prediction.
In this section, we'll be looking at how we can achieve this with Gradient Boosted Trees, specifically XGBoost. Different library might have different ways to doing this, but the concept should be similar.
Tree Structure
A typical model dump from XGBoost looks like the following:
booster[0]:
0:[bmi<0.00942232087] yes=1,no=2,missing=1
1:[bmi<-0.0218342301] yes=3,no=4,missing=3
3:[bmi<-0.0584798381] yes=7,no=8,missing=7
7:leaf=25.84091
8:leaf=33.0292702
4:[bp<0.0270366594] yes=9,no=10,missing=9
9:leaf=38.7487526
10:leaf=51.0882378
2:[bp<0.0235937908] yes=5,no=6,missing=5
5:leaf=53.0696678
6:leaf=69.4000015
booster[1]:
0:[bmi<0.00511107268] yes=1,no=2,missing=1
1:[bp<0.0390867069] yes=3,no=4,missing=3
3:[bmi<-0.0207564179] yes=7,no=8,missing=7
7:leaf=21.0474758
8:leaf=27.7326946
4:[bmi<0.000799824367] yes=9,no=10,missing=9
9:leaf=36.1850548
10:leaf=14.9188232
2:[bmi<0.0730132312] yes=5,no=6,missing=5
5:[bp<6.75072661e-05] yes=11,no=12,missing=11
11:leaf=31.3889732
12:leaf=43.4056664
6:[bp<-0.0498541184] yes=13,no=14,missing=13
13:leaf=13.0395498
14:leaf=59.377037
There are 3 distinct information:
booster
Gradient Boosting Tree is an ensemble tree method, each new booster indicates the start of a new tree. The number of trees we have will be equivalent to the number of trees we specified for the model (e.g. for the sklearn XGBoost API, n_estimators
controls this) multiplied by the number of distinct classes. For regression model or binary classification model, the number of booster in the model dump will be exactly equal to the number of trees we've specified. Whereas for multi class classification, say we have 3 classes, then tree 0 will contribute to the raw prediction of class 0, tree 1 to class 1, tree 2 to class 2, tree 3 to class 0 and so on.node
Following the booster is each tree's if-else structure. e.g. for node 0, if the feature bmi
is less than a threshold, then it will branch to node 1 else it will branch to node 2.leaf
Once we reach the leaf, we can accumulate the response prediction. e.g. node 7 is a leaf, and the prediction for this node is 25.84091.Raw Prediction
We mentioned that to get the prediction for a given input, we sum up the response prediction associated from each tree's leaf node. The holds true for regression models, but for other models, we will need to perform a transformation on top the raw prediction to get to the probabilities. e.g. for when building a binary classification, a logistic transformation will be needed on top of the raw prediction, whereas for the multi-class classification, a softmax transformation is needed.
All the examples below, be it regression, binary classification or multi class classification all follow the same structure.
X, y = datasets.load_diabetes(return_X_y=True, as_frame=True)
X = X[["age", "sex", "bmi", "bp"]]
X.head()
regression_model_params = {
'n_estimators': 2,
'max_depth': 3,
'base_score': 0.0
}
regression_model = XGBRegressor(**regression_model_params).fit(X, y)
regression_model
regression_model.get_booster().dump_model("regression.txt")
regression_model.predict(X.iloc[[0]])
X, y = datasets.make_classification(n_samples=10000, n_features=5, random_state=42, n_classes=2)
X
binary_model_params = {
'n_estimators': 3,
'max_depth': 3,
'tree_method': 'hist',
'grow_policy': 'lossguide'
}
binary_model = XGBClassifier(**binary_model_params).fit(X, y)
binary_model
binary_model.get_booster().dump_model("binary_class.txt")
inputs = np.array([[0.0, 0.2, 0.4, 0.6, 0.8]])
binary_model.predict_proba(inputs)
X, y = datasets.load_iris(return_X_y=True, as_frame=True)
X.head()
multi_class_model_params = {
'n_estimators': 2,
'max_depth': 3
}
multi_class_model = XGBClassifier(**multi_class_model_params).fit(X, y)
multi_class_model
multi_class_model.get_booster().dump_model("multi_class.txt")
inputs = np.array([[5.1, 3.5, 1.4, 0.2]])
multi_class_model.predict_proba(inputs)
The rest of the content is about implementing the boosted tree inferencing logic in C++, all the code resides in the gbt_inference
folder for those interested. In practice, we don't always have to rely on naive code that we've implemented to solidify our understanding. e.g. the m2cgen (Model 2 Code Generator) project is one of the many projects out there that focuses on converting a trained model into native code. If we export our regression model, we can see that the inferencing logic is indeed a bunch of if else statements followed by a summation at the very end.
code = m2c.export_to_c(regression_model)
print(code)
Another way to achieving this is through ONNX, directly quoting from its documentation.
ONNX Runtime provides an easy way to run machine learned models with high performance on CPU or GPU without dependencies on the training framework. Machine learning frameworks are usually optimized for batch training rather than for prediction, which is a more common scenario in applications, sites, and services
We'll walk through the process of converting our boosted tree model into ONNX format, and benchmark the inference runtime. Here, we are doing it for classification model, but the process should be similar for regression based models.
n_features = 5
X, y = datasets.make_classification(n_samples=10000, n_features=n_features, random_state=42, n_classes=2)
feature_names = [f'f{i}'for i in range(n_features)]
print(f'num rows: {X.shape[0]}, num cols: {X.shape[1]}')
X
tree = XGBClassifier(
n_estimators=20,
max_depth=3,
learning_rate=0.2,
tree_method='hist',
verbosity=1
)
tree.fit(X, y, eval_set=[(X, y)])
tree.predict_proba(X[:1])
xgboost_checkpoint = 'model.json'
tree.save_model(xgboost_checkpoint)
tree_loaded = XGBClassifier()
tree_loaded.load_model(xgboost_checkpoint)
assert np.allclose(tree.predict_proba(X[:1]), tree_loaded.predict_proba(X[:1]))
input_payloads = [
{
'f0': -2.24456934,
'f1': -1.36232827,
'f2': 1.55433334,
'f3': -2.0869092,
'f4': -1.27760482
}
]
rows = []
for input_payload in input_payloads:
row = [input_payload[feature] for feature in feature_names]
rows.append(row)
np_rows = np.array(rows, dtype=np.float32)
tree.predict_proba(np_rows)[:, 1]
%%timeit
rows = []
for input_payload in input_payloads:
row = [input_payload[feature] for feature in feature_names]
rows.append(row)
np_rows = np.array(rows, dtype=np.float32)
tree.predict_proba(np_rows)[:, 1]
def convert_xgboost_to_onnx(model, num_features: int, checkpoint: str):
# boiler plate code for registering the xgboost converter
update_registered_converter(
XGBClassifier, 'XGBoostXGBClassifier',
calculate_linear_classifier_output_shapes, convert_xgboost,
options={'nocl': [True, False], 'zipmap': [True, False, 'columns']}
)
# perform the actual conversion specifying the types of our inputs,
# at the time of writing this, it doesn't support categorical types
# that are common in boosted tree libraries such as xgboost or lightgbm
model_onnx = convert_sklearn(
model, 'xgboost',
[('input', FloatTensorType([None, num_features]))],
target_opset={'': 15, 'ai.onnx.ml': 2}
)
with open(checkpoint, "wb") as f:
f.write(model_onnx.SerializeToString())
onnx_model_checkpoint = 'xgboost.onnx'
convert_xgboost_to_onnx(tree, len(feature_names), onnx_model_checkpoint)
Upon porting our model to onnx format, we can use it for inferencing. This section uses the Python API for benchmarking.
sess = rt.InferenceSession(onnx_model_checkpoint)
input_name = sess.get_inputs()[0].name
output_names = [output.name for output in sess.get_outputs()]
np_rows = np.array(rows, dtype=np.float32)
onnx_predict_label, onnx_predict_score = sess.run(output_names, {input_name: np_rows})
onnx_predict_score
%%timeit
rows = []
for input_payload in input_payloads:
row = [input_payload[feature] for feature in feature_names]
rows.append(row)
np_rows = np.array(rows, dtype=np.float32)
onnx_predict_label, onnx_predict_score = sess.run(output_names, {input_name: np_rows})
Note, at the time of writing this document, the onnx converter doesn't support categorical variables splits from common boosted tree libraries such as xgboost or lightgbm, we will have to leverage other ways of dealing with categorical variables if we wish to leverage onnx for inferencing.