Skip to content

Commit

Permalink
Merging changes from biohackthon3 and anndata classes remain pydantic…
Browse files Browse the repository at this point in the history
… forms for now
  • Loading branch information
noahbruderer committed Dec 12, 2024
1 parent 838729a commit 171749c
Show file tree
Hide file tree
Showing 11 changed files with 151 additions and 45 deletions.
48 changes: 42 additions & 6 deletions benchmark/data/benchmark_api_calling_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,59 @@ api_calling:
- case: scanpy:pl:scatter
input:
prompt:
exact_variable_names: "Make a scatter plot of n_genes_by_counts vs total_counts."
specific: "Make a scatter plot with axis being n_genes_by_counts vs total_counts."
abbreviations: "scatter plt with x-axis = n_genes_by_counts and y-axis = total_counts."
general_question: "How can I make a scatter plot with n_genes_by_counts and total_counts?"
help_request: "Can you help me with making a scatter plot with n_genes_by_counts and total_counts?"
expected:
parts_of_query:
["sc.pl.scatter\\(", "n_genes_by_counts", "total_counts", "\\)"]
["sc.pl.scatter\\(", "adata=adata", "n_genes_by_counts", "total_counts", "\\)"]
- case: scanpy:pl:pca
input:
prompt:
explicit_variable_names: "plot the PCA of the data colored by n_genes_by_counts and total_counts."
specific: "plot the PCA embedding colored by n_genes_by_counts and total_counts"
abbreviations: "plt the PC emb with n_genes_by_counts and total_counts as colors."
general_question: "How can I plot the PCA embedding with n_genes_by_counts and total_counts as colors?"
help_request: "Can you help me with plotting the PCA embedding with n_genes_by_counts and total_counts as colors?"
expected:
parts_of_query:
["sc.pl.pca\\(", "n_genes_by_counts", "total_counts", "\\)"]
["sc.pl.pca\\(", "adata=adata", "n_genes_by_counts", "total_counts", "\\)"]
- case: scanpy:pl:tsne
input:
prompt:
explicit_variable_names: "plot the tsne embeddding of the data colored by n_genes_by_counts."
specific: "plot a tsne colored by n_genes_by_counts."
abbreviations: "tsne plt with n_genes_by_counts as colors."
general_question: "How can I plot a tsne with n_genes_by_counts as colors?"
help_request: "Can you help me with plotting a tsne with n_genes_by_counts as colors?"
expected:
parts_of_query: ["sc.pl.tsne\\(", "n_genes_by_counts", "\\)"]
parts_of_query: ["sc.pl.tsne\\(", "adata=adata", "n_genes_by_counts", "\\)"]
- case: scanpy:pl:umap
input:
prompt:
specific: "plot a umap colored by number of n_genes_by_counts."
abbreviations: "umap plt with n_genes_by_counts as colors."
general_question: "How can I plot a umap with n_genes_by_counts as colors?"
help_request: "Can you help me with plotting a umap with n_genes_by_counts as colors?"
expected:
parts_of_query: ["sc.pl.umap\\(", "adata=adata", "n_genes_by_counts", "\\)"]
- case: scanpy:pl:draw_graph
input:
prompt:
specific: "plot a force-directed graph colored by n_genes_by_counts."
abbreviations: "force-directed plt with n_genes_by_counts as colors."
general_question: "How can I plot a force-directed graph with n_genes_by_counts as colors?"
help_request: "Can you help me with plotting a force-directed graph with n_genes_by_counts as colors?"
expected:
parts_of_query: ["sc.pl.draw_graph\\(", "adata=adata", "n_genes_by_counts", "\\)"]
- case: scanpy:pl:spatial
input:
prompt:
specific: "plot a the spatial data colored by n_genes_by_counts."
abbreviations: "spatial data plt with n_genes_by_counts as colors."
general_question: "How can I plot the spatial data with n_genes_by_counts as colors?"
help_request: "Can you help me with plotting the spatial data with n_genes_by_counts as colors?"
expected:
parts_of_query: ["sc.pl.spatial\\(", "adata=adata", "n_genes_by_counts", "\\)"]
- case: anndata:read:h5ad
input:
prompt:
Expand Down
63 changes: 62 additions & 1 deletion benchmark/test_api_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
import pytest

from biochatter._misc import ensure_iterable
from biochatter.api_agent import BioToolsQueryBuilder, OncoKBQueryBuilder, ScanpyPlQueryBuilder, AnnDataIOQueryBuilder, format_as_rest_call, format_as_python_call
from biochatter.api_agent import (
BioToolsQueryBuilder,
OncoKBQueryBuilder,
ScanpyPlQueryBuilder,
ScanpyPlQueryBuilderReduced,
AnnDataIOQueryBuilder,
format_as_rest_call,
format_as_python_call,
)

from .benchmark_utils import (
get_result_file_path,
Expand Down Expand Up @@ -128,3 +136,56 @@ def run_test():
yaml_data["hash"],
get_result_file_path(task),
)

def test_python_api_calling_reduced(
model_name,
test_data_api_calling,
conversation,
multiple_testing,
):
"""Test the Python API calling capability with reduced Scanpy plotting class."""
task = f"{inspect.currentframe().f_code.co_name.replace('test_', '')}"
yaml_data = test_data_api_calling

skip_if_already_run(
model_name=model_name,
task=task,
md5_hash=yaml_data["hash"],
)

if "scanpy:pl" not in yaml_data["case"]:
pytest.skip(
"Function to be tested is not a Scanpy plotting API",
)

def run_test():
conversation.reset() # needs to be reset for each test
builder = ScanpyPlQueryBuilderReduced()
parameters = builder.parameterise_query(
question=yaml_data["input"]["prompt"],
conversation=conversation,
)

method_call = format_as_python_call(parameters[0])

score = []
for expected_part in ensure_iterable(
yaml_data["expected"]["parts_of_query"],
):
if re.search(expected_part, method_call):
score.append(True)
else:
score.append(False)

return calculate_bool_vector_score(score)

mean_score, max, n_iterations = multiple_testing(run_test)

write_results_to_file(
model_name,
yaml_data["case"],
f"{mean_score}/{max}",
f"{n_iterations}",
yaml_data["hash"],
get_result_file_path(task),
)
4 changes: 3 additions & 1 deletion biochatter/api_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from .abc import BaseFetcher, BaseInterpreter, BaseQueryBuilder
from .anndata import AnnDataIOQueryBuilder
from .anndata_agent import AnnDataIOQueryBuilder
from .api_agent import APIAgent
from .bio_tools import BioToolsFetcher, BioToolsInterpreter, BioToolsQueryBuilder
from .blast import (
Expand All @@ -17,6 +17,7 @@
from .formatters import format_as_python_call, format_as_rest_call
from .oncokb import OncoKBFetcher, OncoKBInterpreter, OncoKBQueryBuilder
from .scanpy_pl import ScanpyPlQueryBuilder
from .scanpy_pl_reduced import ScanpyPlQueryBuilder as ScanpyPlQueryBuilderReduced
from .scanpy_tl import ScanpyTlQueryBuilder

__all__ = [
Expand All @@ -36,6 +37,7 @@
"OncoKBInterpreter",
"OncoKBQueryBuilder",
"ScanpyPlQueryBuilder",
"ScanpyPlQueryBuilderReduced",
"ScanpyTlQueryBuilder",
"format_as_python_call",
"format_as_rest_call",
Expand Down
22 changes: 11 additions & 11 deletions biochatter/api_agent/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Callable

from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, create_model, ConfigDict

from biochatter.llm_connect import Conversation

Expand Down Expand Up @@ -167,13 +167,13 @@ class BaseAPIModel(BaseModel):
None,
description="Unique identifier for the model instance",
)
method_name: str = Field(..., description="Name of the method to be executed")

class Config:
"""BaseModel class configuration.
Ensures the model can be extended without strict type checking on
inherited fields.
"""

arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)

class BaseTools():
"""Abstract base class for tools."""
def make_pydantic_tools(self) -> list[BaseAPIModel]:
"""Uses pydantics create_model to create a list of pydantic tools from a dictionary of parameters"""
tools = []
for func_name, tool_params in self.tools_params.items():
tools.append(create_model(func_name, **tool_params, __base__=BaseAPIModel))
return tools
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class MapAnnData(BaseAPIModel):
class ReadH5AD(BaseAPIModel):
"""Read .h5ad-formatted hdf5 file."""

title: str = Field(default="io.read_h5ad", description="NEVER CHANGE")
method_name: str = Field(default="io.read_h5ad", description="NEVER CHANGE")
filename: str = Field(default="dummy.h5ad", description="Path to the .h5ad file")
backed: str | None = Field(
default=None,
Expand All @@ -127,7 +127,7 @@ class ReadH5AD(BaseAPIModel):
class ReadZarr(BaseAPIModel):
"""Read from a hierarchical Zarr array store."""

title: str = Field(default="io.read_zarr", description="NEVER CHANGE")
method_name: str = Field(default="io.read_zarr", description="NEVER CHANGE")
filename: str = Field(
default="placeholder.zarr",
description="Path or URL to the Zarr store",
Expand All @@ -137,7 +137,7 @@ class ReadZarr(BaseAPIModel):
class ReadCSV(BaseAPIModel):
"""Read .csv file."""

title: str = Field(default="io.read_csv", description="NEVER CHANGE")
method_name: str = Field(default="io.read_csv", description="NEVER CHANGE")
filename: str = Field(
default="placeholder.csv",
description="Path to the .csv file",
Expand All @@ -155,7 +155,7 @@ class ReadCSV(BaseAPIModel):
class ReadExcel(BaseAPIModel):
"""Read .xlsx (Excel) file."""

title: str = Field(default="io.read_excel", description="NEVER CHANGE")
method_name: str = Field(default="io.read_excel", description="NEVER CHANGE")
filename: str = Field(
default="placeholder.xlsx",
description="Path to the .xlsx file",
Expand All @@ -170,15 +170,15 @@ class ReadExcel(BaseAPIModel):
class ReadHDF(BaseAPIModel):
"""Read .h5 (hdf5) file."""

title: str = Field(default="io.read_hdf", description="NEVER CHANGE")
method_name: str = Field(default="io.read_hdf", description="NEVER CHANGE")
filename: str = Field(default="placeholder.h5", description="Path to the .h5 file")
key: str | None = Field(None, description="Group key within the .h5 file")


class ReadLoom(BaseAPIModel):
"""Read .loom-formatted hdf5 file."""

title: str = Field(default="io.read_loom", description="NEVER CHANGE")
method_name: str = Field(default="io.read_loom", description="NEVER CHANGE")
filename: str = Field(
default="placeholder.loom",
description="Path to the .loom file",
Expand All @@ -199,7 +199,7 @@ class ReadLoom(BaseAPIModel):
class ReadMTX(BaseAPIModel):
"""Read .mtx file."""

title: str = Field(default="io.read_mtx", description="NEVER CHANGE")
method_name: str = Field(default="io.read_mtx", description="NEVER CHANGE")
filename: str = Field(
default="placeholder.mtx",
description="Path to the .mtx file",
Expand All @@ -210,7 +210,7 @@ class ReadMTX(BaseAPIModel):
class ReadText(BaseAPIModel):
"""Read .txt, .tab, .data (text) file."""

title: str = Field(default="io.read_text", description="NEVER CHANGE")
method_name: str = Field(default="io.read_text", description="NEVER CHANGE")
filename: str = Field(
default="placeholder.txt",
description="Path to the text file",
Expand Down
2 changes: 1 addition & 1 deletion biochatter/api_agent/formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from urllib.parse import urlencode

from .abc import BaseAPIModel, BaseModel
from .anndata import MapAnnData
from .anndata_agent import MapAnnData


def format_as_rest_call(model: BaseModel) -> str:
Expand Down
21 changes: 16 additions & 5 deletions biochatter/api_agent/generate_pydantic_classes_from_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from typing import Any

from docstring_parser import parse
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.pydantic_v1 import Field, create_model
from biochatter.api_agent.abc import BaseAPIModel


def generate_pydantic_classes(module: ModuleType) -> list[type[BaseModel]]:
def generate_pydantic_classes(module: ModuleType) -> list[type[BaseAPIModel]]:
"""Generate Pydantic classes for each callable.
For each callable (function/method) in a given module. Extracts parameters
Expand Down Expand Up @@ -52,7 +52,7 @@ def generate_pydantic_classes(module: ModuleType) -> list[type[BaseModel]]:
required.
"""
base_attributes = set(dir(BaseModel))
base_attributes = set(dir(BaseAPIModel))
classes_list = []

for name, func in inspect.getmembers(module, inspect.isfunction):
Expand Down Expand Up @@ -117,6 +117,17 @@ def generate_pydantic_classes(module: ModuleType) -> list[type[BaseModel]]:
fields[field_name] = (annotation, Field(**field_kwargs))

# Create the Pydantic model
tl_parameters_model = create_model(name, **fields)
tl_parameters_model = create_model(
name,
**fields,
__base__=BaseAPIModel
)
classes_list.append(tl_parameters_model)
return classes_list


# Example usage:
#import scanpy as sc
#generated_classes = generate_pydantic_classes(sc.tl)
#for func in generated_classes:
#print(func.model_json_schema())
7 changes: 2 additions & 5 deletions biochatter/api_agent/scanpy_tl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@

from collections.abc import Callable
from types import ModuleType
from typing import TYPE_CHECKING

from langchain_core.output_parsers import PydanticToolsParser

from biochatter.llm_connect import Conversation

from .abc import BaseAPIModel, BaseQueryBuilder
from .generate_pydantic_classes_from_module import generate_pydantic_classes

if TYPE_CHECKING:
from biochatter.llm_connect import Conversation
from biochatter.llm_connect import Conversation

SCANPY_QUERY_PROMPT = """
You are a world class algorithm for creating queries in structured formats. Your
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,7 @@ target-version = "py310"

[tool.ruff.lint.pycodestyle]
max-doc-length = 80

[tool.ruff.lint.per-file-ignores]
"test/*" = ["ANN001", "ANN201", "D100", "D101", "D102", "D103", "D104", "I001", "S101"]
"benchmark/*" = ["ANN001", "ANN201", "D100", "D101", "D102", "D103", "D104", "I001", "S101"]
4 changes: 2 additions & 2 deletions test/test_api_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
BaseInterpreter,
BaseQueryBuilder,
)
from biochatter.api_agent.anndata import ANNDATA_IO_QUERY_PROMPT, AnnDataIOQueryBuilder
from biochatter.api_agent.anndata_agent import AnnDataIOQueryBuilder, ANNDATA_IO_QUERY_PROMPT
from biochatter.api_agent.api_agent import APIAgent
from biochatter.api_agent.blast import (
BLAST_QUERY_PROMPT,
Expand Down Expand Up @@ -481,7 +481,7 @@ class TestAnndataIOQueryBuilder:
@pytest.fixture
def mock_create_runnable(self):
with patch(
"biochatter.api_agent.anndata.AnnDataIOQueryBuilder.create_runnable",
"biochatter.api_agent.anndata_agent.AnnDataIOQueryBuilder.create_runnable",
) as mock:
mock_runnable = MagicMock()
mock.return_value = mock_runnable
Expand Down
5 changes: 0 additions & 5 deletions test/test_llm_connect.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
# ruff: noqa: S101 # Use of assert detected
# ruff: noqa: ANN201 # No docstring in public function
# ruff: noqa: D103 # Missing docstring in public function
# ruff: noqa: D100 # Missing docstring in public module

import os
from unittest.mock import MagicMock, Mock, mock_open, patch

Expand Down

0 comments on commit 171749c

Please sign in to comment.