From 3c1dd201cf65b016aaff68f2863ec45497c1151a Mon Sep 17 00:00:00 2001 From: GitHub Action <52708150+marcpinet@users.noreply.github.com> Date: Wed, 24 Apr 2024 02:37:20 +0200 Subject: [PATCH] feat: i don't even remember what i've added --- .gitignore | 3 - .../mnist_loading_saved_model.ipynb | 52 +++---- .../simple_cancer_binary.ipynb | 139 +++++------------- .../simple_diabete_regression.ipynb | 102 ++++++++----- .../simple_mnist_multiclass.ipynb | 90 +++++------- .../simple_cnn_classification_mnist.ipynb | 66 ++++----- .../tic_tac_toe_alternative_dataset_shape.py | 2 +- neuralnetlib/layers.py | 7 + neuralnetlib/model.py | 66 +++++++-- tests/test_model.py | 4 +- 10 files changed, 266 insertions(+), 265 deletions(-) diff --git a/.gitignore b/.gitignore index 3ba8170..f07a984 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,3 @@ -# Dist generator -dist_gen.bat - # Datasets formats *.csv *.npz diff --git a/examples/classification-regression/mnist_loading_saved_model.ipynb b/examples/classification-regression/mnist_loading_saved_model.ipynb index d12febb..ce2c3d5 100644 --- a/examples/classification-regression/mnist_loading_saved_model.ipynb +++ b/examples/classification-regression/mnist_loading_saved_model.ipynb @@ -21,8 +21,8 @@ "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T12:52:21.706906Z", - "start_time": "2024-04-21T12:52:18.726598200Z" + "end_time": "2024-04-23T23:32:44.879695500Z", + "start_time": "2024-04-23T23:32:41.806868Z" } }, "outputs": [], @@ -47,8 +47,8 @@ "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T12:52:21.915810200Z", - "start_time": "2024-04-21T12:52:21.706906Z" + "end_time": "2024-04-23T23:32:45.056739600Z", + "start_time": "2024-04-23T23:32:44.879695500Z" } }, "outputs": [], @@ -68,8 +68,8 @@ "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T12:52:22.072282500Z", - "start_time": "2024-04-21T12:52:21.916810900Z" + "end_time": "2024-04-23T23:32:45.166846Z", + "start_time": "2024-04-23T23:32:45.059739600Z" } }, "outputs": [], @@ -92,8 +92,8 @@ "execution_count": 4, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T12:52:22.233389700Z", - "start_time": "2024-04-21T12:52:22.073284800Z" + "end_time": "2024-04-23T23:32:45.285935300Z", + "start_time": "2024-04-23T23:32:45.167845600Z" } }, "outputs": [], @@ -113,8 +113,8 @@ "execution_count": 5, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T12:52:22.258467800Z", - "start_time": "2024-04-21T12:52:22.234388100Z" + "end_time": "2024-04-23T23:32:45.329886Z", + "start_time": "2024-04-23T23:32:45.288843800Z" } }, "outputs": [], @@ -134,8 +134,8 @@ "execution_count": 6, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T12:52:22.323518700Z", - "start_time": "2024-04-21T12:52:22.257467100Z" + "end_time": "2024-04-23T23:32:45.374527900Z", + "start_time": "2024-04-23T23:32:45.314964200Z" } }, "outputs": [ @@ -143,7 +143,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Validation Accuracy: 0.899\n" + "Validation Accuracy: 0.9738333333333333\n" ] } ], @@ -165,8 +165,8 @@ "execution_count": 7, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T12:52:22.393768500Z", - "start_time": "2024-04-21T12:52:22.318518600Z" + "end_time": "2024-04-23T23:32:45.444303500Z", + "start_time": "2024-04-23T23:32:45.375529400Z" } }, "outputs": [ @@ -174,18 +174,18 @@ "name": "stdout", "output_type": "stream", "text": [ - "Test Accuracy: 0.8863\n", + "Test Accuracy: 0.9549\n", "Confusion Matrix:\n", - "[[ 937 0 0 1 11 7 2 18 1 3]\n", - " [ 0 1097 3 4 0 3 2 4 19 3]\n", - " [ 13 9 858 36 26 1 23 38 16 12]\n", - " [ 8 6 18 899 2 33 2 16 12 14]\n", - " [ 1 0 1 0 944 0 7 2 1 26]\n", - " [ 19 0 0 82 30 701 12 5 23 20]\n", - " [ 18 2 0 0 70 15 849 1 2 1]\n", - " [ 0 9 10 5 15 0 0 945 4 40]\n", - " [ 6 22 3 3 37 26 9 2 803 63]\n", - " [ 3 2 1 11 137 2 0 15 8 830]]\n" + "[[ 958 0 3 0 0 3 7 2 4 3]\n", + " [ 0 1117 1 6 0 1 1 2 6 1]\n", + " [ 5 1 983 11 3 0 4 16 9 0]\n", + " [ 2 0 10 959 0 13 1 7 8 10]\n", + " [ 2 1 6 0 909 0 6 0 0 58]\n", + " [ 9 1 0 20 0 838 8 2 3 11]\n", + " [ 10 4 4 1 5 6 917 0 10 1]\n", + " [ 1 8 10 6 0 0 0 982 0 21]\n", + " [ 5 3 9 7 4 6 5 7 917 11]\n", + " [ 3 5 3 5 10 4 2 7 1 969]]\n" ] } ], diff --git a/examples/classification-regression/simple_cancer_binary.ipynb b/examples/classification-regression/simple_cancer_binary.ipynb index 0509f99..23f6d6f 100644 --- a/examples/classification-regression/simple_cancer_binary.ipynb +++ b/examples/classification-regression/simple_cancer_binary.ipynb @@ -21,8 +21,8 @@ "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:53.026361300Z", - "start_time": "2024-04-21T13:22:52.339942200Z" + "end_time": "2024-04-23T22:56:24.052135200Z", + "start_time": "2024-04-23T22:56:22.927958200Z" } }, "outputs": [], @@ -51,8 +51,8 @@ "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:53.040903100Z", - "start_time": "2024-04-21T13:22:53.026361300Z" + "end_time": "2024-04-23T22:56:24.066137600Z", + "start_time": "2024-04-23T22:56:24.046136900Z" } }, "outputs": [], @@ -73,8 +73,8 @@ "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:53.054442700Z", - "start_time": "2024-04-21T13:22:53.042408400Z" + "end_time": "2024-04-23T22:56:24.079136500Z", + "start_time": "2024-04-23T22:56:24.063137500Z" } }, "outputs": [], @@ -99,11 +99,24 @@ "execution_count": 4, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:53.059957800Z", - "start_time": "2024-04-21T13:22:53.048922300Z" + "end_time": "2024-04-23T22:56:24.340700300Z", + "start_time": "2024-04-23T22:56:24.073137200Z" } }, - "outputs": [], + "outputs": [ + { + "ename": "ValueError", + "evalue": "The first layer must be an Input layer.", + "output_type": "error", + "traceback": [ + "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[1;31mValueError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[1;32mIn[4], line 8\u001B[0m\n\u001B[0;32m 6\u001B[0m model \u001B[38;5;241m=\u001B[39m Model()\n\u001B[0;32m 7\u001B[0m model\u001B[38;5;241m.\u001B[39madd(Input(input_neurons))\n\u001B[1;32m----> 8\u001B[0m \u001B[43mmodel\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43madd\u001B[49m\u001B[43m(\u001B[49m\u001B[43mDense\u001B[49m\u001B[43m(\u001B[49m\u001B[43mhidden_neurons\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mweights_init\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mhe\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mrandom_state\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m42\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 9\u001B[0m model\u001B[38;5;241m.\u001B[39madd(Activation(ReLU()))\n\u001B[0;32m 11\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m _ \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(num_hidden_layers \u001B[38;5;241m-\u001B[39m \u001B[38;5;241m1\u001B[39m):\n", + "File \u001B[1;32m~\\Documents\\Programming\\Python\\Handmade NeuralNetwork\\neuralnetlib\\model.py:38\u001B[0m, in \u001B[0;36mModel.add\u001B[1;34m(self, layer)\u001B[0m\n\u001B[0;32m 36\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mlayers:\n\u001B[0;32m 37\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(layer, Input):\n\u001B[1;32m---> 38\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mThe first layer must be an Input layer.\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 39\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m 40\u001B[0m previous_layer \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mlayers[\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m]\n", + "\u001B[1;31mValueError\u001B[0m: The first layer must be an Input layer." + ] + } + ], "source": [ "input_neurons = x_train.shape[1:][0] # Cancer dataset has 30 features\n", "num_hidden_layers = 5 # Number of hidden layers\n", @@ -132,40 +145,14 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:53.085516700Z", - "start_time": "2024-04-21T13:22:53.058950900Z" + "end_time": "2024-04-23T22:56:24.356216400Z", + "start_time": "2024-04-23T22:56:24.343207100Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model\n", - "-------------------------------------------------\n", - "Layer 1: Input(input_shape=(30,))\n", - "Layer 2: Dense(units=100)\n", - "Layer 3: Activation(ReLU)\n", - "Layer 4: Dense(units=100)\n", - "Layer 5: Activation(ReLU)\n", - "Layer 6: Dense(units=100)\n", - "Layer 7: Activation(ReLU)\n", - "Layer 8: Dense(units=100)\n", - "Layer 9: Activation(ReLU)\n", - "Layer 10: Dense(units=100)\n", - "Layer 11: Activation(ReLU)\n", - "Layer 12: Dense(units=1)\n", - "Layer 13: Activation(Sigmoid)\n", - "-------------------------------------------------\n", - "Loss function: BinaryCrossentropy\n", - "Optimizer: Adam(learning_rate=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)\n", - "-------------------------------------------------\n" - ] - } - ], + "outputs": [], "source": [ "model.compile(loss_function=BinaryCrossentropy(), optimizer=Adam(learning_rate=0.0001))\n", "\n", @@ -181,43 +168,15 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:53.842873Z", - "start_time": "2024-04-21T13:22:53.081003300Z" + "start_time": "2024-04-23T22:56:24.345216800Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[==============================] 100% Epoch 1/20 - loss: 0.6860 - accuracy_score: 0.6308 - 0.04s\n", - "[==============================] 100% Epoch 2/20 - loss: 0.6677 - accuracy_score: 0.7055 - 0.03s\n", - "[==============================] 100% Epoch 3/20 - loss: 0.6323 - accuracy_score: 0.8066 - 0.04s\n", - "[==============================] 100% Epoch 4/20 - loss: 0.5702 - accuracy_score: 0.8901 - 0.05s\n", - "[==============================] 100% Epoch 5/20 - loss: 0.4731 - accuracy_score: 0.9143 - 0.05s\n", - "[==============================] 100% Epoch 6/20 - loss: 0.3540 - accuracy_score: 0.9297 - 0.04s\n", - "[==============================] 100% Epoch 7/20 - loss: 0.2499 - accuracy_score: 0.9429 - 0.04s\n", - "[==============================] 100% Epoch 8/20 - loss: 0.1816 - accuracy_score: 0.9473 - 0.04s\n", - "[==============================] 100% Epoch 9/20 - loss: 0.1418 - accuracy_score: 0.9648 - 0.05s\n", - "[==============================] 100% Epoch 10/20 - loss: 0.1182 - accuracy_score: 0.9714 - 0.04s\n", - "[==============================] 100% Epoch 11/20 - loss: 0.1034 - accuracy_score: 0.9758 - 0.03s\n", - "[==============================] 100% Epoch 12/20 - loss: 0.0927 - accuracy_score: 0.9758 - 0.03s\n", - "[==============================] 100% Epoch 13/20 - loss: 0.0844 - accuracy_score: 0.9802 - 0.03s\n", - "[==============================] 100% Epoch 14/20 - loss: 0.0777 - accuracy_score: 0.9802 - 0.03s\n", - "[==============================] 100% Epoch 15/20 - loss: 0.0722 - accuracy_score: 0.9824 - 0.03s\n", - "[==============================] 100% Epoch 16/20 - loss: 0.0675 - accuracy_score: 0.9846 - 0.03s\n", - "[==============================] 100% Epoch 17/20 - loss: 0.0635 - accuracy_score: 0.9890 - 0.03s\n", - "[==============================] 100% Epoch 18/20 - loss: 0.0600 - accuracy_score: 0.9890 - 0.03s\n", - "[==============================] 100% Epoch 19/20 - loss: 0.0569 - accuracy_score: 0.9890 - 0.04s\n", - "[==============================] 100% Epoch 20/20 - loss: 0.0542 - accuracy_score: 0.9912 - 0.03s\n" - ] - } - ], + "outputs": [], "source": [ - "model.train(x_train, y_train, epochs=20, batch_size=48, metrics=[accuracy_score], random_state=42)" + "model.fit(x_train, y_train, epochs=20, batch_size=48, metrics=[accuracy_score], random_state=42)" ] }, { @@ -229,22 +188,13 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:53.857412200Z", - "start_time": "2024-04-21T13:22:53.843829400Z" + "start_time": "2024-04-23T22:56:24.347215800Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Test loss: 0.06351246680217817\n" - ] - } - ], + "outputs": [], "source": [ "loss = model.evaluate(x_test, y_test)\n", "print(f'Test loss: {loss}')" @@ -259,11 +209,10 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:53.863439400Z", - "start_time": "2024-04-21T13:22:53.852402800Z" + "start_time": "2024-04-23T22:56:24.348216100Z" } }, "outputs": [], @@ -280,25 +229,13 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:53.873465Z", - "start_time": "2024-04-21T13:22:53.859930800Z" + "start_time": "2024-04-23T22:56:24.350216700Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Accuracy: 0.9736842105263158\n", - "Precision: 0.9741062479117941\n", - "Recall: 0.9692460317460317\n", - "F1 Score: 0.9716700622635057\n" - ] - } - ], + "outputs": [], "source": [ "accuracy = accuracy_score(y_pred, y_test)\n", "precision = precision_score(y_pred, y_test)\n", diff --git a/examples/classification-regression/simple_diabete_regression.ipynb b/examples/classification-regression/simple_diabete_regression.ipynb index ae6e71d..64ea18e 100644 --- a/examples/classification-regression/simple_diabete_regression.ipynb +++ b/examples/classification-regression/simple_diabete_regression.ipynb @@ -18,16 +18,17 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { + "is_executing": true, "ExecuteTime": { - "end_time": "2024-04-21T13:22:31.884628500Z", - "start_time": "2024-04-21T13:22:30.866451700Z" + "start_time": "2024-04-24T00:35:13.345636700Z" } }, "outputs": [], "source": [ "from sklearn.datasets import load_diabetes\n", + "import numpy as np\n", "\n", "from neuralnetlib.preprocessing import MinMaxScaler, StandardScaler\n", "from neuralnetlib.activations import Linear, LeakyReLU\n", @@ -47,12 +48,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { - "ExecuteTime": { - "end_time": "2024-04-21T13:22:31.904684200Z", - "start_time": "2024-04-21T13:22:31.886631400Z" - } + "is_executing": true }, "outputs": [], "source": [ @@ -69,11 +67,11 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 12, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:31.915316200Z", - "start_time": "2024-04-21T13:22:31.905686700Z" + "end_time": "2024-04-24T00:35:13.897190200Z", + "start_time": "2024-04-24T00:35:13.895189700Z" } }, "outputs": [], @@ -95,11 +93,11 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 13, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:31.922833400Z", - "start_time": "2024-04-21T13:22:31.914315800Z" + "end_time": "2024-04-24T00:35:13.910701900Z", + "start_time": "2024-04-24T00:35:13.901704800Z" } }, "outputs": [], @@ -131,11 +129,11 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 14, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:31.965385600Z", - "start_time": "2024-04-21T13:22:31.924269100Z" + "end_time": "2024-04-24T00:35:13.950703300Z", + "start_time": "2024-04-24T00:35:13.911705Z" } }, "outputs": [ @@ -174,11 +172,11 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 15, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:32.074543200Z", - "start_time": "2024-04-21T13:22:31.939804400Z" + "end_time": "2024-04-24T00:35:14.078227700Z", + "start_time": "2024-04-24T00:35:13.921705400Z" } }, "outputs": [ @@ -186,7 +184,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "[==============================] 100% Epoch 1/10 - loss: 1.2716 - - 0.02s\n", + "[==============================] 100% Epoch 1/10 - loss: 1.2716 - - 0.01s\n", "[==============================] 100% Epoch 2/10 - loss: 1.2699 - - 0.01s\n", "[==============================] 100% Epoch 3/10 - loss: 1.2680 - - 0.01s\n", "[==============================] 100% Epoch 4/10 - loss: 1.2659 - - 0.01s\n", @@ -197,10 +195,18 @@ "[==============================] 100% Epoch 9/10 - loss: 1.2531 - - 0.01s\n", "[==============================] 100% Epoch 10/10 - loss: 1.2501 - - 0.01s\n" ] + }, + { + "data": { + "text/plain": "" + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "model.train(x_train, y_train, epochs=10, batch_size=32, random_state=42)" + "model.fit(x_train, y_train, epochs=10, batch_size=32, random_state=42)" ] }, { @@ -212,11 +218,11 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 16, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:32.083579600Z", - "start_time": "2024-04-21T13:22:32.071027500Z" + "end_time": "2024-04-24T00:35:14.078227700Z", + "start_time": "2024-04-24T00:35:14.042229900Z" } }, "outputs": [ @@ -242,11 +248,11 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 17, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:32.095622600Z", - "start_time": "2024-04-21T13:22:32.083579600Z" + "end_time": "2024-04-24T00:35:14.148266800Z", + "start_time": "2024-04-24T00:35:14.047231200Z" } }, "outputs": [ @@ -267,23 +273,49 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 8. Printing some metrics" + "## 8. Getting original MAE (without normalization from StandardScaler)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 18, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:22:32.136547300Z", - "start_time": "2024-04-21T13:22:32.095622600Z" + "end_time": "2024-04-24T00:35:14.153268100Z", + "start_time": "2024-04-24T00:35:14.059228300Z" } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MAE (original): 66.43730326965814\n" + ] + } + ], "source": [ - "# 8. We don't print metrics such as accuracy or f1-score because this is a regression problem\n", - "# not a classification-regression one." + "y_pred_scaled = model.predict(x_test)\n", + "\n", + "y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()\n", + "y_test_original = scaler_y.inverse_transform(y_test.reshape(-1, 1)).flatten()\n", + "\n", + "mae_original = np.mean(np.abs(y_test_original - y_pred))\n", + "print(f'MAE (original): {mae_original}')" ] + }, + { + "cell_type": "code", + "execution_count": 18, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-04-24T00:35:14.153268100Z", + "start_time": "2024-04-24T00:35:14.069231300Z" + } + } } ], "metadata": { diff --git a/examples/classification-regression/simple_mnist_multiclass.ipynb b/examples/classification-regression/simple_mnist_multiclass.ipynb index a39c2cd..8cfb668 100644 --- a/examples/classification-regression/simple_mnist_multiclass.ipynb +++ b/examples/classification-regression/simple_mnist_multiclass.ipynb @@ -21,8 +21,8 @@ "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:09:57.920117100Z", - "start_time": "2024-04-21T13:09:53.090418500Z" + "end_time": "2024-04-23T23:29:14.420006500Z", + "start_time": "2024-04-23T23:29:10.910211Z" } }, "outputs": [], @@ -52,8 +52,8 @@ "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:09:58.104511Z", - "start_time": "2024-04-21T13:09:57.923629Z" + "end_time": "2024-04-23T23:29:14.609051100Z", + "start_time": "2024-04-23T23:29:14.415004400Z" } }, "outputs": [], @@ -73,8 +73,8 @@ "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:09:58.215354700Z", - "start_time": "2024-04-21T13:09:58.105511Z" + "end_time": "2024-04-23T23:29:14.698566500Z", + "start_time": "2024-04-23T23:29:14.594050100Z" } }, "outputs": [], @@ -97,8 +97,8 @@ "execution_count": 4, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:09:58.222377500Z", - "start_time": "2024-04-21T13:09:58.217869300Z" + "end_time": "2024-04-23T23:29:14.711565900Z", + "start_time": "2024-04-23T23:29:14.701566400Z" } }, "outputs": [], @@ -134,8 +134,8 @@ "execution_count": 5, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:09:58.255484400Z", - "start_time": "2024-04-21T13:09:58.223384700Z" + "end_time": "2024-04-23T23:29:14.741569600Z", + "start_time": "2024-04-23T23:29:14.705566800Z" } }, "outputs": [ @@ -177,8 +177,8 @@ "execution_count": 6, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:12:25.560796300Z", - "start_time": "2024-04-21T13:09:58.240940Z" + "end_time": "2024-04-23T23:30:26.688161400Z", + "start_time": "2024-04-23T23:29:14.734569600Z" } }, "outputs": [ @@ -186,31 +186,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "[==============================] 100% Epoch 1/20 - loss: 0.5726 - accuracy_score: 0.8099 - 9.12s\n", - "[==============================] 100% Epoch 2/20 - loss: 0.2319 - accuracy_score: 0.9333 - 8.09s\n", - "[==============================] 100% Epoch 3/20 - loss: 0.1948 - accuracy_score: 0.9432 - 7.56s\n", - "[==============================] 100% Epoch 4/20 - loss: 0.1726 - accuracy_score: 0.9502 - 7.38s\n", - "[==============================] 100% Epoch 5/20 - loss: 0.1587 - accuracy_score: 0.9530 - 7.27s\n", - "[==============================] 100% Epoch 6/20 - loss: 0.1487 - accuracy_score: 0.9563 - 7.69s\n", - "[==============================] 100% Epoch 7/20 - loss: 0.1386 - accuracy_score: 0.9587 - 7.58s\n", - "[==============================] 100% Epoch 8/20 - loss: 0.1349 - accuracy_score: 0.9603 - 7.49s\n", - "[==============================] 100% Epoch 9/20 - loss: 0.1320 - accuracy_score: 0.9609 - 7.48s\n", - "[==============================] 100% Epoch 10/20 - loss: 0.1222 - accuracy_score: 0.9635 - 7.24s\n", - "[==============================] 100% Epoch 11/20 - loss: 0.1165 - accuracy_score: 0.9658 - 7.21s\n", - "[==============================] 100% Epoch 12/20 - loss: 0.1131 - accuracy_score: 0.9666 - 7.07s\n", - "[==============================] 100% Epoch 13/20 - loss: 0.1111 - accuracy_score: 0.9667 - 7.02s\n", - "[==============================] 100% Epoch 14/20 - loss: 0.1065 - accuracy_score: 0.9677 - 6.84s\n", - "[==============================] 100% Epoch 15/20 - loss: 0.1028 - accuracy_score: 0.9685 - 7.02s\n", - "[==============================] 100% Epoch 16/20 - loss: 0.1039 - accuracy_score: 0.9683 - 7.09s\n", - "[==============================] 100% Epoch 17/20 - loss: 0.1000 - accuracy_score: 0.9700 - 7.28s\n", - "[==============================] 100% Epoch 18/20 - loss: 0.0927 - accuracy_score: 0.9719 - 7.05s\n", - "[==============================] 100% Epoch 19/20 - loss: 0.0925 - accuracy_score: 0.9720 - 6.83s\n", - "[==============================] 100% Epoch 20/20 - loss: 0.0917 - accuracy_score: 0.9726 - 6.97s\n" + "[==============================] 100% Epoch 1/10 - loss: 0.5726 - accuracy_score: 0.8099 - 8.15s\n", + "[==============================] 100% Epoch 2/10 - loss: 0.2319 - accuracy_score: 0.9333 - 7.96s\n", + "[==============================] 100% Epoch 3/10 - loss: 0.1948 - accuracy_score: 0.9432 - 7.10s\n", + "[==============================] 100% Epoch 4/10 - loss: 0.1726 - accuracy_score: 0.9502 - 7.08s\n", + "[==============================] 100% Epoch 5/10 - loss: 0.1587 - accuracy_score: 0.9530 - 6.98s\n", + "[==============================] 100% Epoch 6/10 - loss: 0.1487 - accuracy_score: 0.9563 - 7.23s\n", + "[==============================] 100% Epoch 7/10 - loss: 0.1386 - accuracy_score: 0.9587 - 6.78s\n", + "[==============================] 100% Epoch 8/10 - loss: 0.1349 - accuracy_score: 0.9603 - 6.91s\n", + "[==============================] 100% Epoch 9/10 - loss: 0.1320 - accuracy_score: 0.9609 - 6.93s\n", + "[==============================] 100% Epoch 10/10 - loss: 0.1222 - accuracy_score: 0.9635 - 6.81s\n" ] } ], "source": [ - "model.train(x_train, y_train, epochs=20, batch_size=48, metrics=[accuracy_score], random_state=42)" + "model.fit(x_train, y_train, epochs=10, batch_size=48, metrics=[accuracy_score], random_state=42)" ] }, { @@ -225,8 +215,8 @@ "execution_count": 7, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:12:25.625272Z", - "start_time": "2024-04-21T13:12:25.570874800Z" + "end_time": "2024-04-23T23:30:26.752709800Z", + "start_time": "2024-04-23T23:30:26.683876200Z" } }, "outputs": [ @@ -234,7 +224,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Test loss: 0.1739562576224733\n" + "Test loss: 0.17413642094878234\n" ] } ], @@ -255,8 +245,8 @@ "execution_count": 8, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:12:25.669005800Z", - "start_time": "2024-04-21T13:12:25.619723700Z" + "end_time": "2024-04-23T23:30:26.808706400Z", + "start_time": "2024-04-23T23:30:26.747485900Z" } }, "outputs": [], @@ -276,8 +266,8 @@ "execution_count": 9, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:12:25.670014500Z", - "start_time": "2024-04-21T13:12:25.653477500Z" + "end_time": "2024-04-23T23:30:26.814711Z", + "start_time": "2024-04-23T23:30:26.781708700Z" } }, "outputs": [ @@ -285,9 +275,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "accuracy: 0.9568\n", - "f1_score: 0.9565405449722376\n", - "recall_score 0.9562654244701111\n" + "accuracy: 0.9549\n", + "f1_score: 0.9548478204173041\n", + "recall_score 0.9543130769611624\n" ] } ], @@ -309,17 +299,15 @@ "execution_count": 10, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:12:25.942550200Z", - "start_time": "2024-04-21T13:12:25.662547800Z" + "end_time": "2024-04-23T23:30:27.058485300Z", + "start_time": "2024-04-23T23:30:26.792708800Z" } }, "outputs": [ { "data": { - "image/png": "", - "text/plain": [ - "
" - ] + "text/plain": "
", + "image/png": "\n" }, "metadata": {}, "output_type": "display_data" @@ -346,8 +334,8 @@ "execution_count": 11, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:12:26.112578100Z", - "start_time": "2024-04-21T13:12:25.931021900Z" + "end_time": "2024-04-23T23:30:27.226005600Z", + "start_time": "2024-04-23T23:30:27.057483700Z" } }, "outputs": [], diff --git a/examples/cnn-classification/simple_cnn_classification_mnist.ipynb b/examples/cnn-classification/simple_cnn_classification_mnist.ipynb index 0ed6117..98385b0 100644 --- a/examples/cnn-classification/simple_cnn_classification_mnist.ipynb +++ b/examples/cnn-classification/simple_cnn_classification_mnist.ipynb @@ -21,8 +21,8 @@ "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:59:06.963987800Z", - "start_time": "2024-04-21T13:59:02.342450300Z" + "end_time": "2024-04-23T23:34:31.702298200Z", + "start_time": "2024-04-23T23:34:27.614924700Z" } }, "outputs": [], @@ -52,8 +52,8 @@ "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:59:07.142613500Z", - "start_time": "2024-04-21T13:59:06.964989400Z" + "end_time": "2024-04-23T23:34:31.891335500Z", + "start_time": "2024-04-23T23:34:31.704297500Z" } }, "outputs": [], @@ -73,8 +73,8 @@ "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:59:07.253448500Z", - "start_time": "2024-04-21T13:59:07.144680500Z" + "end_time": "2024-04-23T23:34:32.090375Z", + "start_time": "2024-04-23T23:34:31.893335500Z" } }, "outputs": [], @@ -97,8 +97,8 @@ "execution_count": 4, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:59:07.270516400Z", - "start_time": "2024-04-21T13:59:07.257449600Z" + "end_time": "2024-04-23T23:34:32.105372800Z", + "start_time": "2024-04-23T23:34:32.096373900Z" } }, "outputs": [ @@ -149,8 +149,8 @@ "execution_count": 5, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T13:59:07.322785100Z", - "start_time": "2024-04-21T13:59:07.267961400Z" + "end_time": "2024-04-23T23:34:32.193406400Z", + "start_time": "2024-04-23T23:34:32.104373900Z" } }, "outputs": [ @@ -197,8 +197,8 @@ "execution_count": 6, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T14:04:14.757698500Z", - "start_time": "2024-04-21T13:59:07.286701700Z" + "end_time": "2024-04-23T23:40:29.450560800Z", + "start_time": "2024-04-23T23:34:32.121373200Z" } }, "outputs": [ @@ -206,21 +206,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "[==============================] 100% Epoch 1/10 - loss: 0.7568 - accuracy_score: 0.7469 - 30.60s - val_accuracy: 0.8824\n", - "[==============================] 100% Epoch 2/10 - loss: 0.3492 - accuracy_score: 0.8896 - 27.70s - val_accuracy: 0.9090\n", - "[==============================] 100% Epoch 3/10 - loss: 0.2760 - accuracy_score: 0.9131 - 27.41s - val_accuracy: 0.9248\n", - "[==============================] 100% Epoch 4/10 - loss: 0.2290 - accuracy_score: 0.9287 - 28.12s - val_accuracy: 0.9306\n", - "[==============================] 100% Epoch 5/10 - loss: 0.1984 - accuracy_score: 0.9385 - 29.91s - val_accuracy: 0.9359\n", - "[==============================] 100% Epoch 6/10 - loss: 0.1761 - accuracy_score: 0.9453 - 29.25s - val_accuracy: 0.9403\n", - "[==============================] 100% Epoch 7/10 - loss: 0.1575 - accuracy_score: 0.9516 - 31.05s - val_accuracy: 0.9446\n", - "[==============================] 100% Epoch 8/10 - loss: 0.1420 - accuracy_score: 0.9566 - 29.53s - val_accuracy: 0.9495\n", - "[==============================] 100% Epoch 9/10 - loss: 0.1281 - accuracy_score: 0.9613 - 27.19s - val_accuracy: 0.9552\n", - "[==============================] 100% Epoch 10/10 - loss: 0.1161 - accuracy_score: 0.9649 - 28.16s - val_accuracy: 0.9597\n" + "[==============================] 100% Epoch 1/10 - loss: 0.7568 - accuracy_score: 0.7469 - 32.51s - val_accuracy: 0.8824\n", + "[==============================] 100% Epoch 2/10 - loss: 0.3492 - accuracy_score: 0.8896 - 35.16s - val_accuracy: 0.9090\n", + "[==============================] 100% Epoch 3/10 - loss: 0.2760 - accuracy_score: 0.9131 - 31.36s - val_accuracy: 0.9248\n", + "[==============================] 100% Epoch 4/10 - loss: 0.2290 - accuracy_score: 0.9287 - 30.44s - val_accuracy: 0.9306\n", + "[==============================] 100% Epoch 5/10 - loss: 0.1984 - accuracy_score: 0.9385 - 38.86s - val_accuracy: 0.9359\n", + "[==============================] 100% Epoch 6/10 - loss: 0.1761 - accuracy_score: 0.9453 - 54.44s - val_accuracy: 0.9403\n", + "[==============================] 100% Epoch 7/10 - loss: 0.1575 - accuracy_score: 0.9516 - 32.83s - val_accuracy: 0.9446\n", + "[==============================] 100% Epoch 8/10 - loss: 0.1420 - accuracy_score: 0.9566 - 26.88s - val_accuracy: 0.9495\n", + "[==============================] 100% Epoch 9/10 - loss: 0.1281 - accuracy_score: 0.9613 - 26.80s - val_accuracy: 0.9552\n", + "[==============================] 100% Epoch 10/10 - loss: 0.1161 - accuracy_score: 0.9649 - 27.73s - val_accuracy: 0.9597\n" ] } ], "source": [ - "model.train(x_train, y_train, epochs=10, batch_size=128, metrics=[\n", + "model.fit(x_train, y_train, epochs=10, batch_size=128, metrics=[\n", " accuracy_score], random_state=42, validation_data=(x_test, y_test))" ] }, @@ -236,8 +236,8 @@ "execution_count": 7, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T14:04:16.312526400Z", - "start_time": "2024-04-21T14:04:14.753676500Z" + "end_time": "2024-04-23T23:40:30.972697200Z", + "start_time": "2024-04-23T23:40:29.470562300Z" } }, "outputs": [ @@ -266,8 +266,8 @@ "execution_count": 8, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T14:04:18.043740600Z", - "start_time": "2024-04-21T14:04:16.310956700Z" + "end_time": "2024-04-23T23:40:32.548225200Z", + "start_time": "2024-04-23T23:40:30.978698900Z" } }, "outputs": [], @@ -287,8 +287,8 @@ "execution_count": 9, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T14:04:18.054771900Z", - "start_time": "2024-04-21T14:04:18.046739300Z" + "end_time": "2024-04-23T23:40:32.579219700Z", + "start_time": "2024-04-23T23:40:32.550236700Z" } }, "outputs": [ @@ -320,8 +320,8 @@ "execution_count": 10, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T14:04:18.315666800Z", - "start_time": "2024-04-21T14:04:18.058285200Z" + "end_time": "2024-04-23T23:40:33.035884200Z", + "start_time": "2024-04-23T23:40:32.586224100Z" } }, "outputs": [ @@ -355,8 +355,8 @@ "execution_count": 11, "metadata": { "ExecuteTime": { - "end_time": "2024-04-21T14:04:18.477652200Z", - "start_time": "2024-04-21T14:04:18.317667200Z" + "end_time": "2024-04-23T23:40:33.205916600Z", + "start_time": "2024-04-23T23:40:33.039881500Z" } }, "outputs": [], diff --git a/examples/real-life-applications/tic_tac_toe_alternative_dataset_shape.py b/examples/real-life-applications/tic_tac_toe_alternative_dataset_shape.py index 6bf5da5..623e475 100644 --- a/examples/real-life-applications/tic_tac_toe_alternative_dataset_shape.py +++ b/examples/real-life-applications/tic_tac_toe_alternative_dataset_shape.py @@ -72,7 +72,7 @@ def main(): model.compile(loss_function=BinaryCrossentropy(), optimizer=Adam(learning_rate=0.001)) # 7. Train the model - model.train(x_train, y_train, epochs=500, batch_size=32, metrics=[accuracy_score], random_state=42) + model.fit(x_train, y_train, epochs=500, batch_size=32, metrics=[accuracy_score], random_state=42) # 8. Evaluate the model loss = model.evaluate(x_test, y_test) diff --git a/neuralnetlib/layers.py b/neuralnetlib/layers.py index 5ff03a5..4888abc 100644 --- a/neuralnetlib/layers.py +++ b/neuralnetlib/layers.py @@ -120,6 +120,7 @@ def __str__(self): def forward_pass(self, input_data: np.ndarray) -> np.ndarray: if self.weights is None: + assert len(input_data.shape) == 2, f"Dense input must be 2D (batch_size, features), got {input_data.shape}" self.initialize_weights(input_data.shape[1]) self.input = input_data @@ -265,6 +266,7 @@ def __str__(self): def forward_pass(self, input_data: np.ndarray) -> np.ndarray: if self.weights is None: + assert len(input_data.shape) == 4, f"Conv2D input must be 4D (batch_size, channels, height, width), got {input_data.shape}" self.initialize_weights(input_data.shape[1:]) self.input = input_data @@ -363,6 +365,7 @@ def __str__(self): return f'MaxPooling2D(pool_size={self.pool_size}, stride={self.stride}, padding={self.padding})' def forward_pass(self, input_data: np.ndarray) -> np.ndarray: + assert len(input_data.shape) == 4, f"MaxPooling2D input must be 4D (batch_size, channels, height, width), got {input_data.shape}" self.input = input_data output = self._pool(self.input, self.pool_size, self.stride, self.padding) return output @@ -446,6 +449,7 @@ def __str__(self): return 'Flatten' def forward_pass(self, input_data: np.ndarray) -> np.ndarray: + assert len(input_data.shape) >= 2, f"Flatten input must be at least 2D, got {input_data.shape}" self.input_shape = input_data.shape return input_data.reshape(input_data.shape[0], -1) @@ -508,6 +512,7 @@ def __str__(self): def forward_pass(self, input_data: np.ndarray) -> np.ndarray: if self.weights is None: + assert len(input_data.shape) == 3, f"Conv1D input must be 3D (batch_size, steps, features), got {input_data.shape}" self.initialize_weights(input_data.shape[1:]) self.input = input_data @@ -600,6 +605,7 @@ def __str__(self): return f'MaxPooling1D(pool_size={self.pool_size}, stride={self.stride}, padding={self.padding})' def forward_pass(self, input_data: np.ndarray) -> np.ndarray: + assert len(input_data.shape) == 3, f"MaxPooling1D input must be 3D (batch_size, steps, features), got {input_data.shape}" self.input = input_data output = self._pool(self.input, self.pool_size, self.stride, self.padding) return output @@ -692,6 +698,7 @@ def __str__(self): def forward_pass(self, input_data: np.ndarray) -> np.ndarray: if self.weights is None: + assert len(input_data.shape) == 2, f"Embedding input must be 2D (batch_size, sequence_length), got {input_data.shape}" self.initialize_weights() self.input = input_data diff --git a/neuralnetlib/model.py b/neuralnetlib/model.py index f5f6a67..2178926 100644 --- a/neuralnetlib/model.py +++ b/neuralnetlib/model.py @@ -3,7 +3,7 @@ import numpy as np -from neuralnetlib.layers import Layer, Input, Activation, Dense, Flatten, Conv2D, Dropout, Conv1D, Embedding +from neuralnetlib.layers import Layer, Input, Dense, Activation, Conv2D, MaxPooling2D, Conv1D, MaxPooling1D, Flatten, Dropout, Embedding from neuralnetlib.losses import LossFunction, CategoricalCrossentropy from neuralnetlib.optimizers import Optimizer from neuralnetlib.utils import shuffle, progress_bar @@ -33,16 +33,43 @@ def summary(self): print(str(self)) def add(self, layer: Layer): - if self.layers and len(self.layers) != 0 and not isinstance(self.layers[-1], Input) and isinstance(layer, (Dense, Conv1D, Embedding)): - prev_layer = [l for l in self.layers if isinstance(l, (Input, Dense, Conv2D, Conv1D, Flatten, Embedding))][-1] - if isinstance(prev_layer, Flatten): - prev_layer = [l for l in self.layers if isinstance(l, (Dense, Conv2D, Conv1D))][-1] - if hasattr(prev_layer, 'output_size') and prev_layer.output_size != layer.input_size: - raise ValueError( - f'Layer input size {layer.input_size} does not match previous layer output size {prev_layer.output_size}.') - elif self.layers and isinstance(layer, Dropout): - if isinstance(self.layers[-1], Dropout): - raise ValueError("Cannot add consecutive Dropout layers.") + if not self.layers: + if not isinstance(layer, Input): + raise ValueError("The first layer must be an Input layer.") + else: + previous_layer = self.layers[-1] + + if isinstance(previous_layer, Input): + if not isinstance(layer, (Dense, Conv2D, Conv1D, Embedding)): + raise ValueError("Input layer can only be followed by Dense, Conv2D, Conv1D, or Embedding.") + elif isinstance(previous_layer, Dense): + if not isinstance(layer, (Dense, Activation, Dropout)): + raise ValueError("Dense layer can only be followed by Dense, Activation, or Dropout.") + elif isinstance(previous_layer, Activation): + if not isinstance(layer, (Dense, Conv2D, Conv1D, MaxPooling2D, MaxPooling1D, Flatten, Dropout)): + raise ValueError("Activation layer can only be followed by Dense, Conv2D, Conv1D, MaxPooling2D, MaxPooling1D, Flatten, or Dropout.") + elif isinstance(previous_layer, Conv2D): + if not isinstance(layer, (Conv2D, MaxPooling2D, Activation, Dropout, Flatten)): + raise ValueError("Conv2D layer can only be followed by Conv2D, MaxPooling2D, Activation, Dropout, or Flatten.") + elif isinstance(previous_layer, MaxPooling2D): + if not isinstance(layer, (Conv2D, MaxPooling2D, Flatten)): + raise ValueError("MaxPooling2D layer can only be followed by Conv2D, MaxPooling2D, or Flatten.") + elif isinstance(previous_layer, Conv1D): + if not isinstance(layer, (Conv1D, MaxPooling1D, Activation, Dropout, Flatten)): + raise ValueError("Conv1D layer can only be followed by Conv1D, MaxPooling1D, Activation, Dropout, or Flatten.") + elif isinstance(previous_layer, MaxPooling1D): + if not isinstance(layer, (Conv1D, MaxPooling1D, Flatten)): + raise ValueError("MaxPooling1D layer can only be followed by Conv1D, MaxPooling1D, or Flatten.") + elif isinstance(previous_layer, Flatten): + if not isinstance(layer, (Dense, Dropout)): + raise ValueError("Flatten layer can only be followed by Dense or Dropout.") + elif isinstance(previous_layer, Dropout): + if not isinstance(layer, (Dense, Conv2D, Conv1D, Activation)): + raise ValueError("Dropout layer can only be followed by Dense, Conv2D, Conv1D, or Activation.") + elif isinstance(previous_layer, Embedding): + if not isinstance(layer, (Conv1D, Flatten, Dense)): + raise ValueError("Embedding layer can only be followed by Conv1D, Flatten, or Dense.") + self.layers.append(layer) def compile(self, loss_function: LossFunction, optimizer: Optimizer, verbose: bool = False): @@ -88,8 +115,21 @@ def train_on_batch(self, x_batch: np.ndarray, y_batch: np.ndarray) -> float: self.backward_pass(error) return loss - def train(self, x_train: np.ndarray, y_train: np.ndarray, epochs: int, batch_size: int = None, - verbose: bool = True, metrics: list = None, random_state: int = None, validation_data: tuple = None): + def fit(self, x_train: np.ndarray, y_train: np.ndarray, epochs: int, batch_size: int = None, + verbose: bool = True, metrics: list = None, random_state: int = None, validation_data: tuple = None): + """ + Fit the model to the training data. + + Args: + x_train: Training data + y_train: Training labels + epochs: Number of epochs to train the model + batch_size: Number of samples per gradient update + verbose: Whether to print training progress + metrics: List of metrics to evaluate the model (functions from neuralnetlib.metrics module) + random_state: Random seed for shuffling the data + validation_data: Tuple of validation data and labels + """ for i in range(epochs): start_time = time.time() diff --git a/tests/test_model.py b/tests/test_model.py index fe04027..7c519ae 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -24,11 +24,11 @@ def setUp(self): self.y_test = rng.random((10, 20)) def test_model_train_on_batch(self): - loss = self.model.train_on_batch(self.x_train[:10], self.y_train[:10]) + loss = self.model.fit_on_batch(self.x_train[:10], self.y_train[:10]) self.assertIsInstance(loss, float) def test_model_train(self): - self.model.train(self.x_train, self.y_train, epochs=1, batch_size=10, verbose=False) + self.model.fit(self.x_train, self.y_train, epochs=1, batch_size=10, verbose=False) def test_model_evaluate(self): loss = self.model.evaluate(self.x_test, self.y_test)