Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add proof verification through api #41

Merged
merged 2 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 11 additions & 25 deletions giza_actions/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from giza import API_HOST
from giza.client import AgentsClient, EndpointsClient, JobsClient, ProofsClient
from giza.schemas.agents import Agent, AgentList, AgentUpdate
from giza.schemas.jobs import Job, JobCreate, JobList
from giza.schemas.jobs import Job, JobList
from giza.schemas.proofs import Proof
from giza.utils.enums import JobKind, JobSize, JobStatus
from giza.utils.enums import JobKind, JobStatus
from requests import HTTPError

from giza_actions.model import GizaModel
Expand Down Expand Up @@ -363,9 +363,7 @@ def _verify(self):
return

self._wait_for_proof(self._jobs_client, self._timeout, self._poll_interval)
self._verify_job = self._start_verify_job(self._jobs_client)
self._wait_for_verify(self._jobs_client, self._timeout, self._poll_interval)
self.verified = True
self.verified = self._verify_proof(self._endpoint_client)

def _wait_for_proof(
self, client: JobsClient, timeout: int = 600, poll_interval: int = 10
Expand All @@ -378,29 +376,17 @@ def _wait_for_proof(
self._endpoint_id, self._proof_job.request_id
)

def _start_verify_job(self, client: JobsClient) -> Job:
def _verify_proof(self, client: EndpointsClient) -> bool:
"""
Start the verify job.
Verify the proof.
"""
job_create = JobCreate(
size=JobSize.S,
framework=self._framework,
model_id=self._model_id,
version_id=self._version_id,
proof_id=self._proof.id,
kind=JobKind.VERIFY,
verify_result = client.verify_proof(
self._endpoint_id,
self._proof.id,
)
verify_job = client.create(job_create, trace=None)
logger.info(f"Verify job created with ID {verify_job.id}")
return verify_job

def _wait_for_verify(
self, client: JobsClient, timeout: int = 600, poll_interval: int = 10
):
"""
Wait for the verify job to finish.
"""
self._wait_for(self._verify_job, client, timeout, poll_interval, JobKind.VERIFY)
logger.info(f"Verify result is {verify_result.verification}")
logger.info(f"Verify time is {verify_result.verification_time}")
return True

def _wait_for(
self,
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "giza-actions"
version = "0.3.0"
version = "0.3.1"

description = "A Python SDK for Giza platform"
authors = [
Expand Down
48 changes: 14 additions & 34 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ape.exceptions import NetworkError
from giza.schemas.jobs import Job, JobList
from giza.schemas.proofs import Proof
from giza.schemas.verify import VerifyResponse

from giza_actions.agent import AgentResult, ContractHandler, GizaAgent

Expand All @@ -21,6 +22,12 @@ def get_proof(self, *args, **kwargs):
id=1, job_id=1, created_date="2022-01-01T00:00:00Z", request_id="123"
)

def verify_proof(self, *args, **kwargs):
return VerifyResponse(
verification=True,
verification_time=1,
)


class JobsClientStub:
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -235,11 +242,8 @@ def test_agentresult__get_proof_job():


@patch("giza_actions.agent.AgentResult._wait_for_proof")
@patch("giza_actions.agent.AgentResult._start_verify_job")
@patch("giza_actions.agent.AgentResult._wait_for_verify")
def test_agentresult__verify(
mock_wait_for_verify, mock_start_verify_job, mock_wait_for_proof
):
@patch("giza_actions.agent.AgentResult._verify_proof", return_value=True)
def test_agentresult__verify(mock_verify, mock_wait_for_proof):
result = AgentResult(
input=[],
result=[1],
Expand All @@ -252,8 +256,7 @@ def test_agentresult__verify(

assert result.verified is True
mock_wait_for_proof.assert_called_once()
mock_start_verify_job.assert_called_once()
mock_wait_for_verify.assert_called_once()
mock_verify.assert_called_once()


@patch("giza_actions.agent.AgentResult._wait_for")
Expand All @@ -273,43 +276,20 @@ def test_agentresult__wait_for_proof(mock_wait_for):
mock_wait_for.assert_called_once()


def test_agentresult__start_verify_job():
agent = Mock()
agent.framework = "CAIRO"
agent.model_id = 1
agent.version_id = 1

def test_agentresult__verify_proof():
result = AgentResult(
input=[],
result=[1],
request_id="123",
agent=agent,
agent=Mock(),
endpoint_client=EndpointsClientStub(),
)

# Add a dummy proof to the result so we can verify it
result._proof = Proof(
id=1, job_id=1, created_date="2022-01-01T00:00:00Z", request_id="123"
)

job = result._start_verify_job(JobsClientStub())

assert job.id == 1
assert job.size == "S"
assert job.status == "COMPLETED"


@patch("giza_actions.agent.AgentResult._wait_for")
def test_agentresult__wait_for_verify(mock_wait_for):
result = AgentResult(
input=[],
result=[1],
request_id="123",
agent=Mock(),
endpoint_client=EndpointsClientStub(),
)

result._wait_for_verify(JobsClientStub())
mock_wait_for.assert_called_once()
assert result._verify_proof(EndpointsClientStub())


def test_agentresult__wait_for_job_completed():
Expand Down
Loading