Skip to content

Commit

Permalink
[c_api] Improve ANSI compatibility by avoiding <stdbool.h> (#4697)
Browse files Browse the repository at this point in the history
* [c_api] Improve ANSI compatibility by avoiding <stdbool.h>

* fixes in response to CI linting

* inline NOLINT instead of separate test

* moving length declaration to non-ANSI C conditional

* [c_api] Align expected return type in `basic.py` with new c_api type.
  • Loading branch information
drewmiller authored Nov 15, 2021
1 parent 874e635 commit bfb346c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
11 changes: 8 additions & 3 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include <stdbool.h>
#endif


Expand Down Expand Up @@ -434,12 +433,12 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,
/* --- start Booster interfaces */

/*!
* \brief Get boolean representing whether booster is fitting linear trees.
* \brief Get int representing whether booster is fitting linear trees.
* \param handle Handle of booster
* \param[out] out The address to hold linear trees indicator
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLinear(BoosterHandle handle, bool* out);
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLinear(BoosterHandle handle, int* out);

/*!
* \brief Create a new boosting learner.
Expand Down Expand Up @@ -1361,11 +1360,17 @@ static char* LastErrorMsg() { static THREAD_LOCAL char err_msg[512] = "Everythin
#endif
/*!
* \brief Set string message of the last error.
* \note
* This will call unsafe ``sprintf`` when compiled using C standards before C99.
* \param msg Error message
*/
INLINE_FUNCTION void LGBM_SetLastError(const char* msg) {
#if !defined(__cplusplus) && (!defined(__STDC__) || (__STDC_VERSION__ < 199901L))
sprintf(LastErrorMsg(), "%s", msg); /* NOLINT(runtime/printf) */
#else
const int err_buf_len = 512;
snprintf(LastErrorMsg(), err_buf_len, "%s", msg);
#endif
}

#endif /* LIGHTGBM_C_API_H_ */
4 changes: 2 additions & 2 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3598,7 +3598,7 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
predictor = self._to_predictor(deepcopy(kwargs))
leaf_preds = predictor.predict(data, -1, pred_leaf=True)
nrow, ncol = leaf_preds.shape
out_is_linear = ctypes.c_bool(False)
out_is_linear = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetLinear(
self.handle,
ctypes.byref(out_is_linear)))
Expand All @@ -3607,7 +3607,7 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
params=self.params,
default_value=None
)
new_params["linear_tree"] = out_is_linear.value
new_params["linear_tree"] = bool(out_is_linear.value)
train_set = Dataset(data, label, silent=True, params=new_params)
new_params['refit_decay_rate'] = decay_rate
new_booster = Booster(new_params, train_set)
Expand Down
8 changes: 6 additions & 2 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1639,10 +1639,14 @@ int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
API_END();
}

int LGBM_BoosterGetLinear(BoosterHandle handle, bool* out) {
int LGBM_BoosterGetLinear(BoosterHandle handle, int* out) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out = ref_booster->GetBoosting()->IsLinear();
if (ref_booster->GetBoosting()->IsLinear()) {
*out = 1;
} else {
*out = 0;
}
API_END();
}

Expand Down

0 comments on commit bfb346c

Please sign in to comment.