Skip to content

Commit

Permalink
minor syntax fixes in integration tests (#351)
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-alaiacano authored May 22, 2023
1 parent 239cdad commit a73cddd
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 7 deletions.
4 changes: 2 additions & 2 deletions examples/Serving-Ranking-Models-With-Merlin-Systems.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1026,8 +1026,8 @@
"\n",
"# read in data for request\n",
"batch = df_lib.read_parquet(\n",
" os.path.join(original_data_path,\"valid\", \"part.0.parquet\"), num_rows=3, columns=workflow.input_schema.column_names\n",
")\n",
" os.path.join(original_data_path,\"valid\", \"part.0.parquet\"), columns=workflow.input_schema.column_names\n",
").head(3)\n",
"batch"
]
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import pytest

from testbook import testbook

from tests.conftest import REPO_ROOT
Expand Down Expand Up @@ -97,9 +96,8 @@ def test_example_04_exporting_ranking_models(tb):
# read in data for request
batch = df_lib.read_parquet(
os.path.join("/tmp/data/", "valid", "part.0.parquet"),
num_rows=3,
columns=workflow.input_schema.column_names,
)
).head(3)
batch = batch.drop(columns="click")
outputs = tb.ref("output_cols")
from merlin.dataloader.tf_utils import configure_tensorflow
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/t4r/test_pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@


def test_serve_t4r_with_torchscript(tmpdir):

# ===========================================
# Generate training data
# ===========================================
Expand Down Expand Up @@ -107,7 +106,7 @@ def test_serve_t4r_with_torchscript(tmpdir):
if name in input_schema.column_names:
dtype = input_schema[name].dtype

df_cols[name] = tensor.cpu().numpy().astype(dtype)
df_cols[name] = tensor.cpu().numpy().astype(dtype.name)
if len(tensor.shape) > 1:
df_cols[name] = list(df_cols[name])

Expand Down

0 comments on commit a73cddd

Please sign in to comment.