From d12e2a79becf1f3d1176566a1f1f023fc141cd6f Mon Sep 17 00:00:00 2001 From: noahbruderer Date: Thu, 12 Dec 2024 19:24:18 +0100 Subject: [PATCH] Merge changes from biohackathon3; anndata changed to pydantic classes for now --- .../data/benchmark_api_calling_data.yaml | 42 ++- benchmark/test_api_calling.py | 15 +- biochatter/api_agent/__init__.py | 2 + .../generate_pydantic_classes_from_module.py | 14 +- biochatter/api_agent/scanpy_pp.py | 226 ++++++++++++++ biochatter/api_agent/scanpy_pp_reduced.py | 288 ++++++++++++++++++ poetry.lock | 41 +-- test/test_api_agent.py | 99 ++++++ 8 files changed, 685 insertions(+), 42 deletions(-) create mode 100644 biochatter/api_agent/scanpy_pp.py create mode 100644 biochatter/api_agent/scanpy_pp_reduced.py diff --git a/benchmark/data/benchmark_api_calling_data.yaml b/benchmark/data/benchmark_api_calling_data.yaml index 35853fec..a9530144 100644 --- a/benchmark/data/benchmark_api_calling_data.yaml +++ b/benchmark/data/benchmark_api_calling_data.yaml @@ -73,6 +73,18 @@ api_calling: expected: parts_of_query: ["https://bio.tools/api/t/", "\\?topic=", "[mM]etabolomics"] + - case: scanpy:tl:leiden + input: + prompt: + explicit_variable_names: "Perform Leiden clustering on the data with resolution 0.5." + expected: + parts_of_query: ["sc.tl.leiden\\(", "resolution=0.5", "\\)"] + - case: scanpy:tl:umap + input: + prompt: + explicit_variable_names: "Calculate UMAP embedding with minimum distance 0.3 and spread 1.0." + expected: + parts_of_query: ["sc.tl.umap\\(", "min_dist=0.3", "spread=1.0", "\\)"] - case: scanpy:pl:scatter input: prompt: @@ -82,7 +94,13 @@ api_calling: help_request: "Can you help me with making a scatter plot with n_genes_by_counts and total_counts?" expected: parts_of_query: - ["sc.pl.scatter\\(", "adata=adata", "n_genes_by_counts", "total_counts", "\\)"] + [ + "sc.pl.scatter\\(", + "adata=adata", + "n_genes_by_counts", + "total_counts", + "\\)", + ] - case: scanpy:pl:pca input: prompt: @@ -92,7 +110,13 @@ api_calling: help_request: "Can you help me with plotting the PCA embedding with n_genes_by_counts and total_counts as colors?" expected: parts_of_query: - ["sc.pl.pca\\(", "adata=adata", "n_genes_by_counts", "total_counts", "\\)"] + [ + "sc.pl.pca\\(", + "adata=adata", + "n_genes_by_counts", + "total_counts", + "\\)", + ] - case: scanpy:pl:tsne input: prompt: @@ -101,7 +125,8 @@ api_calling: general_question: "How can I plot a tsne with n_genes_by_counts as colors?" help_request: "Can you help me with plotting a tsne with n_genes_by_counts as colors?" expected: - parts_of_query: ["sc.pl.tsne\\(", "adata=adata", "n_genes_by_counts", "\\)"] + parts_of_query: + ["sc.pl.tsne\\(", "adata=adata", "n_genes_by_counts", "\\)"] - case: scanpy:pl:umap input: prompt: @@ -110,7 +135,8 @@ api_calling: general_question: "How can I plot a umap with n_genes_by_counts as colors?" help_request: "Can you help me with plotting a umap with n_genes_by_counts as colors?" expected: - parts_of_query: ["sc.pl.umap\\(", "adata=adata", "n_genes_by_counts", "\\)"] + parts_of_query: + ["sc.pl.umap\\(", "adata=adata", "n_genes_by_counts", "\\)"] - case: scanpy:pl:draw_graph input: prompt: @@ -119,7 +145,8 @@ api_calling: general_question: "How can I plot a force-directed graph with n_genes_by_counts as colors?" help_request: "Can you help me with plotting a force-directed graph with n_genes_by_counts as colors?" expected: - parts_of_query: ["sc.pl.draw_graph\\(", "adata=adata", "n_genes_by_counts", "\\)"] + parts_of_query: + ["sc.pl.draw_graph\\(", "adata=adata", "n_genes_by_counts", "\\)"] - case: scanpy:pl:spatial input: prompt: @@ -127,8 +154,9 @@ api_calling: abbreviations: "spatial data plt with n_genes_by_counts as colors." general_question: "How can I plot the spatial data with n_genes_by_counts as colors?" help_request: "Can you help me with plotting the spatial data with n_genes_by_counts as colors?" - expected: - parts_of_query: ["sc.pl.spatial\\(", "adata=adata", "n_genes_by_counts", "\\)"] + expected: + parts_of_query: + ["sc.pl.spatial\\(", "adata=adata", "n_genes_by_counts", "\\)"] - case: anndata:read:h5ad input: prompt: diff --git a/benchmark/test_api_calling.py b/benchmark/test_api_calling.py index 0abf21fc..eb0f2156 100644 --- a/benchmark/test_api_calling.py +++ b/benchmark/test_api_calling.py @@ -9,6 +9,7 @@ OncoKBQueryBuilder, ScanpyPlQueryBuilder, ScanpyPlQueryBuilderReduced, + ScanpyTlQueryBuilder, AnnDataIOQueryBuilder, format_as_rest_call, format_as_python_call, @@ -50,14 +51,12 @@ def run_test(): builder = OncoKBQueryBuilder() elif "biotools" in yaml_data["case"]: builder = BioToolsQueryBuilder() - elif "scanpy:pl" in yaml_data["case"]: - builder = ScanpyPlQueryBuilder() parameters = builder.parameterise_query( question=yaml_data["input"]["prompt"], conversation=conversation, ) - api_query = format_as_rest_call(parameters) + api_query = format_as_rest_call(parameters[0]) score = [] for expected_part in ensure_iterable( @@ -81,6 +80,7 @@ def run_test(): get_result_file_path(task), ) + def test_python_api_calling( model_name, test_data_api_calling, @@ -108,12 +108,14 @@ def run_test(): builder = ScanpyPlQueryBuilder() elif "anndata" in yaml_data["case"]: builder = AnnDataIOQueryBuilder() + elif "scanpy:tl" in yaml_data["case"]: + builder = ScanpyTlQueryBuilder() parameters = builder.parameterise_query( question=yaml_data["input"]["prompt"], conversation=conversation, ) - method_call = format_as_python_call(parameters[0]) + method_call = format_as_python_call(parameters[0]) if parameters else "" score = [] for expected_part in ensure_iterable( @@ -137,6 +139,7 @@ def run_test(): get_result_file_path(task), ) + def test_python_api_calling_reduced( model_name, test_data_api_calling, @@ -166,7 +169,7 @@ def run_test(): conversation=conversation, ) - method_call = format_as_python_call(parameters[0]) + method_call = format_as_python_call(parameters[0]) if parameters else "" score = [] for expected_part in ensure_iterable( @@ -188,4 +191,4 @@ def run_test(): f"{n_iterations}", yaml_data["hash"], get_result_file_path(task), - ) \ No newline at end of file + ) diff --git a/biochatter/api_agent/__init__.py b/biochatter/api_agent/__init__.py index 005544ca..dc602665 100644 --- a/biochatter/api_agent/__init__.py +++ b/biochatter/api_agent/__init__.py @@ -18,6 +18,7 @@ from .oncokb import OncoKBFetcher, OncoKBInterpreter, OncoKBQueryBuilder from .scanpy_pl import ScanpyPlQueryBuilder from .scanpy_pl_reduced import ScanpyPlQueryBuilder as ScanpyPlQueryBuilderReduced +from .scanpy_pp_reduced import ScanpyPpQueryBuilder as ScanpyPpQueryBuilderReduced from .scanpy_tl import ScanpyTlQueryBuilder __all__ = [ @@ -38,6 +39,7 @@ "OncoKBQueryBuilder", "ScanpyPlQueryBuilder", "ScanpyPlQueryBuilderReduced", + "ScanpyPpQueryBuilderReduced", "ScanpyTlQueryBuilder", "format_as_python_call", "format_as_rest_call", diff --git a/biochatter/api_agent/generate_pydantic_classes_from_module.py b/biochatter/api_agent/generate_pydantic_classes_from_module.py index aa34014b..0122bc04 100644 --- a/biochatter/api_agent/generate_pydantic_classes_from_module.py +++ b/biochatter/api_agent/generate_pydantic_classes_from_module.py @@ -21,8 +21,10 @@ from docstring_parser import parse from langchain_core.pydantic_v1 import Field, create_model + from biochatter.api_agent.abc import BaseAPIModel + def generate_pydantic_classes(module: ModuleType) -> list[type[BaseAPIModel]]: """Generate Pydantic classes for each callable. @@ -117,17 +119,11 @@ def generate_pydantic_classes(module: ModuleType) -> list[type[BaseAPIModel]]: fields[field_name] = (annotation, Field(**field_kwargs)) # Create the Pydantic model + tl_parameters_model = create_model( name, **fields, - __base__=BaseAPIModel - ) + __base__=BaseAPIModel, + ) classes_list.append(tl_parameters_model) return classes_list - - -# Example usage: -#import scanpy as sc -#generated_classes = generate_pydantic_classes(sc.tl) -#for func in generated_classes: -#print(func.model_json_schema()) diff --git a/biochatter/api_agent/scanpy_pp.py b/biochatter/api_agent/scanpy_pp.py new file mode 100644 index 00000000..f80ac51b --- /dev/null +++ b/biochatter/api_agent/scanpy_pp.py @@ -0,0 +1,226 @@ +from typing import Collection, Literal +from pydantic import BaseModel, Field + +class CalculateQCMetricsParams(BaseModel): + adata: str = Field(..., description="Annotated data matrix") + expr_type: str = Field('counts', description="Name of kind of values in X") + var_type: str = Field('genes', description="The kind of thing the variables are") + qc_vars: str = Field("", description="Keys for boolean columns of .var for control variables") + percent_top: str = Field("50,100,200,500", description="Ranks for library complexity assessment") + layer: str = Field(None, description="Layer to use for expression values") + use_raw: bool = Field(False, description="Use adata.raw.X instead of adata.X") + inplace: bool = Field(False, description="Place calculated metrics in adata's .obs and .var") + log1p: bool = Field(True, description="Compute log1p transformed annotations") + parallel: bool | None = Field(None, description="Parallel computation flag") + + class Config: + arbitrary_types_allowed = True + + +class FilterCellsParams(BaseModel): + data: str = Field(..., description="The (annotated) data matrix of shape n_obs × n_vars. Rows correspond to cells and columns to genes.") + min_counts: int | None = Field(None, description="Minimum number of counts required for a cell to pass filtering.") + min_genes: int | None = Field(None, description="Minimum number of genes expressed required for a cell to pass filtering.") + max_counts: int | None = Field(None, description="Maximum number of counts required for a cell to pass filtering.") + max_genes: int | None = Field(None, description="Maximum number of genes expressed required for a cell to pass filtering.") + inplace: bool = Field(True, description="Perform computation inplace or return result.") + copy: bool = Field(False, description="Whether to copy the data or modify it inplace.") + + class Config: + arbitrary_types_allowed = True + + +class FilterGenesParams(BaseModel): + data: str = Field(..., description="An annotated data matrix of shape n_obs × n_vars. Rows correspond to cells and columns to genes.") + min_counts: int | None = Field(None, description="Minimum number of counts required for a gene to pass filtering.") + min_cells: int | None = Field(None, description="Minimum number of cells in which the gene is expressed required for the gene to pass filtering.") + max_counts: int | None = Field(None, description="Maximum number of counts allowed for a gene to pass filtering.") + max_cells: int | None = Field(None, description="Maximum number of cells in which the gene is expressed allowed for the gene to pass filtering.") + inplace: bool = Field(True, description="Perform computation inplace or return result.") + copy: bool = Field(False, description="Whether to return a copy of the data (not modifying the original).") + + class Config: + arbitrary_types_allowed = True + + +class HighlyVariableGenesParams(BaseModel): + adata: str = Field(..., description="Annotated data matrix of shape n_obs × n_vars. Rows correspond to cells and columns to genes.") + layer: str | None = Field(None, description="Use adata.layers[layer] for expression values instead of adata.X.") + n_top_genes: int | None = Field(None, description="Number of highly-variable genes to keep. Mandatory if flavor='seurat_v3'.") + min_mean: float = Field(0.0125, description="Minimum mean expression threshold for highly variable genes. Ignored if flavor='seurat_v3'.") + max_mean: float = Field(3, description="Maximum mean expression threshold for highly variable genes. Ignored if flavor='seurat_v3'.") + min_disp: float = Field(0.5, description="Minimum dispersion threshold for highly variable genes. Ignored if flavor='seurat_v3'.") + max_disp: float = Field(float('inf'), description="Maximum dispersion threshold for highly variable genes. Ignored if flavor='seurat_v3'.") + span: float = Field(0.3, description="Fraction of the data (cells) used in variance estimation for the loess model fit if flavor='seurat_v3'.") + n_bins: int = Field(20, description="Number of bins for binning the mean gene expression. Normalization is done per bin.") + flavor: Literal['seurat', 'cell_ranger', 'seurat_v3', 'seurat_v3_paper'] = Field('seurat', description="The method to use for identifying highly variable genes.") + subset: bool = Field(False, description="If True, subset to highly-variable genes, otherwise just indicate them.") + inplace: bool = Field(True, description="Whether to place calculated metrics in .var or return them.") + batch_key: str | None = Field(None, description="If specified, highly-variable genes are selected separately within each batch and merged.") + check_values: bool = Field(True, description="Whether to check if counts in selected layer are integers (relevant for flavor='seurat_v3').") + + class Config: + arbitrary_types_allowed = True + +class Log1pParams(BaseModel): + data: str = Field(..., description="The (annotated) data matrix of shape n_obs × n_vars. Rows correspond to cells and columns to genes.") + base: float | None = Field(None, description="Base of the logarithm. Natural logarithm is used by default.") + copy: bool = Field(False, description="If True, a copy of the data is returned. Otherwise, the operation is done inplace.") + chunked: bool | None = Field(None, description="Process the data matrix in chunks, which will save memory. Applies only to AnnData.") + chunk_size: int | None = Field(None, description="Number of observations (n_obs) per chunk to process the data in.") + layer: str | None = Field(None, description="Entry of layers to transform.") + obsm: str | None = Field(None, description="Entry of obsm to transform.") + + class Config: + arbitrary_types_allowed = True + +class PCAParams(BaseModel): + data: str = Field(..., description="The (annotated) data matrix of shape n_obs × n_vars. Rows correspond to cells and columns to genes.") + n_comps: int | None = Field(None, description="Number of principal components to compute. Defaults to 50, or 1 - minimum dimension size of selected representation.") + layer: str | None = Field(None, description="If provided, which element of layers to use for PCA.") + zero_center: bool = Field(True, description="If True, compute standard PCA from covariance matrix. If False, omit zero-centering variables.") + svd_solver: str | None = Field(None, description="SVD solver to use. Options: 'auto', 'arpack', 'randomized', 'lobpcg', or 'tsqr'.") + random_state: int | None = Field(0, description="Change to use different initial states for the optimization.") + return_info: bool = Field(False, description="Only relevant when not passing an AnnData. Whether to return PCA info.") + mask_var: str | None = Field(None, description="To run PCA only on certain genes. Default is .var['highly_variable'] if available.") + use_highly_variable: bool | None = Field(None, description="Whether to use highly variable genes only, stored in .var['highly_variable']. Deprecated in 1.10.0.") + dtype: str = Field('float32', description="Numpy data type string to which to convert the result.") + chunked: bool = Field(False, description="If True, perform incremental PCA using sklearn IncrementalPCA or dask-ml IncrementalPCA.") + chunk_size: int | None = Field(None, description="Number of observations to include in each chunk. Required if chunked=True.") + copy: bool = Field(False, description="If True, a copy of the data is returned when AnnData is passed. Otherwise, the operation is done inplace.") + + class Config: + arbitrary_types_allowed = True + +class NormalizeTotalParams(BaseModel): + adata: str = Field(..., description="The annotated data matrix of shape n_obs × n_vars. Rows correspond to cells and columns to genes.") + target_sum: float | None = Field(None, description="Target sum after normalization. If None, each cell will have total counts equal to the median before normalization.") + exclude_highly_expressed: bool = Field(False, description="If True, exclude highly expressed genes from normalization computation.") + max_fraction: float = Field(0.05, description="If exclude_highly_expressed=True, consider a gene as highly expressed if it has more than max_fraction of the total counts in at least one cell.") + key_added: str | None = Field(None, description="Name of the field in adata.obs where the normalization factor is stored.") + layer: str | None = Field(None, description="Layer to normalize instead of X. If None, normalize X.") + inplace: bool = Field(True, description="Whether to update adata or return normalized copies of adata.X and adata.layers.") + copy: bool = Field(False, description="Whether to modify the copied input object. Not compatible with inplace=False.") + + class Config: + arbitrary_types_allowed = True + +class RegressOutParams(BaseModel): + adata: str = Field(..., description="The annotated data matrix.") + keys: str | Collection[str] = Field(..., description="Keys for observation annotation on which to regress on. Can be a single key or a collection of keys.") + layer: str | None = Field(None, description="Layer to regress on, if provided.") + n_jobs: int | None = Field(None, description="Number of jobs for parallel computation. None means using default n_jobs.") + copy: bool = Field(False, description="If True, a copy of the data will be returned. Otherwise, modifies in-place.") + + class Config: + arbitrary_types_allowed = True + + +class ScaleParams(BaseModel): + data: str = Field(..., description="The (annotated) data matrix of shape n_obs × n_vars. Rows correspond to cells and columns to genes.") + zero_center: bool = Field(True, description="If False, omit zero-centering variables, which allows to handle sparse input efficiently.") + max_value: float | None = Field(None, description="Clip (truncate) to this value after scaling. If None, do not clip.") + copy: bool = Field(False, description="Whether this function should be performed inplace.") + layer: str | None = Field(None, description="If provided, which element of layers to scale.") + obsm: str | None = Field(None, description="If provided, which element of obsm to scale.") + mask_obs: str | None = Field(None, description="Restrict the scaling to a certain set of observations. The mask is specified as a boolean array or a string referring to an array in obs.") + + class Config: + arbitrary_types_allowed = True + +class SubsampleParams(BaseModel): + data: str = Field(..., description="The (annotated) data matrix of shape n_obs × n_vars. Rows correspond to cells and columns to genes.") + fraction: float | None = Field(None, description="Subsample to this fraction of the number of observations.") + n_obs: int | None = Field(None, description="Subsample to this number of observations.") + random_state: int | None = Field(0, description="Random seed to change subsampling.") + copy: bool = Field(False, description="If an AnnData is passed, determines whether a copy is returned.") + + class Config: + arbitrary_types_allowed = True + +class DownsampleCountsParams(BaseModel): + adata: str = Field(..., description="Annotated data matrix.") + counts_per_cell: int | None = Field(None, description="Target total counts per cell. If a cell has more than ‘counts_per_cell’, it will be downsampled to this number. Can be an integer or integer ndarray with same length as number of observations.") + total_counts: int | None = Field(None, description="Target total counts. If the count matrix has more than total_counts, it will be downsampled to this number.") + random_state: int | None = Field(0, description="Random seed for subsampling.") + replace: bool = Field(False, description="Whether to sample the counts with replacement.") + copy: bool = Field(False, description="Determines whether a copy of adata is returned.") + + class Config: + arbitrary_types_allowed = True + +class RecipeZheng17Params(BaseModel): + adata: str = Field(..., description="Annotated data matrix.") + n_top_genes: int = Field(1000, description="Number of genes to keep.") + log: bool = Field(True, description="Take logarithm. If True, log-transform data after filtering.") + plot: bool = Field(False, description="Show a plot of the gene dispersion vs. mean relation.") + copy: bool = Field(False, description="Return a copy of adata instead of updating it.") + + class Config: + arbitrary_types_allowed = True + +class RecipeWeinreb17Params(BaseModel): + adata: str = Field(..., description="Annotated data matrix.") + log: bool = Field(True, description="Logarithmize data? If True, log-transform the data.") + mean_threshold: float = Field(0.01, description="Threshold for mean expression of genes.") + cv_threshold: float = Field(2, description="Threshold for coefficient of variation (CV) for gene dispersion.") + n_pcs: int = Field(50, description="Number of principal components to use.") + svd_solver: str = Field('randomized', description="SVD solver to use.") + random_state: int = Field(0, description="Random state for reproducibility of results.") + copy: bool = Field(False, description="Return a copy if True, else modifies the original AnnData.") + + class Config: + arbitrary_types_allowed = True + +class RecipeSeuratParams(BaseModel): + adata: str = Field(..., description="Annotated data matrix.") + log: bool = Field(True, description="Logarithmize data? If True, log-transform the data.") + plot: bool = Field(False, description="Show a plot of the gene dispersion vs. mean relation.") + copy: bool = Field(False, description="Return a copy if True, else modifies the original AnnData.") + + class Config: + arbitrary_types_allowed = True + +class CombatParams(BaseModel): + adata: str = Field(..., description="Annotated data matrix.") + key: str = Field('batch', description="Key to a categorical annotation from obs that will be used for batch effect removal.") + covariates: list[str] | None = Field(None, description="Additional covariates such as adjustment variables or biological conditions.") + inplace: bool = Field(True, description="Whether to replace adata.X or to return the corrected data.") + + class Config: + arbitrary_types_allowed = True + + +class ScrubletParams(BaseModel): + adata: str = Field(..., description="Annotated data matrix (n_obs × n_vars).") + adata_sim: str | None = Field(None, description="Optional AnnData object from scrublet_simulate_doublets() with same number of vars as adata.") + batch_key: str | None = Field(None, description="Optional obs column name discriminating between batches.") + sim_doublet_ratio: float = Field(2.0, description="Number of doublets to simulate relative to the number of observed transcriptomes.") + expected_doublet_rate: float = Field(0.05, description="Estimated doublet rate for the experiment.") + stdev_doublet_rate: float = Field(0.02, description="Uncertainty in the expected doublet rate.") + synthetic_doublet_umi_subsampling: float = Field(1.0, description="Rate for sampling UMIs when creating synthetic doublets.") + knn_dist_metric: str = Field('euclidean', description="Distance metric used for nearest neighbor search.") + normalize_variance: bool = Field(True, description="Normalize the data such that each gene has a variance of 1.") + log_transform: bool = Field(False, description="Whether to log-transform the data prior to PCA.") + mean_center: bool = Field(True, description="If True, center the data such that each gene has a mean of 0.") + n_prin_comps: int = Field(30, description="Number of principal components used to embed the transcriptomes prior to KNN graph construction.") + use_approx_neighbors: bool = Field(False, description="Use approximate nearest neighbor method (annoy) for KNN classifier.") + get_doublet_neighbor_parents: bool = Field(False, description="If True, return parent transcriptomes that generated the doublet neighbors.") + n_neighbors: int | None = Field(None, description="Number of neighbors used to construct the KNN graph.") + threshold: float | None = Field(None, description="Doublet score threshold for calling a transcriptome a doublet.") + verbose: bool = Field(True, description="If True, log progress updates.") + copy: bool = Field(False, description="If True, return a copy of adata with Scrublet results added.") + random_state: int = Field(0, description="Initial state for doublet simulation and nearest neighbors.") + + class Config: + arbitrary_types_allowed = True + +class ScrubletSimulateDoubletsParams(BaseModel): + adata: str = Field(..., description="Annotated data matrix of shape n_obs × n_vars. Rows correspond to cells, columns to genes.") + layer: str | None = Field(None, description="Layer of adata where raw values are stored, or 'X' if values are in .X.") + sim_doublet_ratio: float = Field(2.0, description="Number of doublets to simulate relative to the number of observed transcriptomes.") + synthetic_doublet_umi_subsampling: float = Field(1.0, description="Rate for sampling UMIs when creating synthetic doublets. If 1.0, simply add UMIs from two randomly sampled transcriptomes.") + random_seed: int = Field(0, description="Random seed for reproducibility.") + + class Config: + arbitrary_types_allowed = True diff --git a/biochatter/api_agent/scanpy_pp_reduced.py b/biochatter/api_agent/scanpy_pp_reduced.py new file mode 100644 index 00000000..d0107e84 --- /dev/null +++ b/biochatter/api_agent/scanpy_pp_reduced.py @@ -0,0 +1,288 @@ +from collections.abc import Callable +from typing import TYPE_CHECKING +from langchain_core.output_parsers import PydanticToolsParser +from biochatter.llm_connect import Conversation +from .abc import BaseAPIModel, BaseQueryBuilder, BaseTools +from typing import Union, Collection, Optional +from pydantic import BaseModel, Field + +if TYPE_CHECKING: + from biochatter.llm_connect import Conversation + + + + +SCANPY_PL_QUERY_PROMPT = """ +Scanpy Preprocessing (scanpy.pp) Query Guide + +You are a world-class algorithm for creating queries in structured formats. Your task is to use the Python API of scanpy to answer questions about the scanpy.pp (preprocessing) module. All function calls should be prefixed with scanpy.pp.. For example, to normalize the data, you should use scanpy.pp.normalize_total. + +Use the following documentation to craft queries for preprocessing tasks: +Preprocessing Functions in scanpy.pp + +These are the primary preprocessing functions in scanpy.pp used for data cleaning, transformation, and filtering. + + pp.filter_cells + Filters cells based on the minimum and maximum number of counts or genes. + + pp.filter_genes + Filters genes based on the minimum and maximum number of counts or cells in which the gene is expressed. + + pp.normalize_total + Normalizes the data by scaling each cell to a specified total count. + + pp.log1p + Applies a natural logarithm transformation to the data (adds 1 before log transformation). + + pp.regress_out + Removes the effects of unwanted sources of variation by regressing out specific variables. + + pp.scale + Scales the data by zero-centering and (optionally) scaling each feature. + + pp.subsample + Subsamples the data by randomly selecting a fraction of the observations or by a fixed number of observations. + + pp.highly_variable_genes + Identifies and selects highly variable genes based on mean and dispersion criteria. + + pp.calculate_qc_metrics + Computes quality control metrics such as the number of genes detected per cell, total counts per cell, and more. + +Specialized Preprocessing Methods + +These functions are used for specialized preprocessing workflows: + + pp.recipe_zhen17 + Implements a preprocessing recipe as described in Zheng et al., 2017 for single-cell RNA-seq data. + + pp.recipe_weinreb17 + Implements a preprocessing recipe as described in Weinreb et al., 2017 for single-cell RNA-seq data. + + pp.recipe_seurat + Implements a preprocessing recipe for integration with the Seurat workflow. + + pp.combat + Removes batch effects using the ComBat method. + + pp.scrublet + Simulates and detects doublets in the dataset. + +General Functions + + pp.dummy + Placeholder function for dummy operations or custom preprocessing routines. + + pp.projection + Projects the data into a new space after dimensionality reduction techniques like PCA. + +Use the provided documentation to craft precise queries for any preprocessing needs in Scanpy. Ensure that your function call starts with scanpy.pp. and that you include relevant parameters based on the query. + +This prompt guides the user to query the scanpy.pp module for preprocessing tasks, assisting with the construction of specific preprocessing operations, such as filtering, normalization, scaling, and more. +""" + + +class ScanpyPpFuncs(BaseTools): + tools_params = {} + + tools_params["filter_cells"] = { + "data": (str, Field(..., description="The (annotated) data matrix.")), + "min_counts": (Optional[int], Field(None, description="Minimum counts per cell.")), + "min_genes": (Optional[int], Field(None, description="Minimum genes expressed in a cell.")), + "max_counts": (Optional[int], Field(None, description="Maximum counts per cell.")), + "max_genes": (Optional[int], Field(None, description="Maximum genes expressed in a cell.")), + "inplace": (bool, Field(True, description="Whether to modify the data in place.")) + } + + tools_params["filter_genes"] = { + "data": (str, Field(..., description="The (annotated) data matrix.")), + "min_counts": (Optional[int], Field(None, description="Minimum counts per gene.")), + "min_cells": (Optional[int], Field(None, description="Minimum number of cells expressing the gene.")), + "max_counts": (Optional[int], Field(None, description="Maximum counts per gene.")), + "max_cells": (Optional[int], Field(None, description="Maximum number of cells expressing the gene.")), + "inplace": (bool, Field(True, description="Whether to modify the data in place.")) + } + + tools_params["highly_variable_genes"] = { + "adata": (str, Field(..., description="Annotated data matrix.")), + "n_top_genes": (Optional[int], Field(None, description="Number of highly-variable genes to keep.")), + "min_mean": (float, Field(0.0125, description="Minimum mean expression for highly-variable genes.")), + "max_mean": (float, Field(3, description="Maximum mean expression for highly-variable genes.")), + "flavor": (str, Field('seurat', description="Method for identifying highly-variable genes.")), + "inplace": (bool, Field(True, description="Whether to place metrics in .var or return them.")) + } + + tools_params["log1p"] = { + "data": (str, Field(..., description="The data matrix.")), + "base": (Optional[float], Field(None, description="Base of the logarithm.")), + "copy": (bool, Field(False, description="If True, return a copy_param of the data.")), + "chunked": (Optional[bool], Field(None, description="Process data in chunks.")) + } + + tools_params["pca"] = { + "data": (str, Field(..., description="The (annotated) data matrix.")), + "n_comps": (Optional[int], Field(None, description="Number of principal components to compute.")), + "layer": (Optional[str], Field(None, description="Element of layers to use for PCA.")), + "zero_center": (bool, Field(True, description="Whether to zero-center the data.")), + "svd_solver": (Optional[str], Field(None, description="SVD solver to use.")), + "copy": (bool, Field(False, description="If True, return a copy_param of the data.")) + } + + tools_params["normalize_total"] = { + "adata": (str, Field(..., description="The annotated data matrix.")), + "target_sum": (Optional[float], Field(None, description="Target sum after normalization.")), + "exclude_highly_expressed": (bool, Field(False, description="Whether to exclude highly expressed genes.")), + "inplace": (bool, Field(True, description="Whether to update adata or return normalized data.")) + } + + tools_params["regress_out"] = { + "adata": (str, Field(..., description="The annotated data matrix.")), + "keys": (Union[str, Collection[str]], Field(..., description="Keys for regression.")), + "copy": (bool, Field(False, description="If True, return a copy_param of the data.")) + } + + tools_params["scale"] = { + "data": (str, Field(..., description="The data matrix.")), + "zero_center": (bool, Field(True, description="Whether to zero-center the data.")), + "copy": (bool, Field(False, description="Whether to perform operation inplace.")) + } + + tools_params["subsample"] = { + "data": (str, Field(..., description="The data matrix.")), + "fraction": (Optional[float], Field(None, description="Fraction of observations to subsample.")), + "n_obs": (Optional[int], Field(None, description="Number of observations to subsample.")), + "copy": (bool, Field(False, description="If True, return a copy_param of the data.")) + } + + tools_params["downsample_counts"] = { + "adata": (str, Field(..., description="The annotated data matrix.")), + "counts_per_cell": (Optional[Union[int, str]], Field(None, description="Target total counts per cell.")), + "replace": (bool, Field(False, description="Whether to sample with replacement.")), + "copy": (bool, Field(False, description="If True, return a copy_param of the data.")) + } + + tools_params["combat"] = { + "adata": (str, Field(..., description="The annotated data matrix.")), + "key": (str, Field('batch', description="Key for batch effect removal.")), + "inplace": (bool, Field(True, description="Whether to replace the data inplace.")) + } + + tools_params["scrublet"] = { + "adata": (str, Field(..., description="Annotated data matrix.")), + "sim_doublet_ratio": (float, Field(2.0, description="Number of doublets to simulate.")), + "threshold": (Optional[float], Field(None, description="Doublet score threshold.")), + "copy": (bool, Field(False, description="If True, return a copy_param of the data.")) + } + + tools_params["scrublet_simulate_doublets"] = { + "adata": (str, Field(..., description="Annotated data matrix.")), + "sim_doublet_ratio": (float, Field(2.0, description="Number of doublets to simulate.")), + "random_seed": (int, Field(0, description="Random seed for reproducibility.")) + } + tools_params["calculate_qc_metrics"] = { + "adata": (str, Field(..., description="The annotated data matrix.")), + "expr_type": (str, Field('counts', description="Name of kind of values in X.")), + "var_type": (str, Field('genes', description="The kind of thing the variables are.")), + "qc_vars": (Collection[str], Field((), + description="Keys for boolean columns of .var which identify variables you could want to control for (e.g., “ERCC” or “mito”).")), + "percent_top": (Collection[int], Field((50, 100, 200, 500), + description="List of ranks at which cumulative proportion of expression will be reported as a percentage.")), + "layer": (Optional[str], Field(None, + description="If provided, use adata.layers[layer] for expression values instead of adata.X.")), + "use_raw": ( + bool, Field(False, description="If True, use adata.raw.X for expression values instead of adata.X.")), + "inplace": (bool, Field(False, description="Whether to place calculated metrics in adata’s .obs and .var.")), + "log1p": (bool, Field(True, description="Set to False to skip computing log1p transformed annotations.")) + } + + tools_params["recipe_zheng17"] = { + "adata": (str, Field(..., description="The annotated data matrix.")), + "n_top_genes": (int, Field(1000, description="Number of genes to keep.")), + "log": (bool, Field(True, description="Take logarithm of the data.")), + "plot": (bool, Field(False, description="Show a plot of the gene dispersion vs. mean relation.")), + "copy": (bool, Field(False, description="Return a copy of adata instead of updating it.")) + } + + tools_params["recipe_weinreb17"] = { + "adata": (str, Field(..., description="The annotated data matrix.")), + "log": (bool, Field(True, description="Logarithmize the data?")), + "mean_threshold": (float, Field(0.01, description="Mean expression threshold for gene selection.")), + "cv_threshold": (float, Field(2, description="Coefficient of variation threshold for gene selection.")), + "n_pcs": (int, Field(50, description="Number of principal components to compute.")), + "svd_solver": (str, Field('randomized', description="SVD solver to use for PCA.")), + "random_state": (int, Field(0, description="Random seed for reproducibility.")), + "copy": (bool, Field(False, description="Return a copy if true, otherwise modifies the original adata object.")) + } + + tools_params["recipe_seurat"] = { + "adata": (str, Field(..., description="The annotated data matrix.")), + "log": (bool, Field(True, description="Logarithmize the data?")), + "plot": (bool, Field(False, description="Show a plot of the gene dispersion vs. mean relation.")), + "copy": (bool, Field(False, description="Return a copy if true, otherwise modifies the original adata object.")) + } + + def __init__(self, tools_params: dict = tools_params): + super().__init__() + self.tools_params = tools_params + +class ScanpyPpQueryBuilder(BaseQueryBuilder): + """A class for building a ScanpyPp query object.""" + + def create_runnable( + self, + query_parameters: list["BaseAPIModel"], + conversation: "Conversation", + ) -> Callable: + """Create a runnable object for executing queries. + + Create runnable using the LangChain `create_structured_output_runnable` + method. + + Args: + ---- + query_parameters: A Pydantic data model that specifies the fields of + the API that should be queried. + + conversation: A BioChatter conversation object. + + Returns: + ------- + A Callable object that can execute the query. + + """ + runnable = conversation.chat.bind_tools(query_parameters) + return runnable | PydanticToolsParser(tools=query_parameters) + + def parameterise_query( + self, + question: str, + conversation: "Conversation", + ) -> list["BaseModel"]: + """Generate a ScanpyPp query object. + + Generates the object based on the given question, prompt, and + BioChatter conversation. Uses a Pydantic model to define the API fields. + Creates a runnable that can be invoked on LLMs that are qualified to + parameterise functions. + + Args: + ---- + question (str): The question to be answered. + + conversation: The conversation object used for parameterising the + ScanpyPpQuery. + + Returns: + ------- + ScanpyPpQuery: the parameterised query object (Pydantic model) + + """ + tool_maker = ScanpyPpFuncs() + tools = tool_maker.make_pydantic_tools() + runnable = self.create_runnable( + conversation=conversation, query_parameters=tools + ) + scanpy_pp_call_obj = runnable.invoke( + question, + ) + return scanpy_pp_call_obj diff --git a/poetry.lock b/poetry.lock index 89b91b01..5f3709a9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "absl-py" @@ -3437,6 +3437,7 @@ description = "Nvidia JIT LTO Library" optional = true python-versions = ">=3" files = [ + {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, ] @@ -4951,29 +4952,29 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.8.2" +version = "0.8.3" description = "An extremely fast Python linter and code formatter, written in Rust." optional = true python-versions = ">=3.7" files = [ - {file = "ruff-0.8.2-py3-none-linux_armv6l.whl", hash = "sha256:c49ab4da37e7c457105aadfd2725e24305ff9bc908487a9bf8d548c6dad8bb3d"}, - {file = "ruff-0.8.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ec016beb69ac16be416c435828be702ee694c0d722505f9c1f35e1b9c0cc1bf5"}, - {file = "ruff-0.8.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f05cdf8d050b30e2ba55c9b09330b51f9f97d36d4673213679b965d25a785f3c"}, - {file = "ruff-0.8.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:60f578c11feb1d3d257b2fb043ddb47501ab4816e7e221fbb0077f0d5d4e7b6f"}, - {file = "ruff-0.8.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cbd5cf9b0ae8f30eebc7b360171bd50f59ab29d39f06a670b3e4501a36ba5897"}, - {file = "ruff-0.8.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b402ddee3d777683de60ff76da801fa7e5e8a71038f57ee53e903afbcefdaa58"}, - {file = "ruff-0.8.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:705832cd7d85605cb7858d8a13d75993c8f3ef1397b0831289109e953d833d29"}, - {file = "ruff-0.8.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:32096b41aaf7a5cc095fa45b4167b890e4c8d3fd217603f3634c92a541de7248"}, - {file = "ruff-0.8.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e769083da9439508833cfc7c23e351e1809e67f47c50248250ce1ac52c21fb93"}, - {file = "ruff-0.8.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fe716592ae8a376c2673fdfc1f5c0c193a6d0411f90a496863c99cd9e2ae25d"}, - {file = "ruff-0.8.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:81c148825277e737493242b44c5388a300584d73d5774defa9245aaef55448b0"}, - {file = "ruff-0.8.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d261d7850c8367704874847d95febc698a950bf061c9475d4a8b7689adc4f7fa"}, - {file = "ruff-0.8.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1ca4e3a87496dc07d2427b7dd7ffa88a1e597c28dad65ae6433ecb9f2e4f022f"}, - {file = "ruff-0.8.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:729850feed82ef2440aa27946ab39c18cb4a8889c1128a6d589ffa028ddcfc22"}, - {file = "ruff-0.8.2-py3-none-win32.whl", hash = "sha256:ac42caaa0411d6a7d9594363294416e0e48fc1279e1b0e948391695db2b3d5b1"}, - {file = "ruff-0.8.2-py3-none-win_amd64.whl", hash = "sha256:2aae99ec70abf43372612a838d97bfe77d45146254568d94926e8ed5bbb409ea"}, - {file = "ruff-0.8.2-py3-none-win_arm64.whl", hash = "sha256:fb88e2a506b70cfbc2de6fae6681c4f944f7dd5f2fe87233a7233d888bad73e8"}, - {file = "ruff-0.8.2.tar.gz", hash = "sha256:b84f4f414dda8ac7f75075c1fa0b905ac0ff25361f42e6d5da681a465e0f78e5"}, + {file = "ruff-0.8.3-py3-none-linux_armv6l.whl", hash = "sha256:8d5d273ffffff0acd3db5bf626d4b131aa5a5ada1276126231c4174543ce20d6"}, + {file = "ruff-0.8.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e4d66a21de39f15c9757d00c50c8cdd20ac84f55684ca56def7891a025d7e939"}, + {file = "ruff-0.8.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c356e770811858bd20832af696ff6c7e884701115094f427b64b25093d6d932d"}, + {file = "ruff-0.8.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c0a60a825e3e177116c84009d5ebaa90cf40dfab56e1358d1df4e29a9a14b13"}, + {file = "ruff-0.8.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:75fb782f4db39501210ac093c79c3de581d306624575eddd7e4e13747e61ba18"}, + {file = "ruff-0.8.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f26bc76a133ecb09a38b7868737eded6941b70a6d34ef53a4027e83913b6502"}, + {file = "ruff-0.8.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:01b14b2f72a37390c1b13477c1c02d53184f728be2f3ffc3ace5b44e9e87b90d"}, + {file = "ruff-0.8.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:53babd6e63e31f4e96ec95ea0d962298f9f0d9cc5990a1bbb023a6baf2503a82"}, + {file = "ruff-0.8.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1ae441ce4cf925b7f363d33cd6570c51435972d697e3e58928973994e56e1452"}, + {file = "ruff-0.8.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7c65bc0cadce32255e93c57d57ecc2cca23149edd52714c0c5d6fa11ec328cd"}, + {file = "ruff-0.8.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5be450bb18f23f0edc5a4e5585c17a56ba88920d598f04a06bd9fd76d324cb20"}, + {file = "ruff-0.8.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8faeae3827eaa77f5721f09b9472a18c749139c891dbc17f45e72d8f2ca1f8fc"}, + {file = "ruff-0.8.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:db503486e1cf074b9808403991663e4277f5c664d3fe237ee0d994d1305bb060"}, + {file = "ruff-0.8.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:6567be9fb62fbd7a099209257fef4ad2c3153b60579818b31a23c886ed4147ea"}, + {file = "ruff-0.8.3-py3-none-win32.whl", hash = "sha256:19048f2f878f3ee4583fc6cb23fb636e48c2635e30fb2022b3a1cd293402f964"}, + {file = "ruff-0.8.3-py3-none-win_amd64.whl", hash = "sha256:f7df94f57d7418fa7c3ffb650757e0c2b96cf2501a0b192c18e4fb5571dfada9"}, + {file = "ruff-0.8.3-py3-none-win_arm64.whl", hash = "sha256:fe2756edf68ea79707c8d68b78ca9a58ed9af22e430430491ee03e718b5e4936"}, + {file = "ruff-0.8.3.tar.gz", hash = "sha256:5e7558304353b84279042fc584a4f4cb8a07ae79b2bf3da1a7551d960b5626d3"}, ] [[package]] diff --git a/test/test_api_agent.py b/test/test_api_agent.py index 55973278..6121ae4f 100644 --- a/test/test_api_agent.py +++ b/test/test_api_agent.py @@ -12,7 +12,10 @@ BaseInterpreter, BaseQueryBuilder, ) + from biochatter.api_agent.anndata_agent import AnnDataIOQueryBuilder, ANNDATA_IO_QUERY_PROMPT +from biochatter.api_agent.scanpy_pp_reduced import ScanpyPpQueryBuilder + from biochatter.api_agent.api_agent import APIAgent from biochatter.api_agent.blast import ( BLAST_QUERY_PROMPT, @@ -543,6 +546,102 @@ def test_parameterise_query(self, mock_create_runnable): assert result == mock_query_obj +class TestScanpyPpQueryBuilder: + @pytest.fixture + def mock_create_runnable(self): + with patch( + "biochatter.api_agent.scanpy_pp_reduced.ScanpyPpQueryBuilder.create_runnable", + ) as mock: + mock_runnable = MagicMock() + mock.return_value = mock_runnable + yield mock_runnable + + @patch("biochatter.llm_connect.GptConversation") + def test_create_runnable(self, mock_conversation): + # Mock the list of Pydantic classes as a list of Mock objects + class MockTool1(BaseModel): + param1: str + + class MockTool2(BaseModel): + param2: int + + mock_generated_classes = [MockTool1, MockTool2] + + # Mock the conversation object and LLM + mock_conversation_instance = mock_conversation.return_value + mock_llm = MagicMock() + mock_conversation_instance.chat = mock_llm + + # Mock the LLM with tools + mock_llm_with_tools = MagicMock() + mock_llm.bind_tools.return_value = mock_llm_with_tools + + # Mock the chain + mock_chain = MagicMock() + mock_llm_with_tools.__or__.return_value = mock_chain + + # Act + builder = AnnDataIOQueryBuilder() + result = builder.create_runnable( + query_parameters=mock_generated_classes, + conversation=mock_conversation_instance, + ) + + # Assert + mock_llm.bind_tools.assert_called_once_with(mock_generated_classes, tool_choice="required") + mock_llm_with_tools.__or__.assert_called_once_with( + PydanticToolsParser(tools=mock_generated_classes), + ) + # Verify the returned chain + assert result == mock_chain + + def test_parameterise_query(self, mock_create_runnable): + # Arrange + query_builder = ScanpyPpQueryBuilder() + mock_conversation = MagicMock() + question = "I want to use scanpy pp to filter cells with at least 200 genes" + expected_input = [("system", ANNDATA_IO_QUERY_PROMPT), ("human", question)] + mock_query_obj = MagicMock() + mock_create_runnable.invoke.return_value = mock_query_obj + + # Act + result = query_builder.parameterise_query(question, mock_conversation) + + # Assert + mock_create_runnable.invoke.assert_called_once_with(expected_input) + assert result == mock_query_obj + + +class TestScanpyPpQueryBuilder: + @pytest.fixture + def mock_create_runnable(self): + with patch( + "biochatter.api_agent.scanpy_pp_reduced.ScanpyPpQueryBuilder.create_runnable", + ) as mock: + mock_runnable = MagicMock() + mock.return_value = mock_runnable + yield mock_runnable + + def test_create_runnable(self): + pass + + def test_parameterise_query(self, mock_create_runnable): + # Arrange + query_builder = ScanpyPpQueryBuilder() + mock_conversation = MagicMock() + question = "I want to use scanpy pp to filter cells with at least 200 genes" + expected_input = f"{question}" + mock_query_obj = MagicMock() + mock_create_runnable.invoke.return_value = mock_query_obj + + # Act + result = query_builder.parameterise_query(question, mock_conversation) + + # Assert + mock_create_runnable.invoke.assert_called_once_with(expected_input) + assert result == mock_query_obj + + class TestScanpyTlQueryBuilder: @pytest.fixture def mock_create_runnable(self):