diff --git a/examples/quickstart/steps/model_evaluator.py b/examples/quickstart/steps/model_evaluator.py index f66cd0b4dd3..fc8dac00132 100644 --- a/examples/quickstart/steps/model_evaluator.py +++ b/examples/quickstart/steps/model_evaluator.py @@ -13,14 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + import torch from datasets import Dataset from transformers import ( T5ForConditionalGeneration, ) -from zenml import log_model_metadata, step +from zenml import get_step_context, log_metadata, step from zenml.logger import get_logger logger = get_logger(__name__) @@ -50,4 +50,11 @@ def evaluate_model( avg_loss = total_loss / num_batches print(f"Average loss on the dataset: {avg_loss}") - log_model_metadata({"Average Loss": avg_loss}) + step_context = get_step_context() + + if step_context.model: + log_metadata( + metadata={"Average Loss": avg_loss}, + model_name=step_context.model.name, + model_version=step_context.model.version, + ) diff --git a/examples/quickstart/steps/model_tester.py b/examples/quickstart/steps/model_tester.py index d271601bd04..72d68ed7d57 100644 --- a/examples/quickstart/steps/model_tester.py +++ b/examples/quickstart/steps/model_tester.py @@ -21,7 +21,7 @@ T5TokenizerFast, ) -from zenml import log_model_metadata, step +from zenml import get_step_context, log_metadata, step from zenml.logger import get_logger from .data_loader import PROMPT @@ -70,4 +70,11 @@ def test_model( sentence_without_prompt: decoded_output } - log_model_metadata({"Example Prompts": test_collection}) + step_context = get_step_context() + + if step_context.model: + log_metadata( + metadata={"Example Prompts": test_collection}, + model_name=step_context.model.name, + model_version=step_context.model.version, + )