Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent validation data handling in Keras 3 for Language Model fine-tuning #20748

Open
che-shr-cat opened this issue Jan 10, 2025 · 1 comment
Assignees
Labels
Gemma Gemma model specific issues stat:awaiting response from contributor type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited.

Comments

@che-shr-cat
Copy link

Issue Description

When fine-tuning language models in Keras 3, there are inconsistencies in how validation data should be provided. The documentation suggests validation_data should be in (x, y) format, but the actual requirements are unclear and the behavior differs between training and validation phases.

Current Behavior & Problems

Issue 1: Raw text arrays are not accepted for validation

train_texts = ["text1", "text2", ...]
val_texts = ["val1", "val2", ...]

# This fails with ValueError:
model.fit(
   train_texts,
   validation_data=val_texts
)

# Error:
ValueError: Data is expected to be in format `x`, `(x,)`, `(x, y)`, or `(x, y, sample_weight)`, found: ("text1", "text2", ...)

Issue 2: Pre-tokenized validation fails

# Trying to provide tokenized data:
val_tokenized = [tokenizer(text) for text in val_texts]
val_padded = np.array([pad_sequence(seq, max_len) for seq in val_tokenized])
val_input = val_padded[:, :-1]
val_target = val_padded[:, 1:]

model.fit(
    train_texts,
    validation_data=(val_input, val_target)
)

# Error:
TypeError: Input 'input' of 'SentencepieceTokenizeOp' Op has type int64 that does not match expected type of string.

The error suggests the tokenizer is being applied again to already tokenized data. I understand there is the preprocessor=None parameter, but I don't want to preprocess train data manually.

Working Solution (But Needs Documentation)

The working approach is to provide prompt-completion pairs:

# Prepare validation data as prompts and expected outputs
val_inputs = [format_prompt(text) for text in val_input_texts]
val_outputs = [format_output(text) for text in val_output_texts]
val_inputs = np.array(val_inputs)
val_outputs = np.array(val_outputs)

model.fit(
    train_texts,
    validation_data=(val_inputs, val_outputs)
)

Expected Behavior

  1. The documentation should clearly state that validation data for language models should be provided as prompt-completion pairs
  2. The validation data handling should be consistent with how training data is processed
  3. It should be clear whether token shifting is handled internally or needs to be done manually

Environment

  • Keras Version: 3.x
  • Python Version: 3.10
  • Model: Gemma LLM (but likely affects other LLMs too)

Additional Context

While there is a working solution using prompt-completion pairs, this differs from traditional language model training where each token predicts the next token. The documentation should clarify this architectural choice and explain the proper way to provide validation data.

@github-actions github-actions bot added the Gemma Gemma model specific issues label Jan 10, 2025
@mehtamansi29
Copy link
Collaborator

mehtamansi29 commented Jan 15, 2025

Hi @che-shr-cat -

Thanks for reporting the issue. Here for fitting model with raw texts you can use it like this. It will run without error.
Attached gist here for the same.

train_texts = ["text1", "text2"]
val_texts = ["val1", "val2"]

train_dataset = tf.data.Dataset.from_tensor_slices(train_texts).batch(2)
val_dataset = tf.data.Dataset.from_tensor_slices(val_texts).batch(2)

model= keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en")
model.fit(train_dataset,validation_data=val_dataset)

And for the Pre-tokenized validation, if you are using gemma_2b_en model then it will give error. Because as per gemma2 model card document, Input is Text string only.

Please let me know if you are using different model and facing the error.

@mehtamansi29 mehtamansi29 added stat:awaiting response from contributor type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited. labels Jan 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues stat:awaiting response from contributor type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited.
Projects
None yet
Development

No branches or pull requests

2 participants