Skip to content

Commit

Permalink
Adapt masked language modeling.py script to backend-agnostic (#2054)
Browse files Browse the repository at this point in the history
* adapting the script masked_langauge_modeling.py

* refactoring the script

* refactoring continues

* improved_implementation

* improved_implementation

* addressing last comments

* removing warnings
  • Loading branch information
Humbulani1234 authored Mar 4, 2025
1 parent ba5b116 commit 2676192
Show file tree
Hide file tree
Showing 3 changed files with 1,248 additions and 260 deletions.
98 changes: 53 additions & 45 deletions examples/nlp/ipynb/masked_language_modeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,22 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
"os.environ[\"KERAS_BACKEND\"] = \"torch\" # or jax, or tensorflow\n",
"\n",
"import keras_hub\n",
"\n",
"import keras\n",
"import tensorflow as tf\n",
"from keras import layers\n",
"from keras.layers import TextVectorization\n",
"\n",
"from dataclasses import dataclass\n",
"import pandas as pd\n",
"import numpy as np\n",
Expand All @@ -95,7 +97,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -130,7 +132,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -142,7 +144,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -159,7 +161,6 @@
"\n",
"\n",
"def get_data_from_text_files(folder_name):\n",
"\n",
" pos_files = glob.glob(\"aclImdb/\" + folder_name + \"/pos/*.txt\")\n",
" pos_texts = get_text_list_from_files(pos_files)\n",
" neg_files = glob.glob(\"aclImdb/\" + folder_name + \"/neg/*.txt\")\n",
Expand All @@ -177,7 +178,7 @@
"train_df = get_data_from_text_files(\"train\")\n",
"test_df = get_data_from_text_files(\"test\")\n",
"\n",
"all_data = train_df.append(test_df)"
"all_data = pd.concat([train_df, test_df], ignore_index=True)"
]
},
{
Expand All @@ -203,12 +204,15 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"# For data pre-processing and tf.data.Dataset\n",
"import tensorflow as tf\n",
"\n",
"\n",
"def custom_standardization(input_data):\n",
" lowercase = tf.strings.lower(input_data)\n",
Expand Down Expand Up @@ -276,9 +280,9 @@
" # Set input to [MASK] which is the last token for the 90% of tokens\n",
" # This means leaving 10% unchanged\n",
" inp_mask_2mask = inp_mask & (np.random.rand(*encoded_texts.shape) < 0.90)\n",
" encoded_texts_masked[\n",
" inp_mask_2mask\n",
" ] = mask_token_id # mask token is the last in the dict\n",
" encoded_texts_masked[inp_mask_2mask] = (\n",
" mask_token_id # mask token is the last in the dict\n",
" )\n",
"\n",
" # Set 10% to a random token\n",
" inp_mask_2random = inp_mask_2mask & (np.random.rand(*encoded_texts.shape) < 1 / 9)\n",
Expand Down Expand Up @@ -312,10 +316,8 @@
" config.BATCH_SIZE\n",
")\n",
"\n",
"# Build dataset for end to end model input (will be used at the end)\n",
"test_raw_classifier_ds = tf.data.Dataset.from_tensor_slices(\n",
" (test_df.review.values, y_test)\n",
").batch(config.BATCH_SIZE)\n",
"# Dataset for end to end model input (will be used at the end)\n",
"test_raw_classifier_ds = test_df\n",
"\n",
"# Prepare data for masked language model\n",
"x_all_review = encode(all_data.review.values)\n",
Expand Down Expand Up @@ -345,7 +347,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -389,26 +391,14 @@
"\n",
"\n",
"class MaskedLanguageModel(keras.Model):\n",
" def train_step(self, inputs):\n",
" if len(inputs) == 3:\n",
" features, labels, sample_weight = inputs\n",
" else:\n",
" features, labels = inputs\n",
" sample_weight = None\n",
"\n",
" with tf.GradientTape() as tape:\n",
" predictions = self(features, training=True)\n",
" loss = loss_fn(labels, predictions, sample_weight=sample_weight)\n",
" def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):\n",
"\n",
" # Compute gradients\n",
" trainable_vars = self.trainable_variables\n",
" gradients = tape.gradient(loss, trainable_vars)\n",
"\n",
" # Update weights\n",
" self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n",
"\n",
" # Compute our own metrics\n",
" loss = loss_fn(y, y_pred, sample_weight)\n",
" loss_tracker.update_state(loss, sample_weight=sample_weight)\n",
" return keras.ops.sum(loss)\n",
"\n",
" def compute_metrics(self, x, y, y_pred, sample_weight):\n",
"\n",
" # Return a dict mapping metric names to current value\n",
" return {\"loss\": loss_tracker.result()}\n",
Expand Down Expand Up @@ -505,7 +495,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -530,7 +520,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -596,24 +586,41 @@
"When you want to deploy a model, it's best if it already includes its preprocessing\n",
"pipeline, so that you don't have to reimplement the preprocessing logic in your\n",
"production environment. Let's create an end-to-end model that incorporates\n",
"the `TextVectorization` layer, and let's evaluate. Our model will accept raw strings\n",
"as input."
"the `TextVectorization` layer inside evalaute method, and let's evaluate. We will pass raw strings as input."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"# We create a custom Model to override the evaluate method so\n",
"# that it first pre-process text data\n",
"class ModelEndtoEnd(keras.Model):\n",
"\n",
" def evaluate(self, inputs):\n",
" features = encode(inputs.review.values)\n",
" labels = inputs.sentiment.values\n",
" test_classifier_ds = (\n",
" tf.data.Dataset.from_tensor_slices((features, labels))\n",
" .shuffle(1000)\n",
" .batch(config.BATCH_SIZE)\n",
" )\n",
" return super().evaluate(test_classifier_ds)\n",
"\n",
" # Build the model\n",
" def build(self, input_shape):\n",
" self.built = True\n",
"\n",
"\n",
"def get_end_to_end(model):\n",
" inputs_string = keras.Input(shape=(1,), dtype=\"string\")\n",
" indices = vectorize_layer(inputs_string)\n",
" outputs = model(indices)\n",
" end_to_end_model = keras.Model(inputs_string, outputs, name=\"end_to_end_model\")\n",
" inputs = classifer_model.inputs[0]\n",
" outputs = classifer_model.outputs\n",
" end_to_end_model = ModelEndtoEnd(inputs, outputs, name=\"end_to_end_model\")\n",
" optimizer = keras.optimizers.Adam(learning_rate=config.LR)\n",
" end_to_end_model.compile(\n",
" optimizer=optimizer, loss=\"binary_crossentropy\", metrics=[\"accuracy\"]\n",
Expand All @@ -622,6 +629,7 @@
"\n",
"\n",
"end_to_end_classification_model = get_end_to_end(classifer_model)\n",
"# Pass raw text dataframe to the model\n",
"end_to_end_classification_model.evaluate(test_raw_classifier_ds)"
]
}
Expand All @@ -630,7 +638,7 @@
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "mlm_and_finetune_with_bert",
"name": "masked_language_modeling",
"private_outputs": false,
"provenance": [],
"toc_visible": true
Expand All @@ -655,4 +663,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
71 changes: 40 additions & 31 deletions examples/nlp/masked_language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Last modified: 2024/03/15
Description: Implement a Masked Language Model (MLM) with BERT and fine-tune it on the IMDB Reviews dataset.
Accelerator: GPU
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT) and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
"""

"""
Expand Down Expand Up @@ -46,12 +46,14 @@

import os

os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["KERAS_BACKEND"] = "torch" # or jax, or tensorflow

import keras_hub

import keras
import tensorflow as tf
from keras import layers
from keras.layers import TextVectorization

from dataclasses import dataclass
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -117,7 +119,7 @@ def get_data_from_text_files(folder_name):
train_df = get_data_from_text_files("train")
test_df = get_data_from_text_files("test")

all_data = train_df.append(test_df)
all_data = pd.concat([train_df, test_df], ignore_index=True)

"""
## Dataset preparation
Expand All @@ -135,6 +137,9 @@ def get_data_from_text_files(folder_name):
It masks 15% of all input tokens in each sequence at random.
"""

# For data pre-processing and tf.data.Dataset
import tensorflow as tf


def custom_standardization(input_data):
lowercase = tf.strings.lower(input_data)
Expand Down Expand Up @@ -238,10 +243,8 @@ def get_masked_input_and_labels(encoded_texts):
config.BATCH_SIZE
)

# Build dataset for end to end model input (will be used at the end)
test_raw_classifier_ds = tf.data.Dataset.from_tensor_slices(
(test_df.review.values, y_test)
).batch(config.BATCH_SIZE)
# Dataset for end to end model input (will be used at the end)
test_raw_classifier_ds = test_df

# Prepare data for masked language model
x_all_review = encode(all_data.review.values)
Expand Down Expand Up @@ -301,26 +304,14 @@ def bert_module(query, key, value, i):


class MaskedLanguageModel(keras.Model):
def train_step(self, inputs):
if len(inputs) == 3:
features, labels, sample_weight = inputs
else:
features, labels = inputs
sample_weight = None

with tf.GradientTape() as tape:
predictions = self(features, training=True)
loss = loss_fn(labels, predictions, sample_weight=sample_weight)

# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):

# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))

# Compute our own metrics
loss = loss_fn(y, y_pred, sample_weight)
loss_tracker.update_state(loss, sample_weight=sample_weight)
return keras.ops.sum(loss)

def compute_metrics(self, x, y, y_pred, sample_weight):

# Return a dict mapping metric names to current value
return {"loss": loss_tracker.result()}
Expand Down Expand Up @@ -475,16 +466,33 @@ def create_classifier_bert_model():
When you want to deploy a model, it's best if it already includes its preprocessing
pipeline, so that you don't have to reimplement the preprocessing logic in your
production environment. Let's create an end-to-end model that incorporates
the `TextVectorization` layer, and let's evaluate. Our model will accept raw strings
as input.
the `TextVectorization` layer inside evaluate method, and let's evaluate. We will pass raw strings as input.
"""


# We create a custom Model to override the evaluate method so
# that it first pre-process text data
class ModelEndtoEnd(keras.Model):

def evaluate(self, inputs):
features = encode(inputs.review.values)
labels = inputs.sentiment.values
test_classifier_ds = (
tf.data.Dataset.from_tensor_slices((features, labels))
.shuffle(1000)
.batch(config.BATCH_SIZE)
)
return super().evaluate(test_classifier_ds)

# Build the model
def build(self, input_shape):
self.built = True


def get_end_to_end(model):
inputs_string = keras.Input(shape=(1,), dtype="string")
indices = vectorize_layer(inputs_string)
outputs = model(indices)
end_to_end_model = keras.Model(inputs_string, outputs, name="end_to_end_model")
inputs = classifer_model.inputs[0]
outputs = classifer_model.outputs
end_to_end_model = ModelEndtoEnd(inputs, outputs, name="end_to_end_model")
optimizer = keras.optimizers.Adam(learning_rate=config.LR)
end_to_end_model.compile(
optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
Expand All @@ -493,4 +501,5 @@ def get_end_to_end(model):


end_to_end_classification_model = get_end_to_end(classifer_model)
# Pass raw text dataframe to the model
end_to_end_classification_model.evaluate(test_raw_classifier_ds)
Loading

0 comments on commit 2676192

Please sign in to comment.