From c0e24508b7c5709181e7edbc1a108848401002e6 Mon Sep 17 00:00:00 2001 From: B-Step62 Date: Wed, 25 Sep 2024 15:24:36 +0900 Subject: [PATCH] Use Databricks SDK for OAuth Signed-off-by: B-Step62 --- libs/databricks/poetry.lock | 2 +- libs/databricks/pyproject.toml | 1 + .../integration_tests/test_chat_models.py | 9 +++ .../integration_tests/test_vectorstore.py | 66 +++++-------------- 4 files changed, 26 insertions(+), 52 deletions(-) diff --git a/libs/databricks/poetry.lock b/libs/databricks/poetry.lock index 5e24bba..99e552d 100644 --- a/libs/databricks/poetry.lock +++ b/libs/databricks/poetry.lock @@ -3370,4 +3370,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "ac775fe07702a5575f43107a336951b68fcec470b977dfeb4d668db0ee71d3a9" +content-hash = "c6485b9664d292281f293ca0b3ee5a19a625224eff2146e4ff348f4afc45a2bc" diff --git a/libs/databricks/pyproject.toml b/libs/databricks/pyproject.toml index 82759b6..253213d 100644 --- a/libs/databricks/pyproject.toml +++ b/libs/databricks/pyproject.toml @@ -51,6 +51,7 @@ codespell = "^2.2.6" optional = true [tool.poetry.group.test_integration.dependencies] +databricks-sdk = "^0.32.3" langchain = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/langchain" } langgraph = "^0.2.27" pytest-timeout = "^2.3.1" diff --git a/libs/databricks/tests/integration_tests/test_chat_models.py b/libs/databricks/tests/integration_tests/test_chat_models.py index 6c995de..f4e05dd 100644 --- a/libs/databricks/tests/integration_tests/test_chat_models.py +++ b/libs/databricks/tests/integration_tests/test_chat_models.py @@ -322,4 +322,13 @@ def chatbot(state: State): {"messages": [("user", "Subtract 5 from it")]}, config={"configurable": {"thread_id": "1"}}, ) + + # Interestingly, the agent sometimes mistakes the subtraction for addition:( + # In such case, the agent asks for a retry so we need one more step. + if "Let me try again." in response["messages"][-1].content: + response = graph.invoke( + {"messages": [("user", "Ok, try again")]}, + config={"configurable": {"thread_id": "1"}}, + ) + assert "40" in response["messages"][-1].content diff --git a/libs/databricks/tests/integration_tests/test_vectorstore.py b/libs/databricks/tests/integration_tests/test_vectorstore.py index 8bf72ce..b497835 100644 --- a/libs/databricks/tests/integration_tests/test_vectorstore.py +++ b/libs/databricks/tests/integration_tests/test_vectorstore.py @@ -10,10 +10,11 @@ """ import os -import time +from datetime import timedelta import pytest -import requests +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.jobs import RunLifecycleStateV2State, TerminationTypeType @pytest.mark.timeout(3600) @@ -23,62 +24,25 @@ def test_vectorstore(): because the setup is too complex to run within a single python file. Thereby, this test simply triggers the workflow by calling the REST API. """ - required_env_vars = ["DATABRICKS_HOST", "DATABRICKS_TOKEN", "VS_TEST_JOB_ID"] - for var in required_env_vars: - assert os.getenv(var), f"Please set the environment variable {var}." - - test_endpoint = os.getenv("DATABRICKS_HOST") test_job_id = os.getenv("VS_TEST_JOB_ID") - headers = { - "Authorization": f"Bearer {os.getenv('DATABRICKS_TOKEN')}", - } + if not test_job_id: + raise RuntimeError("Please set the environment variable VS_TEST_JOB_ID") + + w = WorkspaceClient() # Check if there is any ongoing job run - response = requests.get( - f"{test_endpoint}/api/2.1/jobs/runs/list", - json={ - "job_id": test_job_id, - "active_only": True, - }, - headers=headers, - ) - no_active_run = len(response.json().get("runs", [])) == 0 + run_list = list(w.jobs.list_runs(job_id=test_job_id, active_only=True)) + no_active_run = len(run_list) == 0 assert no_active_run, "There is an ongoing job run. Please wait for it to complete." # Trigger the workflow - # TODO: We are going to replace this with the Databricks SDK once the vector store - # class is also migrated to the SDK. - response = requests.post( - f"{test_endpoint}/api/2.1/jobs/run-now", - json={ - "job_id": test_job_id, - }, - headers=headers, + response = w.jobs.run_now(job_id=test_job_id) + job_url = ( + f"{os.getenv('DATABRICKS_HOST')}/jobs/{test_job_id}/runs/{response.run_id}" ) - - assert response.status_code == 200, "Failed to trigger the workflow." - - job_url = f"{test_endpoint}/jobs/{test_job_id}/runs/{response.json()['run_id']}" print(f"Started the job at {job_url}") # noqa: T201 # Wait for the job to complete - while True: - response = requests.get( - f"{test_endpoint}/api/2.1/jobs/runs/get", - json={ - "run_id": response.json()["run_id"], - }, - headers=headers, - ) - - assert response.status_code == 200, "Failed to get the job status." - - status = response.json()["status"] - if status["state"] == "TERMINATED": - if status["termination_details"]["type"] == "SUCCESS": - break - else: - assert False, "Job failed. Please check the logs in the workspace." - - time.sleep(60) - print("Job is still running...") # noqa: T201 + result = response.result(timeout=timedelta(seconds=3600)) + assert result.status.state == RunLifecycleStateV2State.TERMINATED + assert result.status.termination_details.type == TerminationTypeType.SUCCESS