Skip to content

Commit

Permalink
[Gurobi#374] Implement something for pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
pobonomo committed Nov 19, 2024
1 parent d12f205 commit 7165dc0
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 6 deletions.
7 changes: 3 additions & 4 deletions src/gurobi_ml/modeling/base_predictor_constr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from abc import ABC, abstractmethod

import numpy as np
import gurobipy as gp
import numpy as np

from ..exceptions import ParameterError
from ._submodel import _SubModel
Expand Down Expand Up @@ -109,17 +109,16 @@ def add_validity_domain(self, validity_domain=None, **kwargs):
return

if method != "box":
raise NotImplementedError('validity domain {} not implemented')
raise NotImplementedError("validity domain {} not implemented")

print("Adding boxes")
if X is not None:
self.input.UB = np.minimum(self.input.UB, X.max(axis=0))
self.input.LB = np.maximum(self.input.LB, X.min(axis=0))
self.gp_model.update()

if y is not None:
self.output.UB = np.minimum(self.output.UB, y.max(axis=0))
self.output.LB = np.maximum(self.output.UB, y.min(axis=0))
self.output.LB = np.maximum(self.output.LB, y.min(axis=0))

def _build_submodel(self, gp_model, **kwargs):
"""Predict output from input using predictor or transformer."""
Expand Down
3 changes: 3 additions & 0 deletions src/gurobi_ml/sklearn/column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def __init__(self, gp_model, column_transformer, input_vars, **kwargs):
self._default_name = "col_trans"
super().__init__(gp_model, column_transformer, input_vars, **kwargs)

def add_validity_domain(self, validity_domain=None, **kwargs):
raise NotImplemented("Validity domain not implemented for ColumnTransformer")

# For this class we need to reimplement submodel because we don't want
# to transform input variables to Gurobi variable. We can't do it for categorical
# The input should be unchanged.
Expand Down
2 changes: 1 addition & 1 deletion src/gurobi_ml/sklearn/mlpregressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _mip_model(self, **kwargs):
input_vars = self._input
output = None

kwargs.pop("validity_domain")
kwargs.pop("validity_domain", None)

for i in range(neural_net.n_layers_ - 1):
layer_coefs = neural_net.coefs_[i]
Expand Down
42 changes: 41 additions & 1 deletion src/gurobi_ml/sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
"""


import gurobipy as gp

from gurobi_ml.exceptions import ParameterError

from ..lightgbm_sklearn_api import lightgbm_sklearn_convertors
from ..modeling.base_predictor_constr import AbstractPredictorConstr
from ..modeling.get_convertor import get_convertor
Expand Down Expand Up @@ -87,13 +91,49 @@ def _build_submodel(self, gp_model, *args, **kwargs):
their input and output. They are just containers of other objects that will
do it.
"""
validity_domain = kwargs.pop("validity_domain", None)
self._mip_model(**kwargs)
assert self.output is not None
assert self.input is not None
# We can call validate only after the model is created
self._validate()
self.add_validity_domain(validity_domain, **kwargs)
return self

def add_validity_domain(self, validity_domain=None, **kwargs):
if validity_domain is None:
return
try:
X = validity_domain["X"]
except KeyError:
X = None
try:
y = validity_domain["y"]
except KeyError:
y = None
try:
method = validity_domain["method"]
except KeyError:
return

if method is None or method == "none":
return

if method != "box":
raise NotImplementedError("validity domain {} not implemented")

if X is not None:
for step in self._steps:
if isinstance(step.input, gp.MVar):
step.add_validity_domain({"X": X, "method": method})
break
X = step.transformer.transform(X)
else:
raise ParameterError("No variables in pipeline?")

if y is not None:
self._steps[-1].add_validity_domain({"y": y, "method": method})

def _mip_model(self, **kwargs):
pipeline = self.predictor
gp_model = self.gp_model
Expand All @@ -105,7 +145,7 @@ def _mip_model(self, **kwargs):
transformers["ColumnTransformer"] = add_column_transformer_constr
kwargs["validate_input"] = True

kwargs.pop("validity_domain")
assert "validity_domain" not in kwargs

for transformer in pipeline[:-1]:
convertor = get_convertor(transformer, transformers)
Expand Down

0 comments on commit 7165dc0

Please sign in to comment.