Skip to content

Commit

Permalink
Fix tests for keras with new activation
Browse files Browse the repository at this point in the history
  • Loading branch information
pobonomo committed Jan 2, 2025
1 parent 0807a08 commit 8893a6e
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 252 deletions.
143 changes: 126 additions & 17 deletions Generate_keras_test_network.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,54 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"id": "c8d57e80-9075-4d63-bc36-f9aaad08ea2f",
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf"
"import keras"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"id": "90fa4efb-f9d5-40fb-8e4a-5e3ed1094740",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'3.7.0'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"keras.__version__"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "5d98d000-661e-4495-bef0-49c5eb180aff",
"metadata": {},
"outputs": [],
"source": [
"nn = tf.keras.models.Sequential(\n",
"nn = keras.models.Sequential(\n",
" [\n",
" tf.keras.layers.InputLayer((8,)),\n",
" tf.keras.layers.Dense(30, activation='relu'),\n",
" tf.keras.layers.Dense(1),\n",
" keras.layers.InputLayer((8,)),\n",
" keras.layers.Dense(30, activation='relu'),\n",
" keras.layers.Dense(1),\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"id": "ba3cf3ee-bd25-4180-95c0-2ff42d858a34",
"metadata": {},
"outputs": [],
Expand All @@ -38,29 +59,29 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"id": "247bd200-8026-4f08-8739-9aabb3c37e99",
"metadata": {},
"outputs": [],
"source": [
"(X_train, y_train), (X_test, y_test) = tf.keras.datasets.california_housing.load_data(\n",
" version=\"large\"\n",
" version=\"small\"\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 18,
"id": "a29325dd-1ab1-4cce-81c0-2528e892adb6",
"metadata": {},
"outputs": [],
"source": [
"normalize = tf.keras.layers.Normalization(axis=-1)"
"normalize = keras.layers.Normalization(axis=-1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 19,
"id": "cbecbd91-e100-4568-9424-efd9e3b6d5fc",
"metadata": {},
"outputs": [],
Expand All @@ -72,18 +93,106 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 21,
"id": "5656d2da-ee2d-4a8f-aef3-65876c20193b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 36ms/step - loss: 51257974784.0000 - val_loss: 48780189696.0000\n",
"Epoch 2/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 34ms/step - loss: 51058966528.0000 - val_loss: 48779857920.0000\n",
"Epoch 3/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 30ms/step - loss: 56175738880.0000 - val_loss: 48779501568.0000\n",
"Epoch 4/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - loss: 48874921984.0000 - val_loss: 48779141120.0000\n",
"Epoch 5/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 25ms/step - loss: 52104830976.0000 - val_loss: 48778752000.0000\n",
"Epoch 6/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 41ms/step - loss: 53767278592.0000 - val_loss: 48778342400.0000\n",
"Epoch 7/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 25ms/step - loss: 51997323264.0000 - val_loss: 48777920512.0000\n",
"Epoch 8/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 34ms/step - loss: 52127023104.0000 - val_loss: 48777490432.0000\n",
"Epoch 9/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 31ms/step - loss: 55014318080.0000 - val_loss: 48777023488.0000\n",
"Epoch 10/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 27ms/step - loss: 50627502080.0000 - val_loss: 48776540160.0000\n",
"Epoch 11/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 28ms/step - loss: 52081172480.0000 - val_loss: 48776024064.0000\n",
"Epoch 12/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 21ms/step - loss: 55939633152.0000 - val_loss: 48775487488.0000\n",
"Epoch 13/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 34ms/step - loss: 51670016000.0000 - val_loss: 48774975488.0000\n",
"Epoch 14/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 26ms/step - loss: 55131279360.0000 - val_loss: 48774389760.0000\n",
"Epoch 15/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 30ms/step - loss: 51200266240.0000 - val_loss: 48773820416.0000\n",
"Epoch 16/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 30ms/step - loss: 53789458432.0000 - val_loss: 48773218304.0000\n",
"Epoch 17/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 29ms/step - loss: 50551488512.0000 - val_loss: 48772616192.0000\n",
"Epoch 18/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - loss: 50127593472.0000 - val_loss: 48771956736.0000\n",
"Epoch 19/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 31ms/step - loss: 48622862336.0000 - val_loss: 48771301376.0000\n",
"Epoch 20/20\n",
"\u001b[1m15/15\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 30ms/step - loss: 53927636992.0000 - val_loss: 48770617344.0000\n"
]
},
{
"data": {
"text/plain": [
"<keras.src.callbacks.history.History at 0x7eb858fa50>"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nn.fit(X_train, y_train, epochs=20, validation_data=(X_test, y_test))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "128b1ba9-55e9-4d78-9b31-d2a0da9bb165",
"metadata": {},
"outputs": [],
"source": [
"nn.fit(X_train, y_train, epochs=100, validation_data=(X_test, y_test))"
"nn.save('toto.keras')"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "9c820767-db55-48ba-8dd3-675d06fb5c3d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Sequential name=sequential_1, built=True>"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"keras.saving.load_model('toto.keras')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "128b1ba9-55e9-4d78-9b31-d2a0da9bb165",
"id": "68653ae3-8f6a-4359-be49-b0049e1e0d5b",
"metadata": {},
"outputs": [],
"source": []
Expand All @@ -105,7 +214,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
"version": "3.11.2"
},
"license": {
"full_text": "# Copyright © 2023 Gurobi Optimization, LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# =============================================================================="
Expand Down
4 changes: 2 additions & 2 deletions src/gurobi_ml/modeling/neuralnet/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(self):
if not _HAS_NLEXPR:
raise NoModel(self, "Can't use logistic activation without Gurobi ≥ 12.0")

def mip_model(self, layer, predict_function='predict_proba', **kwargs):
def mip_model(self, layer, predict_function="predict_proba", **kwargs):
"""MIP model for logistic activation on a layer.
Parameters
Expand Down Expand Up @@ -181,7 +181,7 @@ class SoftMax:
def __init__(self):
pass

def mip_model(self, layer, predict_function='predict_proba', **kwargs):
def mip_model(self, layer, predict_function="predict_proba", **kwargs):
"""MIP model for SoftMax activation on a layer.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion src/gurobi_ml/sklearn/mlpregressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(
clean_predictor=False,
**kwargs,
):
assert predictor.out_activation_ in ("identity", )
assert predictor.out_activation_ in ("identity",)
SKRegressor.__init__(
self,
predictor,
Expand Down
Loading

0 comments on commit 8893a6e

Please sign in to comment.