Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch vocab mapper #169

Merged
merged 4 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/cold-drinks-sing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"apollo": minor
---

Add batching to the vocab mapper
14 changes: 0 additions & 14 deletions services/vocab_mapper/tools.py

This file was deleted.

267 changes: 185 additions & 82 deletions services/vocab_mapper/vocab_mapper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict

import anthropic
import pandas as pd
import logging

from vocab_mapper.prompts import *
from vocab_mapper.tools import process_inputs
from vocab_mapper.dataset_tools import format_google_sheets_input, format_google_sheets_output
from util import create_logger, ApolloError

Expand All @@ -16,12 +18,20 @@ class VocabMapper:
def __init__(self,
anthropic_api_key: str,
vectorstore,
dataset: pd.DataFrame):
dataset: pd.DataFrame,
batch_size: int = 30,
max_concurrent_calls: int = 25):
"""Initialize the vocab mapper."""
self.client = anthropic.Anthropic(api_key=anthropic_api_key)
self.vectorstore = vectorstore
self.dataset = dataset
self.loinc_num_dict = dict(zip(dataset.LONG_COMMON_NAME, dataset.LOINC_NUM))
self.batch_size = batch_size
self.max_concurrent_calls = max_concurrent_calls

def _batch_iterator(self, items, batch_size):
for i in range(0, len(items), batch_size):
yield items[i:i + batch_size]

def _call_llm(self, system_prompt: str, user_prompt: str) -> str:
"""Helper method to make LLM calls."""
Expand All @@ -44,93 +54,175 @@ def _call_llm(self, system_prompt: str, user_prompt: str) -> str:
)
return message.content[0].text

def get_expanded_terms(self, input_text: str, general_info: str, specific_info: str) -> str:
"""Step 1: Get expanded list of possible terms."""
user_prompt = EXPANSION_USER_PROMPT.format(
input_text=input_text,
general_info=general_info,
specific_info=specific_info
)
return self._call_llm(EXPANSION_SYSTEM_PROMPT, user_prompt)

def search_database(self, expanded_terms: str) -> list:
"""Step 2: Search the database for the expanded terms."""
# Vector search
vector_results = []
for guess in expanded_terms.split("\n"):
results = self.vectorstore.search(guess, search_kwargs={"k": 10})
vector_results.extend(results)

# Keyword search
keyword_results = []
for guess in expanded_terms.split("\n"):
def _call_llm_batch(self, system_prompt: str, user_prompts: List[str]) -> List[str]:
"""Process a batch of LLM inputs concurrently."""
with ThreadPoolExecutor(max_workers=self.max_concurrent_calls) as executor:
futures = [
executor.submit(self._call_llm, system_prompt, prompt)
for prompt in user_prompts
]
return [future.result() for future in futures]

def _get_expanded_terms(self, inputs: List[Dict[str, str]]) -> List[str]:
"""Process a batch of inputs for term expansion."""
user_prompts = [
EXPANSION_USER_PROMPT.format(
input_text=input_data["input_term"],
general_info=input_data.get("general_info", ""),
specific_info=input_data.get("specific_info", "")
)
for input_data in inputs
]
return self._call_llm_batch(EXPANSION_SYSTEM_PROMPT, user_prompts)

def _search_database(self, expanded_terms: str) -> list:
"""Search the database for expanded terms."""

def process_term(guess):
results = []
# Vector search
vector_results = self.vectorstore.search(guess, search_kwargs={"k": 10})
results.extend(vector_results)

return results

# Process all terms from a single input's expanded_terms in parallel
terms = expanded_terms.split("\n")
with ThreadPoolExecutor(max_workers=8) as executor:
results = list(executor.map(process_term, terms))

return [item for sublist in results for item in sublist]

def _search_database_batch(self, expanded_terms_list: List[str]) -> List[list]:
with ThreadPoolExecutor(max_workers=4) as executor:
futures = [
executor.submit(self._search_database, expanded_terms)
for expanded_terms in expanded_terms_list
]
return [future.result() for future in futures]

def _keyword_search(self, expanded_terms: str) -> list:
"""Search the dataset with a keyword-based search."""

def process_term(guess):
results = []

# Keyword search
matches = self.dataset[
self.dataset["LONG_COMMON_NAME"].str.lower().str.contains(guess.lower())
self.dataset["LONG_COMMON_NAME_LOWER"].str.contains(guess.lower())
].LONG_COMMON_NAME.to_list()[:100]
keyword_results.extend([{

results.extend([{
"text": json.dumps({
"LONG_COMMON_NAME": s,
"LOINC_NUM": self.loinc_num_dict.get(s)
}),
"metadata": {},
"score": None
} for s in matches])

return keyword_results + vector_results

def get_shortlist(self, input_text: str, general_info: str, specific_info: str,
expanded_terms: str, search_results: list) -> str:
"""Step 3: Get a shortlist of the best matches."""
user_prompt = SHORTLIST_USER_PROMPT.format(
input_text=input_text,
general_info=general_info,
specific_info=specific_info,
expanded_terms=expanded_terms,
search_results=search_results
)
return self._call_llm(SHORTLIST_SYSTEM_PROMPT, user_prompt)

def get_final_selection(self, input_text: str, general_info: str, specific_info: str,
expanded_terms: str, search_results: list) -> str:
"""Step 4: Get the best match."""
user_prompt = FINAL_SELECTION_USER_PROMPT.format(
input_text=input_text,
general_info=general_info,
specific_info=specific_info,
expanded_terms=expanded_terms,
search_results=search_results
)
return self._call_llm(FINAL_SELECTION_SYSTEM_PROMPT, user_prompt)

def map_term(self, input_term: str, general_info: str = "", specific_info: str = "") -> dict:
"""Map an input term using the target dataset."""
# Step 1: Get expanded terms
expanded_terms = self.get_expanded_terms(input_term, general_info, specific_info)

# Step 2: Search database
search_results = self.search_database(expanded_terms)

# Step 3: Get shortlist of best terms
shortlist = self.get_shortlist(
input_term, general_info, specific_info,
expanded_terms, search_results
return results

# Process all terms from a single input's expanded_terms in parallel
terms = expanded_terms.split("\n")
with ThreadPoolExecutor(max_workers=6) as executor:
results = list(executor.map(process_term, terms))

return [item for sublist in results for item in sublist]

def _get_shortlist(self, inputs: List[Dict], expanded_terms_list: List[str],
search_results_list: List[list]) -> List[str]:
"""Process a batch of inputs for shortlisting."""
user_prompts = [
SHORTLIST_USER_PROMPT.format(
input_text=input_data["input_term"],
general_info=input_data.get("general_info", ""),
specific_info=input_data.get("specific_info", ""),
expanded_terms=expanded_terms,
search_results=search_results
)
for input_data, expanded_terms, search_results
in zip(inputs, expanded_terms_list, search_results_list)
]
return self._call_llm_batch(SHORTLIST_SYSTEM_PROMPT, user_prompts)

def _get_final_selection(self, inputs: List[Dict], expanded_terms_list: List[str],
search_results_list: List[list]) -> List[str]:
"""Process a batch of inputs for final selection."""
user_prompts = [
FINAL_SELECTION_USER_PROMPT.format(
input_text=input_data["input_term"],
general_info=input_data.get("general_info", ""),
specific_info=input_data.get("specific_info", ""),
expanded_terms=expanded_terms,
search_results=search_results
)
for input_data, expanded_terms, search_results
in zip(inputs, expanded_terms_list, search_results_list)
]
return self._call_llm_batch(FINAL_SELECTION_SYSTEM_PROMPT, user_prompts)

def _map_terms_batch(self, inputs: List[Dict]) -> List[Dict]:
"""Map a batch of input terms using the target dataset."""
# Step 1: Get expanded terms for the batch
logger.info(f"Generating search terms based on inputs")
expanded_terms_list = self._get_expanded_terms(inputs)

# Step 2: Search database for all expanded terms
logger.info(f"Searching vector database")
database_results_list = self._search_database_batch(expanded_terms_list)
logger.info(f"Searching by keywords")
keyword_results_list = [self._keyword_search(expanded_terms) for expanded_terms in expanded_terms_list]
search_results_list = [vec + key for vec, key in zip(keyword_results_list, database_results_list)]

# Step 3: Get shortlists of best terms for the batch
logger.info(f"Generating a shortlist of best target terms")
shortlist_list = self._get_shortlist(
inputs, expanded_terms_list, search_results_list
)
# Step 4: Select the best term (from the full search results)
final_selection = self.get_final_selection(
input_term, general_info, specific_info,
expanded_terms, search_results

# Step 4: Select the best terms for the batch
logger.info(f"Generating the best target term")
final_selection_list = self._get_final_selection(
inputs, expanded_terms_list, search_results_list
)

# Return all results
return {
"expanded_terms": expanded_terms,
# "search_results": search_results, # large - can add back in for debugging
"shortlist": shortlist,
"final_selection": final_selection
}
return [
{
"expanded_terms": expanded_terms,
"shortlist": shortlist,
"final_selection": final_selection
}
for expanded_terms, shortlist, final_selection
in zip(expanded_terms_list, shortlist_list, final_selection_list)
]

def _preprocess_dataset(self):
"""Preprocess dataset to keep only necessary columns and add lowercased names for search."""
# Keep only needed columns
self.dataset = self.dataset[['LONG_COMMON_NAME', 'LOINC_NUM']]
# Add lowercase column
self.dataset['LONG_COMMON_NAME_LOWER'] = self.dataset.LONG_COMMON_NAME.str.lower()

def map_terms(self, input_data):
"""Process a list of inputs in batches."""
logger.info(f"Preprocessing dataset")
self._preprocess_dataset()

logger.info(f"Starting mapping")
results = []
for batch in self._batch_iterator(input_data, self.batch_size):
batch_results = self._map_terms_batch(batch)
results.extend([
{
'input': input_row,
'mapping': mapping
}
for input_row, mapping in zip(batch, batch_results)
])
logger.info(f"Finished mapping")
return results


def main(data):
"""Map vocab with the VocabMapper using Google Sheets input data. Format output for Google Sheets."""
Expand Down Expand Up @@ -164,18 +256,29 @@ def main(data):
logger.error(msg)
raise ApolloError(500, f"Missing API keys: {', '.join(missing_keys)}", type="BAD_REQUEST")

# Initialize mapper
loinc_df = load_dataset("awacke1/LOINC-Clinical-Terminology")
loinc_df = pd.DataFrame(loinc_df['train'])
# Get dataset
logger.info(f"Getting the dataset")
os.makedirs("tmp", exist_ok=True)

if os.path.exists("tmp/loinc_dataset.csv"):
loinc_df = pd.read_csv("tmp/loinc_dataset.csv")
else:
loinc_df = pd.DataFrame(load_dataset("awacke1/LOINC-Clinical-Terminology")['train'])
loinc_df.to_csv("tmp/loinc_dataset.csv", index=False)

vectorstore = loinc_store.connect_loinc()

# Initialize mapper
mapper = VocabMapper(
anthropic_api_key=ANTHROPIC_API_KEY,
vectorstore=vectorstore,
dataset=loinc_df
dataset=loinc_df,
batch_size=25,
max_concurrent_calls=2
)

# Process the inputs
mapping_results = process_inputs(input_data, mapper)
mapping_results = mapper.map_terms(input_data)
logger.info(f"Mapping results: {mapping_results}")

# Format results back to Google Sheets format
Expand Down