-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] add pipeline configuration/structure #3
Changes from 18 commits
6fe2bfb
bbbfc9d
60c0978
89d45f3
5e8bafc
90bee39
cb16fc9
58bc727
cfff8bb
79bfdcc
0625124
c26ef1b
e907624
85b20dd
ddd5b67
23bb537
c99c83f
a78f241
cdbdec2
d1e2a31
cd5bb83
99095fb
ef0b25d
0839a6e
585fc21
5cdfc6d
306e9ec
693cb76
c3b5767
8e7152f
20af580
5cff6be
194e9b1
6f45fba
b6e26b0
08e534a
e8108fd
c366e61
e1fcd2b
01a70f0
ce537b8
44ad3c6
35c09aa
8c5237f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +0,0 @@ | ||
from .run import __main__ as run | ||
|
||
__all__ = ['run'] | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,96 +1,51 @@ | ||
""" Extract participant demographics from HTML files. """ | ||
import os | ||
|
||
from publang.extract import extract_from_text | ||
from openai import OpenAI | ||
from pathlib import Path | ||
import json | ||
import pandas as pd | ||
import logging | ||
|
||
from . import prompts | ||
from .clean import clean_predictions | ||
from .clean import clean_prediction | ||
|
||
from ns_pipelines.pipeline import IndependentPipeline | ||
|
||
def extract(extraction_model, extraction_client, docs, output_dir, prompt_set='', **extract_kwargs): | ||
|
||
def extract(extraction_model, extraction_client, text, prompt_set='', **extract_kwargs): | ||
extract_kwargs.pop('search_query', None) | ||
|
||
# Extract | ||
predictions = extract_from_text( | ||
docs['body'].to_list(), | ||
model=extraction_model, client=extraction_client, | ||
text, | ||
model=extraction_model, | ||
client=extraction_client, | ||
**extract_kwargs | ||
) | ||
|
||
# Add PMCID to predictions | ||
for i, pred in enumerate(predictions): | ||
if not pred: | ||
logging.warning(f"No prediction for document {docs['pmid'].iloc[i]}") | ||
continue | ||
pred['pmid'] = int(docs['pmid'].iloc[i]) | ||
if not predictions: | ||
logging.warning("No predictions found.") | ||
return None, None | ||
|
||
clean_preds = clean_predictions(predictions) | ||
clean_preds = clean_prediction(predictions) | ||
|
||
return predictions, clean_preds | ||
|
||
|
||
def _load_client(model_name): | ||
if 'gpt' in model_name: | ||
client = OpenAI(api_key=os.getenv('MYOPENAI_API_KEY')) | ||
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason I had it this way was because if the environment variable is set to So I wanted to have the option to not pass that key to OpenAI. Specifically, this is for when you want to use the OpenAI client for another API (such as OpenRouter). So what I would do is add which API key to use as a configuration parameter, and in the production environment name it something else that is not |
||
|
||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In principle we can run this using other API keys and hence other model names, so perhaps let's not worry about validation here |
||
raise ValueError(f"Model {model_name} not supported") | ||
|
||
return client | ||
|
||
|
||
def _load_prompt_config(prompt_set): | ||
return getattr(prompts, prompt_set) | ||
|
||
def _save_predictions(predictions, clean_preds, extraction_model, prompt_set, output_dir): | ||
short_model_name = extraction_model.split('/')[-1] | ||
outname = f"{prompt_set}_{short_model_name}" | ||
predictions_path = output_dir / f'{outname}.json' | ||
clean_predictions_path = output_dir / f'{outname}_clean.csv' | ||
|
||
json.dump(predictions, predictions_path.open('w')) | ||
|
||
clean_preds.to_csv( | ||
clean_predictions_path, index=False | ||
) | ||
|
||
def __main__(extraction_model, docs_path, prompt_set, output_dir=None, **kwargs): | ||
""" Run the participant demographics extraction pipeline. | ||
|
||
Args: | ||
extraction_model (str): The model to use for extraction. | ||
docs_path (str): The path to the csv file containing the documents. | ||
prompt_set (str): The prompt set to use for the extraction. | ||
output_dir (str): The directory to save the output files. | ||
**kwargs: Additional keyword arguments to pass to the extraction function. | ||
""" | ||
|
||
docs = pd.read_csv(docs_path) | ||
|
||
extraction_client = _load_client(extraction_model) | ||
|
||
prompt_config = _load_prompt_config(prompt_set) | ||
if kwargs is not None: | ||
prompt_config.update(kwargs) | ||
|
||
output_dir = Path(output_dir) | ||
|
||
predictions, clean_preds = extract( | ||
extraction_model, extraction_client, docs, | ||
**prompt_config | ||
) | ||
|
||
if output_dir is not None: | ||
_save_predictions(predictions, clean_preds, extraction_model, prompt_set, output_dir) | ||
|
||
return predictions, clean_preds | ||
|
||
|
||
def ParticipantDemographics(IndependentPipeline): | ||
class ParticipantDemographicsExtraction(IndependentPipeline): | ||
"""Participant demographics extraction pipeline.""" | ||
|
||
_version = "1.0.0" | ||
|
@@ -100,7 +55,8 @@ def ParticipantDemographics(IndependentPipeline): | |
def __init__( | ||
self, | ||
extraction_model, | ||
prompt_set, inputs=("text",), | ||
prompt_set, | ||
inputs=("text",), | ||
input_sources=("pubget", "ace"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I would do is add the key as part of the Later on, we could define a base class There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i.e. the For now just make the key a config parameter and rename the value of the key something else. |
||
**kwargs | ||
): | ||
|
@@ -117,15 +73,16 @@ def _run(self, study_inputs, n_cpus=1): | |
if self.kwargs is not None: | ||
prompt_config.update(self.kwargs) | ||
|
||
with open(study_inputs["text"]) as f: | ||
text = f.read() | ||
|
||
predictions, clean_preds = extract( | ||
self.extraction_model, | ||
extraction_client, | ||
study_inputs["text"], | ||
text, | ||
prompt_set=self.prompt_set, | ||
**prompt_config | ||
) | ||
|
||
# Save predictions | ||
|
||
return {"predictions": predictions, "clean_predictions": clean_preds} |
This file was deleted.
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if given that we now have a class inside
run.py
with arun()
method, if we should renamerun.py
to something else?pipeline.py
?