Skip to content

Commit

Permalink
LLM tests + structure outputs (#60)
Browse files Browse the repository at this point in the history
* Move to stracture output

* Add LLM tests

* fix test

* Address comments

* Update Quorum/tests/conftest.py

Co-authored-by: yoav-el-certora <[email protected]>

* Update Quorum/tests/conftest.py

Co-authored-by: yoav-el-certora <[email protected]>

* Address comments

* Fix first test

* Fix all tests

---------

Co-authored-by: yoav-el-certora <[email protected]>
  • Loading branch information
nivcertora and yoav-el-certora authored Jan 2, 2025
1 parent 5ef29c5 commit e7ff52a
Show file tree
Hide file tree
Showing 8 changed files with 1,092 additions and 15 deletions.
14 changes: 9 additions & 5 deletions Quorum/entry_points/ipfs_validator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from Quorum.utils.chain_enum import Chain
import Quorum.utils.arg_validations as arg_valid
from Quorum.apis.block_explorers.chains_api import ChainAPI
from Quorum.llm.chains.ipfs_validation_chain import IPFSValidationChain
import Quorum.utils.config as config
import Quorum.utils.arg_validations as arg_valid
import Quorum.utils.pretty_printer as pp

from pathlib import Path
import argparse
Expand Down Expand Up @@ -72,10 +73,13 @@ def main():
answer = ipfs_validation_chain.execute(
prompt_templates = args.prompt_templates, ipfs=ipfs, payload=payload
)

# Output the LLM's response
print(answer)


if answer.incompatibilities:
pp.pretty_print("Found incompatibilities:", pp.Colors.FAILURE)
for incompatibility in answer.incompatibilities:
pp.pretty_print(incompatibility, pp.Colors.FAILURE)
else:
pp.pretty_print("LLM found no incompatibilities. Please Check manually.", pp.Colors.WARNING)

if __name__ == '__main__':
main()
31 changes: 26 additions & 5 deletions Quorum/llm/chains/ipfs_validation_chain.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,30 @@
from Quorum.llm.jinja_utils import render_prompt
from Quorum.llm.chains.cached_llm import CachedLLM
from typing import Optional
from pydantic import BaseModel, Field

from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph
from langchain_core.output_parsers import StrOutputParser

from Quorum.llm.jinja_utils import render_prompt
from Quorum.llm.chains.cached_llm import CachedLLM

class Incompatibility(BaseModel):
"""
Incompatibility is a Pydantic model that represents a mismatch between the IPFS and Solidity payloads.
"""
subject: str = Field(..., description="The subject of the incompatibility (e.g. disagreement between IPFS and Solidity).")
subject_in_ipfs: str = Field(..., description="The subject details as described in the IPFS payload.")
subject_in_solidity: str = Field(..., description="The subject details as described in the Solidity payload.")
description: str = Field(..., description="A detailed description of the incompatibility.")

class IncompatibilityArray(BaseModel):
"""
IncompatibilityArray is a Pydantic model that represents a list of incompatibilities between the IPFS and Solidity payloads.
"""
incompatibilities: Optional[list[Incompatibility]] = Field(
default=None,
description="A list of incompatibilities between the IPFS and Solidity payloads."
)

class IPFSValidationChain(CachedLLM):
"""
Expand All @@ -24,6 +43,8 @@ def __init__(self):
prompt templates for execution.
"""
super().__init__()

self.structured_llm = self.llm.with_structured_output(IncompatibilityArray)

# Define the workflow for the IPFS validation chain
workflow = StateGraph(state_schema=MessagesState)
Expand All @@ -43,7 +64,7 @@ def __call_model(self, state: MessagesState) -> MessagesState:
response = self.llm.invoke(messages)
return {"messages": response}

def execute(self, prompt_templates: list[str], ipfs: str, payload: str, thread_id: int = 1) -> str:
def execute(self, prompt_templates: list[str], ipfs: str, payload: str, thread_id: int = 1) -> IncompatibilityArray:
"""
Executes the IPFS validation workflow by rendering prompts, interacting with the LLM,
and retrieving the final validation report.
Expand Down Expand Up @@ -74,4 +95,4 @@ def execute(self, prompt_templates: list[str], ipfs: str, payload: str, thread_i
config={"configurable": {"thread_id": f"{thread_id}"}},
)

return StrOutputParser().parse(history["messages"][-1].content)
return self.structured_llm.invoke([h.content for h in history["messages"]])
37 changes: 32 additions & 5 deletions Quorum/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import pytest
import shutil
import json5 as json
from pathlib import Path
from typing import Generator

from Quorum.apis.block_explorers.source_code import SourceCode
import Quorum.utils.config as config

from pathlib import Path
import shutil

from typing import Generator


RESOURCES_DIR = Path(__file__).parent / 'resources'
EXPECTED_DIR = Path(__file__).parent / 'expected'
Expand Down Expand Up @@ -40,3 +39,31 @@ def tmp_cache() -> Generator[Path, None, None]:
cache.mkdir()
yield cache
shutil.rmtree(cache)


@pytest.fixture(scope="module")
def load_ipfs_validation_chain_inputs() -> tuple[str, str]:
llm_resource_dir = RESOURCES_DIR / "llm" / "ipfs_validation_chain"
ipfs_path = llm_resource_dir / "ipfs.txt"
source_code_path = llm_resource_dir / "source_code.sol"

ipfs_content = ipfs_path.read_text(encoding="utf-8")
source_code = source_code_path.read_text(encoding="utf-8")

return ipfs_content, source_code

@pytest.fixture
def expected_first_deposit_results():
expected_path = EXPECTED_DIR / 'test_llm' / 'first_deposit_chain.json'
with open(expected_path) as f:
expected = json.load(f)
return expected

@pytest.fixture
def first_deposit_chain_input():
llm_resource_dir = RESOURCES_DIR / "llm" / "first_deposit_chain"
source_code_path = llm_resource_dir / "source_code.sol"

source_code = source_code_path.read_text(encoding="utf-8")

return source_code
39 changes: 39 additions & 0 deletions Quorum/tests/expected/test_llm/first_deposit_chain.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"listings": [
{
"asset_symbol": "USDT",
"asset_address": "0xdAC17F958D2ee523a2206206994597C13D831ec7",
"supply_seed_amount": null,
"supply_indicator": false,
"approve_indicator": false
},
{
"asset_symbol": "WBTC",
"asset_address": "0x2260FAC5E5542a773Aa44fBCfeDf7C193bc2C599",
"supply_seed_amount": null,
"supply_indicator": false,
"approve_indicator": false
},
{
"asset_symbol": "WETH",
"asset_address": "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2",
"supply_seed_amount": null,
"supply_indicator": false,
"approve_indicator": false
},
{
"asset_symbol": "YFI",
"asset_address": "0x0bc529c00C6401aEF6D220BE8C6Ea1667F6Ad93e",
"supply_seed_amount": null,
"supply_indicator": false,
"approve_indicator": false
},
{
"asset_symbol": "ZRX",
"asset_address": "0xE41d2489571d322189246DaFA5ebDe1F4699F498",
"supply_seed_amount": null,
"supply_indicator": false,
"approve_indicator": false
}
]
}
Loading

0 comments on commit e7ff52a

Please sign in to comment.