From ba9473dd6a38f9440d0d03d59877d9e1afeee514 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bar=C4=B1=C5=9F=20Can=20Durak?= <36421093+bcdurak@users.noreply.github.com> Date: Wed, 27 Nov 2024 14:32:23 +0100 Subject: [PATCH] updating the quickstart (#3188) --- examples/quickstart/steps/model_evaluator.py | 13 ++++++++++--- examples/quickstart/steps/model_tester.py | 11 +++++++++-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/examples/quickstart/steps/model_evaluator.py b/examples/quickstart/steps/model_evaluator.py index f66cd0b4dd3..fc8dac00132 100644 --- a/examples/quickstart/steps/model_evaluator.py +++ b/examples/quickstart/steps/model_evaluator.py @@ -13,14 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + import torch from datasets import Dataset from transformers import ( T5ForConditionalGeneration, ) -from zenml import log_model_metadata, step +from zenml import get_step_context, log_metadata, step from zenml.logger import get_logger logger = get_logger(__name__) @@ -50,4 +50,11 @@ def evaluate_model( avg_loss = total_loss / num_batches print(f"Average loss on the dataset: {avg_loss}") - log_model_metadata({"Average Loss": avg_loss}) + step_context = get_step_context() + + if step_context.model: + log_metadata( + metadata={"Average Loss": avg_loss}, + model_name=step_context.model.name, + model_version=step_context.model.version, + ) diff --git a/examples/quickstart/steps/model_tester.py b/examples/quickstart/steps/model_tester.py index d271601bd04..72d68ed7d57 100644 --- a/examples/quickstart/steps/model_tester.py +++ b/examples/quickstart/steps/model_tester.py @@ -21,7 +21,7 @@ T5TokenizerFast, ) -from zenml import log_model_metadata, step +from zenml import get_step_context, log_metadata, step from zenml.logger import get_logger from .data_loader import PROMPT @@ -70,4 +70,11 @@ def test_model( sentence_without_prompt: decoded_output } - log_model_metadata({"Example Prompts": test_collection}) + step_context = get_step_context() + + if step_context.model: + log_metadata( + metadata={"Example Prompts": test_collection}, + model_name=step_context.model.name, + model_version=step_context.model.version, + )