Skip to content

Commit

Permalink
Merge changes from biohackathon3; anndata changed to pydantic classes…
Browse files Browse the repository at this point in the history
… for now
  • Loading branch information
noahbruderer committed Dec 12, 2024
1 parent 171749c commit d12e2a7
Show file tree
Hide file tree
Showing 8 changed files with 685 additions and 42 deletions.
42 changes: 35 additions & 7 deletions benchmark/data/benchmark_api_calling_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ api_calling:
expected:
parts_of_query:
["https://bio.tools/api/t/", "\\?topic=", "[mM]etabolomics"]
- case: scanpy:tl:leiden
input:
prompt:
explicit_variable_names: "Perform Leiden clustering on the data with resolution 0.5."
expected:
parts_of_query: ["sc.tl.leiden\\(", "resolution=0.5", "\\)"]
- case: scanpy:tl:umap
input:
prompt:
explicit_variable_names: "Calculate UMAP embedding with minimum distance 0.3 and spread 1.0."
expected:
parts_of_query: ["sc.tl.umap\\(", "min_dist=0.3", "spread=1.0", "\\)"]
- case: scanpy:pl:scatter
input:
prompt:
Expand All @@ -82,7 +94,13 @@ api_calling:
help_request: "Can you help me with making a scatter plot with n_genes_by_counts and total_counts?"
expected:
parts_of_query:
["sc.pl.scatter\\(", "adata=adata", "n_genes_by_counts", "total_counts", "\\)"]
[
"sc.pl.scatter\\(",
"adata=adata",
"n_genes_by_counts",
"total_counts",
"\\)",
]
- case: scanpy:pl:pca
input:
prompt:
Expand All @@ -92,7 +110,13 @@ api_calling:
help_request: "Can you help me with plotting the PCA embedding with n_genes_by_counts and total_counts as colors?"
expected:
parts_of_query:
["sc.pl.pca\\(", "adata=adata", "n_genes_by_counts", "total_counts", "\\)"]
[
"sc.pl.pca\\(",
"adata=adata",
"n_genes_by_counts",
"total_counts",
"\\)",
]
- case: scanpy:pl:tsne
input:
prompt:
Expand All @@ -101,7 +125,8 @@ api_calling:
general_question: "How can I plot a tsne with n_genes_by_counts as colors?"
help_request: "Can you help me with plotting a tsne with n_genes_by_counts as colors?"
expected:
parts_of_query: ["sc.pl.tsne\\(", "adata=adata", "n_genes_by_counts", "\\)"]
parts_of_query:
["sc.pl.tsne\\(", "adata=adata", "n_genes_by_counts", "\\)"]
- case: scanpy:pl:umap
input:
prompt:
Expand All @@ -110,7 +135,8 @@ api_calling:
general_question: "How can I plot a umap with n_genes_by_counts as colors?"
help_request: "Can you help me with plotting a umap with n_genes_by_counts as colors?"
expected:
parts_of_query: ["sc.pl.umap\\(", "adata=adata", "n_genes_by_counts", "\\)"]
parts_of_query:
["sc.pl.umap\\(", "adata=adata", "n_genes_by_counts", "\\)"]
- case: scanpy:pl:draw_graph
input:
prompt:
Expand All @@ -119,16 +145,18 @@ api_calling:
general_question: "How can I plot a force-directed graph with n_genes_by_counts as colors?"
help_request: "Can you help me with plotting a force-directed graph with n_genes_by_counts as colors?"
expected:
parts_of_query: ["sc.pl.draw_graph\\(", "adata=adata", "n_genes_by_counts", "\\)"]
parts_of_query:
["sc.pl.draw_graph\\(", "adata=adata", "n_genes_by_counts", "\\)"]
- case: scanpy:pl:spatial
input:
prompt:
specific: "plot a the spatial data colored by n_genes_by_counts."
abbreviations: "spatial data plt with n_genes_by_counts as colors."
general_question: "How can I plot the spatial data with n_genes_by_counts as colors?"
help_request: "Can you help me with plotting the spatial data with n_genes_by_counts as colors?"
expected:
parts_of_query: ["sc.pl.spatial\\(", "adata=adata", "n_genes_by_counts", "\\)"]
expected:
parts_of_query:
["sc.pl.spatial\\(", "adata=adata", "n_genes_by_counts", "\\)"]
- case: anndata:read:h5ad
input:
prompt:
Expand Down
15 changes: 9 additions & 6 deletions benchmark/test_api_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
OncoKBQueryBuilder,
ScanpyPlQueryBuilder,
ScanpyPlQueryBuilderReduced,
ScanpyTlQueryBuilder,
AnnDataIOQueryBuilder,
format_as_rest_call,
format_as_python_call,
Expand Down Expand Up @@ -50,14 +51,12 @@ def run_test():
builder = OncoKBQueryBuilder()
elif "biotools" in yaml_data["case"]:
builder = BioToolsQueryBuilder()
elif "scanpy:pl" in yaml_data["case"]:
builder = ScanpyPlQueryBuilder()
parameters = builder.parameterise_query(
question=yaml_data["input"]["prompt"],
conversation=conversation,
)

api_query = format_as_rest_call(parameters)
api_query = format_as_rest_call(parameters[0])

score = []
for expected_part in ensure_iterable(
Expand All @@ -81,6 +80,7 @@ def run_test():
get_result_file_path(task),
)


def test_python_api_calling(
model_name,
test_data_api_calling,
Expand Down Expand Up @@ -108,12 +108,14 @@ def run_test():
builder = ScanpyPlQueryBuilder()
elif "anndata" in yaml_data["case"]:
builder = AnnDataIOQueryBuilder()
elif "scanpy:tl" in yaml_data["case"]:
builder = ScanpyTlQueryBuilder()
parameters = builder.parameterise_query(
question=yaml_data["input"]["prompt"],
conversation=conversation,
)

method_call = format_as_python_call(parameters[0])
method_call = format_as_python_call(parameters[0]) if parameters else ""

score = []
for expected_part in ensure_iterable(
Expand All @@ -137,6 +139,7 @@ def run_test():
get_result_file_path(task),
)


def test_python_api_calling_reduced(
model_name,
test_data_api_calling,
Expand Down Expand Up @@ -166,7 +169,7 @@ def run_test():
conversation=conversation,
)

method_call = format_as_python_call(parameters[0])
method_call = format_as_python_call(parameters[0]) if parameters else ""

score = []
for expected_part in ensure_iterable(
Expand All @@ -188,4 +191,4 @@ def run_test():
f"{n_iterations}",
yaml_data["hash"],
get_result_file_path(task),
)
)
2 changes: 2 additions & 0 deletions biochatter/api_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .oncokb import OncoKBFetcher, OncoKBInterpreter, OncoKBQueryBuilder
from .scanpy_pl import ScanpyPlQueryBuilder
from .scanpy_pl_reduced import ScanpyPlQueryBuilder as ScanpyPlQueryBuilderReduced
from .scanpy_pp_reduced import ScanpyPpQueryBuilder as ScanpyPpQueryBuilderReduced
from .scanpy_tl import ScanpyTlQueryBuilder

__all__ = [
Expand All @@ -38,6 +39,7 @@
"OncoKBQueryBuilder",
"ScanpyPlQueryBuilder",
"ScanpyPlQueryBuilderReduced",
"ScanpyPpQueryBuilderReduced",
"ScanpyTlQueryBuilder",
"format_as_python_call",
"format_as_rest_call",
Expand Down
14 changes: 5 additions & 9 deletions biochatter/api_agent/generate_pydantic_classes_from_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@

from docstring_parser import parse
from langchain_core.pydantic_v1 import Field, create_model

from biochatter.api_agent.abc import BaseAPIModel


def generate_pydantic_classes(module: ModuleType) -> list[type[BaseAPIModel]]:
"""Generate Pydantic classes for each callable.
Expand Down Expand Up @@ -117,17 +119,11 @@ def generate_pydantic_classes(module: ModuleType) -> list[type[BaseAPIModel]]:
fields[field_name] = (annotation, Field(**field_kwargs))

# Create the Pydantic model

tl_parameters_model = create_model(
name,
**fields,
__base__=BaseAPIModel
)
__base__=BaseAPIModel,
)
classes_list.append(tl_parameters_model)
return classes_list


# Example usage:
#import scanpy as sc
#generated_classes = generate_pydantic_classes(sc.tl)
#for func in generated_classes:
#print(func.model_json_schema())
Loading

0 comments on commit d12e2a7

Please sign in to comment.