Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with xgboost export in python: not same values splits #537

Open
antoinemertz opened this issue Aug 19, 2022 · 2 comments
Open

Issue with xgboost export in python: not same values splits #537

antoinemertz opened this issue Aug 19, 2022 · 2 comments

Comments

@antoinemertz
Copy link

Hi,

Thanks for the work on this package. I'm using m2cgen to convert a XGBoost model into VBA code. But when using the code produce by m2cgen I've got some predictions that are really different from the one get after training my model in Python. Here are some examples:
Capture

And I've looked in the XGBoost booster after training and compare to the output (in Python) form m2cgen. Here is what I have from m2cgen

import math
def sigmoid(x):
    if x < 0.0:
        z = math.exp(x)
        return z / (1.0 + z)
    return 1.0 / (1.0 + math.exp(-x))
def score(input):
    if input[1] < 2.0:
        if input[1] < 1.0:
            var0 = -0.3193863
        else:
            var0 = -0.046659842
    else:
        if input[7] < 867.94:
            var0 = 0.058621403
        else:
            var0 = 0.25975806
    if input[4] < 5654.47:
        if input[3] < 0.38662624:
            var1 = -0.029487507
        else:
            var1 = 0.16083813
    else:
        if input[1] < 2.0:
            var1 = -0.32378462
        else:
            var1 = -0.08247565
    if input[0] < 1.0:
        if input[2] < 0.8:
            var2 = -0.15353489
        else:
            var2 = 0.081936955
    else:
        if input[4] < 2989.61:
            var2 = 0.13463722
        else:
            var2 = -0.042515814
    if input[5] < 0.11556604:
        if input[12] < 0.11059804:
            var3 = -0.1621976
        else:
            var3 = 0.30593434
    else:
        if input[11] < 661.39:
            var3 = 0.0063493266
        else:
            var3 = 0.15387529
    if input[9] < 0.12683104:
        if input[19] < 197.56:
            var4 = -0.25690553
        else:
            var4 = 0.06560632
    else:
        if input[8] < 0.11749347:
            var4 = -0.018011741
        else:
            var4 = 0.10678521
    if input[7] < 1790.11:
        if input[8] < 0.11749347:
            var5 = -0.091719724
        else:
            var5 = 0.048037946
    else:
        if input[1] < 3.0:
            var5 = 0.058297392
        else:
            var5 = 0.18175843
    if input[6] < 1351.78:
        if input[10] < 3.0:
            var6 = -0.0012290713
        else:
            var6 = 0.10081242
    else:
        if input[17] < 0.07381933:
            var6 = -0.12741692
        else:
            var6 = 0.038392954
    if input[1] < 3.0:
        if input[15] < 0.12838633:
            var7 = -0.081163615
        else:
            var7 = 0.019387348
    else:
        if input[20] < 0.29835963:
            var7 = 0.1156334
        else:
            var7 = -0.17409053
    if input[5] < 0.062735535:
        if input[3] < 0.5642857:
            var8 = -0.2049814
        else:
            var8 = 0.12192867
    else:
        if input[13] < 17.0:
            var8 = -0.0035746796
        else:
            var8 = 0.10629323
    if input[19] < 179.98:
        if input[4] < 15379.7:
            var9 = -0.010353668
        else:
            var9 = -0.19715081
    else:
        if input[21] < 1744.96:
            var9 = 0.08414988
        else:
            var9 = -0.31387258
    if input[9] < 0.12683104:
        if input[19] < 90.45:
            var10 = -0.15493616
        else:
            var10 = 0.05997152
    else:
        if input[11] < -1390.57:
            var10 = -0.12933072
        else:
            var10 = 0.028274538
    if input[14] < 3.0:
        if input[7] < 652.72:
            var11 = -0.061523404
        else:
            var11 = 0.018090146
    else:
        if input[20] < -0.015413969:
            var11 = 0.122180216
        else:
            var11 = -0.07323579
    if input[18] < 35.0:
        if input[17] < 0.105689526:
            var12 = -0.058067013
        else:
            var12 = 0.035271224
    else:
        if input[20] < 0.42494825:
            var12 = 0.067990474
        else:
            var12 = -0.13910332
    if input[8] < 0.11749347:
        if input[22] < 0.06889495:
            var13 = -0.109115146
        else:
            var13 = -0.011202088
    else:
        if input[16] < -161.82:
            var13 = -0.01581455
        else:
            var13 = 0.10806873
    if input[18] < 8.0:
        if input[17] < 0.0007647209:
            var14 = -0.10060249
        else:
            var14 = 0.04555326
    else:
        if input[15] < 0.15912667:
            var14 = 0.0012086431
        else:
            var14 = 0.061486576
    if input[11] < -1708.65:
        if input[1] < 4.0:
            var15 = -0.14637202
        else:
            var15 = 0.10264576
    else:
        if input[19] < 2421.29:
            var15 = 0.008009123
        else:
            var15 = 0.17349313
    if input[20] < 0.21551265:
        if input[20] < -0.14049701:
            var16 = -0.069627054
        else:
            var16 = 0.012490782
    else:
        if input[7] < 4508.38:
            var16 = -0.13310793
        else:
            var16 = 0.2982378
    if input[4] < 10364.37:
        if input[18] < 46.0:
            var17 = -0.00067418563
        else:
            var17 = 0.07025912
    else:
        if input[19] < 32.3:
            var17 = -0.11449907
        else:
            var17 = 0.102952585
    if input[12] < 0.11059804:
        if input[9] < 0.06418919:
            var18 = -0.12425961
        else:
            var18 = -0.0036558604
    else:
        if input[9] < 0.06418919:
            var18 = 0.3158906
        else:
            var18 = 0.06434954
    var19 = sigmoid(var0 + var1 + var2 + var3 + var4 + var5 + var6 + var7 + var8 + var9 + var10 + var11 + var12 + var13 + var14 + var15 + var16 + var17 + var18)
    return [1.0 - var19, var19]

And this is what I have in the booster:

def score_booster(input):
    if input[1]<2:
        if input[1]<1:
            var0=-0.319386303
        else:
            var0=-0.0466598421
    else:
        if input[7]<867.940002:
            var0=0.0586214028
        else:
            var0=0.259758055

    if input[4]<5654.47021:
        if input[3]<0.386626244:
            var1=-0.0294875074
        else:
            var1=0.160838127
    else:
        if input[1]<2:
            var1=-0.32378462
        else:
            var1=-0.0824756473

    if input[0]<1:
        if input[2]<0.800000012:
            var2=-0.153534889
        else:
            var2=0.0819369555
    else:
        if input[4]<2989.61011:
            var2=0.134637222
        else:
            var2=-0.0425158143

    if input[5]<0.115566038:
        if input[12]<0.110598043:
            var3=-0.162197605
        else:
            var3=0.30593434
    else:
        if input[11]<661.390015:
            var3=0.00634932658
        else:
            var3=0.153875291

    if input[9]<0.12683104:
        if input[19]<197.559998:
            var4=-0.256905526
        else:
            var4=0.0656063184
    else:
        if input[8]<0.117493473:
            var4=-0.0180117413
        else:
            var4=0.106785208

    if input[7]<1790.10999:
        if input[8]<0.117493473:
            var5=-0.0917197242
        else:
            var5=0.0480379462
    else:
        if input[1]<3:
            var5=0.058297392
        else:
            var5=0.181758434

    if input[6]<1351.78003:
        if input[10]<3:
            var6=-0.00122907129
        else:
            var6=0.10081242
    else:
        if input[17]<0.0738193318:
            var6=-0.127416924
        else:
            var6=0.0383929536

    if input[1]<3:
        if input[15]<0.128386334:
            var7=-0.081163615
        else:
            var7=0.0193873476
    else:
        if input[20]<0.298359632:
            var7=0.115633398
        else:
            var7=-0.174090534

    if input[5]<0.0627355352:
        if input[3]<0.564285696:
            var8=-0.204981402
        else:
            var8=0.12192867
    else:
        if input[13]<17:
            var8=-0.00357467961
        else:
            var8=0.106293231

    if input[19]<179.979996:
        if input[4]<15379.7002:
            var9=-0.0103536677
        else:
            var9=-0.197150812
    else:
        if input[21]<1744.95996:
            var9=0.0841498822
        else:
            var9=-0.313872576
    
    if input[9]<0.12683104:
        if input[19]<90.4499969:
            var10=-0.154936165
        else:
            var10=0.0599715188
    else:
        if input[11]<-1390.56995:
            var10=-0.129330724
        else:
            var10=0.028274538

    if input[14]<3:
        if input[7]<652.719971:
            var11=-0.061523404
        else:
            var11=0.0180901457
    else:
        if input[20]<-0.0154139688:
            var11=0.122180216
        else:
            var11=-0.0732357875

    if input[18]<35:
        if input[17]<0.105689526:
            var12=-0.0580670126
        else:
            var12=0.0352712236
    else:
        if input[20]<0.424948245:
            var12=0.0679904744
        else:
            var12=-0.139103323

    if input[8]<0.117493473:
        if input[22]<0.0688949525:
            var13=-0.109115146
        else:
            var13=-0.0112020876
    else:
        if input[16]<-161.820007:
            var13=-0.0158145502
        else:
            var13=0.108068727

    if input[18]<8:
        if input[17]<0.000764720899:
            var14=-0.100602493
        else:
            var14=0.0455532596
    else:
        if input[15]<0.159126669:
            var14=0.00120864308
        else:
            var14=0.0614865758
    
    if input[11]<-1708.65002:
        if input[1]<4:
            var15=-0.14637202
        else:
            var15=0.102645762
    else:
        if input[19]<2421.29004:
            var15=0.00800912268
        else:
            var15=0.173493132

    if input[20]<0.215512648:
        if input[20]<-0.140497014:
            var16=-0.069627054
        else:
            var16=0.012490782
    else:
        if input[7]<4508.37988:
            var16=-0.13310793
        else:
            var16=0.298237801
    
    if input[4]<10364.3701:
        if input[18]<46:
            var17=-0.000674185634
        else:
            var17=0.0702591166
    else:
        if input[19]<32.2999992:
            var17=-0.11449907
        else:
            var17=0.102952585

    if input[12]<0.110598043:
        if input[9]<0.0641891882:
            var18=-0.124259613
        else:
            var18=-0.00365586043
    else:
        if input[9]<0.0641891882:
            var18=0.31589061
        else:
            var18=0.0643495396
    
    return (var0, var1, var2, var3, var4, var5, var6, var7, var8, var9, var10, var11, var12, var13, var14, var15, var16, var17, var18)

For me, but I'm not expert, it seems that in the function returned by m2cgen the floats are 32 bits floats in the if/else conditions and not in the booster. So if in my data one sample has the value of the split m2cgen is not giving back the right value. Is there a trick to force floats to 64 bits?

Thanks in advance for your return,

Antoine

@ghost
Copy link

ghost commented Feb 24, 2023

I have a similar issue with the export to C.
In python

model.predict_proba(xi)

returns [0.25445127 0.7455487 ]
and in C its [0.19037, 0.8097]

In cases if the values are around 0.5, the classification result differs using the probability threshold of 0.5. (64 errors in 2000 samples).

Any ideas how I could debug that?

@hqliling
Copy link

I have same problems, any ideas?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

2 participants