Skip to content

Commit

Permalink
[Gurobi#374] Add a simple implementation of validity domain
Browse files Browse the repository at this point in the history
It works on a neural network. Should work on other objects
except column transformers that are a special case.
  • Loading branch information
pobonomo committed Nov 18, 2024
1 parent c92deee commit d12f205
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 10 deletions.
5 changes: 4 additions & 1 deletion docs/examples/example3_adversarial_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@
# Change last layer activation to identity
nn.out_activation_ = "identity"
# Code to add the neural network to the constraints
pred_constr = add_mlp_regressor_constr(m, nn, x, y)
print(X.shape)
pred_constr = add_mlp_regressor_constr(m, nn, x, y,
validity_domain={"method": "box",
"X":X})

# Restore activation
nn.out_activation_ = "softmax"
Expand Down
4 changes: 3 additions & 1 deletion docs/examples/example4_price_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,9 @@ def peak_season(row):

from gurobi_ml import add_predictor_constr

pred_constr = add_predictor_constr(m, lin_reg, feats, d)
pred_constr = add_predictor_constr(m, lin_reg, feats, d,
validity_domain={"method": "box",
"X":X, "y":y})

pred_constr.print_stats()

Expand Down
7 changes: 0 additions & 7 deletions notebooks/dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
ipywidgets
matplotlib
myst_nb
notebook
pandas
seaborn
Sphinx
sphinx-copybutton
sphinx-pyproject
sphinx-rtd-theme
tensorflow
torch
skorch
../.
36 changes: 35 additions & 1 deletion src/gurobi_ml/modeling/base_predictor_constr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from abc import ABC, abstractmethod

import numpy as np
import gurobipy as gp

from ..exceptions import ParameterError
Expand Down Expand Up @@ -88,7 +89,39 @@ def _validate(self):
self._input = input_vars
self._output = output_vars

def _build_submodel(self, gp_model, *args, **kwargs):
def add_validity_domain(self, validity_domain=None, **kwargs):
if validity_domain is None:
return
try:
X = validity_domain["X"]
except KeyError:
X = None
try:
y = validity_domain["y"]
except KeyError:
y = None
try:
method = validity_domain["method"]
except KeyError:
return

if method is None or method == "none":
return

if method != "box":
raise NotImplementedError('validity domain {} not implemented')

print("Adding boxes")
if X is not None:
self.input.UB = np.minimum(self.input.UB, X.max(axis=0))
self.input.LB = np.maximum(self.input.LB, X.min(axis=0))
self.gp_model.update()

if y is not None:
self.output.UB = np.minimum(self.output.UB, y.max(axis=0))
self.output.LB = np.maximum(self.output.UB, y.min(axis=0))

def _build_submodel(self, gp_model, **kwargs):
"""Predict output from input using predictor or transformer."""
self._input, columns, index = validate_input_vars(self.gp_model, self._input)
self._input_index = index
Expand All @@ -101,6 +134,7 @@ def _build_submodel(self, gp_model, *args, **kwargs):
self._validate()
self._mip_model(**kwargs)
assert self._output is not None
self.add_validity_domain(**kwargs)
return self

def _print_container_steps(self, iterations_name, iterable, file):
Expand Down
2 changes: 2 additions & 0 deletions src/gurobi_ml/sklearn/mlpregressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def _mip_model(self, **kwargs):
input_vars = self._input
output = None

kwargs.pop("validity_domain")

for i in range(neural_net.n_layers_ - 1):
layer_coefs = neural_net.coefs_[i]
layer_intercept = neural_net.intercepts_[i]
Expand Down
2 changes: 2 additions & 0 deletions src/gurobi_ml/sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def _mip_model(self, **kwargs):
transformers["ColumnTransformer"] = add_column_transformer_constr
kwargs["validate_input"] = True

kwargs.pop("validity_domain")

for transformer in pipeline[:-1]:
convertor = get_convertor(transformer, transformers)
steps.append(convertor(gp_model, transformer, input_vars, **kwargs))
Expand Down

0 comments on commit d12f205

Please sign in to comment.