Skip to content

Commit

Permalink
Allow JVM-Package to access inplace predict method (#9167)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Stephan T. Lavavej <[email protected]>
Co-authored-by: Jiaming Yuan <[email protected]>
Co-authored-by: Joe <[email protected]>
  • Loading branch information
4 people authored Sep 11, 2023
1 parent 9027686 commit d05ea58
Show file tree
Hide file tree
Showing 5 changed files with 384 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ public class Booster implements Serializable, KryoSerializable {
// handle to the booster.
private long handle = 0;
private int version = 0;
/**
* Type of prediction, used for inplace_predict.
*/
public enum PredictionType {
kValue(0),
kMargin(1);

private Integer ptype;
private PredictionType(final Integer ptype) {
this.ptype = ptype;
}
public Integer getPType() {
return ptype;
}
}

/**
* Create a new Booster with empty stage.
Expand Down Expand Up @@ -375,6 +390,97 @@ private synchronized float[][] predict(DMatrix data,
return predicts;
}

/**
* Perform thread-safe prediction.
*
* @param data Flattened input matrix of features for prediction
* @param nrow The number of preditions to make (count of input matrix rows)
* @param ncol The number of features in the model (count of input matrix columns)
* @param missing Value indicating missing element in the <code>data</code> input matrix
*
* @return predict Result matrix
*/
public float[][] inplace_predict(float[] data,
int nrow,
int ncol,
float missing) throws XGBoostError {
int[] iteration_range = new int[2];
iteration_range[0] = 0;
iteration_range[1] = 0;
return this.inplace_predict(data, nrow, ncol,
missing, iteration_range, PredictionType.kValue, null);
}

/**
* Perform thread-safe prediction.
*
* @param data Flattened input matrix of features for prediction
* @param nrow The number of preditions to make (count of input matrix rows)
* @param ncol The number of features in the model (count of input matrix columns)
* @param missing Value indicating missing element in the <code>data</code> input matrix
* @param iteration_range Specifies which layer of trees are used in prediction. For
* example, if a random forest is trained with 100 rounds.
* Specifying `iteration_range=[10, 20)`, then only the forests
* built during [10, 20) (half open set) rounds are used in this
* prediction.
*
* @return predict Result matrix
*/
public float[][] inplace_predict(float[] data,
int nrow,
int ncol,
float missing, int[] iteration_range) throws XGBoostError {
return this.inplace_predict(data, nrow, ncol,
missing, iteration_range, PredictionType.kValue, null);
}


/**
* Perform thread-safe prediction.
*
* @param data Flattened input matrix of features for prediction
* @param nrow The number of preditions to make (count of input matrix rows)
* @param ncol The number of features in the model (count of input matrix columns)
* @param missing Value indicating missing element in the <code>data</code> input matrix
* @param iteration_range Specifies which layer of trees are used in prediction. For
* example, if a random forest is trained with 100 rounds.
* Specifying `iteration_range=[10, 20)`, then only the forests
* built during [10, 20) (half open set) rounds are used in this
* prediction.
* @param predict_type What kind of prediction to run.
* @return predict Result matrix
*/
public float[][] inplace_predict(float[] data,
int nrow,
int ncol,
float missing,
int[] iteration_range,
PredictionType predict_type,
float[] base_margin) throws XGBoostError {
if (iteration_range.length != 2) {
throw new XGBoostError(new String("Iteration range is expected to be [begin, end)."));
}
int ptype = predict_type.getPType();

int begin = iteration_range[0];
int end = iteration_range[1];

float[][] rawPredicts = new float[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredictFromDense(handle, data, nrow, ncol,
missing,
begin, end, ptype, base_margin, rawPredicts));

int col = rawPredicts[0].length / nrow;
float[][] predicts = new float[nrow][col];
int r, c;
for (int i = 0; i < rawPredicts[0].length; i++) {
r = i / col;
c = i % col;
predicts[r][c] = rawPredicts[0][i];
}
return predicts;
}

/**
* Predict leaf indices given the data
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ public final static native int XGBoosterEvalOneIter(long handle, int iter, long[
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask,
int ntree_limit, float[][] predicts);

public final static native int XGBoosterPredictFromDense(long handle, float[] data,
long nrow, long ncol, float missing, int iteration_begin, int iteration_end, int predict_type, float[] margin,
float[][] predicts);

public final static native int XGBoosterLoadModel(long handle, String fname);

public final static native int XGBoosterSaveModel(long handle, String fname);
Expand Down Expand Up @@ -154,10 +158,6 @@ final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count,
public final static native int XGDMatrixSetInfoFromInterface(
long handle, String field, String json);

@Deprecated
public final static native int XGDeviceQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, float missing, int nthread, int maxBin, long[] out);

public final static native int XGQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, java.util.Iterator<ColumnBatch> ref, String config, long[] out);

Expand Down
79 changes: 79 additions & 0 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,85 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterPredictFromDense
* Signature: (J[FJJFIII[F[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense(
JNIEnv *jenv, jclass jcls, jlong jhandle, jfloatArray jdata, jlong num_rows, jlong num_features,
jfloat missing, jint iteration_begin, jint iteration_end, jint predict_type,
jfloatArray jmargin, jobjectArray jout) {
API_BEGIN();
BoosterHandle handle = reinterpret_cast<BoosterHandle>(jhandle);

/**
* Create array interface.
*/
namespace linalg = xgboost::linalg;
jfloat *data = jenv->GetFloatArrayElements(jdata, nullptr);
xgboost::Context ctx;
auto t_data = linalg::MakeTensorView(
ctx.Device(),
xgboost::common::Span{data, static_cast<std::size_t>(num_rows * num_features)}, num_rows,
num_features);
auto s_array = linalg::ArrayInterfaceStr(t_data);

/**
* Create configuration object.
*/
xgboost::Json config{xgboost::Object{}};
config["cache_id"] = xgboost::Integer{};
config["type"] = xgboost::Integer{static_cast<std::int32_t>(predict_type)};
config["iteration_begin"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_begin)};
config["iteration_end"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_end)};
config["missing"] = xgboost::Number{static_cast<float>(missing)};
config["strict_shape"] = xgboost::Boolean{true};
std::string s_config;
xgboost::Json::Dump(config, &s_config);

/**
* Handle base margin
*/
BoosterHandle proxy{nullptr};

float *margin{nullptr};
if (jmargin) {
margin = jenv->GetFloatArrayElements(jmargin, nullptr);
JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy));
JVM_CHECK_CALL(
XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin)));
}

bst_ulong const *out_shape;
bst_ulong out_dim;
float const *result;
auto ret = XGBoosterPredictFromDense(handle, s_array.c_str(), s_config.c_str(), proxy, &out_shape,
&out_dim, &result);

jenv->ReleaseFloatArrayElements(jdata, data, 0);
if (proxy) {
XGDMatrixFree(proxy);
jenv->ReleaseFloatArrayElements(jmargin, margin, 0);
}

if (ret != 0) {
return ret;
}

std::size_t n{1};
for (std::size_t i = 0; i < out_dim; ++i) {
n *= out_shape[i];
}

jfloatArray jarray = jenv->NewFloatArray(n);

jenv->SetFloatArrayRegion(jarray, 0, n, result);
jenv->SetObjectArrayElement(jout, 0, jarray);

API_END();
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadModel
Expand Down
16 changes: 8 additions & 8 deletions jvm-packages/xgboost4j/src/native/xgboost4j.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit d05ea58

Please sign in to comment.