Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] add pipeline configuration/structure #3

Merged
merged 44 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
6fe2bfb
adding testing
jdkent Oct 22, 2024
bbbfc9d
add dependent pipeline
jdkent Oct 24, 2024
60c0978
mark pipeline as (in)dependent
jdkent Oct 24, 2024
89d45f3
wip: start modifying the existing pipeline
jdkent Oct 24, 2024
5e8bafc
merge in new changes
jdkent Oct 24, 2024
90bee39
Restructure package
adelavega Oct 30, 2024
cb16fc9
add filter_inputs function
jdkent Oct 30, 2024
58bc727
Refactor init logic to dataclasses
adelavega Oct 30, 2024
cfff8bb
Both group and independent can use the same function name ('function'…
adelavega Oct 30, 2024
79bfdcc
group_function to function
adelavega Oct 30, 2024
0625124
_hash_attrs instead
adelavega Oct 30, 2024
c26ef1b
Set default _hash_attrs
adelavega Oct 30, 2024
e907624
refactor based on feedback
jdkent Oct 31, 2024
85b20dd
add pipeline name to output path
jdkent Oct 31, 2024
ddd5b67
wip: modify readme
jdkent Oct 31, 2024
23bb537
fix merge
jdkent Oct 31, 2024
c99c83f
add tests dependencies
jdkent Oct 31, 2024
a78f241
add test for participant demographics
jdkent Nov 14, 2024
cdbdec2
opensource data
jdkent Nov 15, 2024
d1e2a31
remove old functions
jdkent Nov 15, 2024
cd5bb83
commit the cassette
jdkent Nov 15, 2024
99095fb
add dependencies
jdkent Nov 15, 2024
ef0b25d
allow installable pyproject
jdkent Nov 15, 2024
0839a6e
move test directory and remove top level __init__
jdkent Nov 15, 2024
585fc21
try underscores
jdkent Nov 15, 2024
5cdfc6d
Revert "allow installable pyproject"
jdkent Nov 15, 2024
306e9ec
Revert "Revert "allow installable pyproject""
jdkent Nov 15, 2024
693cb76
Revert "try underscores"
jdkent Nov 15, 2024
c3b5767
Revert "move test directory and remove top level __init__"
jdkent Nov 15, 2024
8e7152f
remove init
jdkent Nov 15, 2024
20af580
remove old files
jdkent Nov 15, 2024
5cff6be
switch to version 5
jdkent Nov 15, 2024
194e9b1
use editable install
jdkent Nov 15, 2024
6f45fba
trigger variable
jdkent Nov 15, 2024
b6e26b0
add fake key
jdkent Nov 15, 2024
08e534a
Update ns_pipelines/word_count/run.py
jdkent Nov 16, 2024
e8108fd
Update ns_pipelines/participant_demographics/run.py
jdkent Nov 16, 2024
c366e61
Update ns_pipelines/word_count/run.py
jdkent Nov 18, 2024
e1fcd2b
Update ns_pipelines/word_count/run.py
jdkent Nov 18, 2024
01a70f0
Update ns_pipelines/participant_demographics/run.py
jdkent Nov 18, 2024
ce537b8
Update ns_pipelines/word_count/run.py
jdkent Nov 19, 2024
44ad3c6
Update ns_pipelines/word_count/run.py
jdkent Nov 19, 2024
35c09aa
change the names
jdkent Nov 21, 2024
8c5237f
work with .keys file
jdkent Nov 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ jobs:
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.8'

- name: Install dependencies
run: pip install .[tests]
run: pip install -e .[tests,participant_demographics,word_count]

- name: Test with pytest
env:
OPENAI_API_KEY: "fake_key"
run: pytest
Empty file removed __init__.py
Empty file.
3 changes: 0 additions & 3 deletions ns_pipelines/participant_demographics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .run import __main__ as run

__all__ = ['run']
47 changes: 23 additions & 24 deletions ns_pipelines/participant_demographics/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,48 @@
import numpy as np


def clean_predictions(predictions):
# Clean known issues with GPT demographics predictions
predictions = [p for p in predictions if p and "groups" in p]
def clean_prediction(prediction):
# Clean known issues with GPT demographics prediction

meta_keys = ["pmid", "rank", "start_char", "end_char", "id"]
meta_keys = [k for k in meta_keys if k in predictions[0]]
meta_keys = [k for k in meta_keys if k in prediction]

# Convert JSON to DataFrame
predictions = pd.json_normalize(
predictions, record_path=["groups"],
prediction = pd.json_normalize(
prediction, record_path=["groups"],
meta=meta_keys
)

predictions.columns = predictions.columns.str.replace(' ', '_')
prediction.columns = prediction.columns.str.replace(' ', '_')

predictions = predictions.fillna(value=np.nan)
predictions["group_name"] = predictions["group_name"].fillna("healthy")
prediction = prediction.fillna(value=np.nan)
prediction["group_name"] = prediction["group_name"].fillna("healthy")

# Drop rows where count is NA
predictions = predictions[~pd.isna(predictions["count"])]
prediction = prediction[~pd.isna(prediction["count"])]

# Set group_name to healthy if no diagnosis
predictions.loc[
(predictions["group_name"] != "healthy") & (pd.isna(predictions["diagnosis"])),
prediction.loc[
(prediction["group_name"] != "healthy") & (pd.isna(prediction["diagnosis"])),
"group_name",
] = "healthy"

# If no male count, substract count from female count columns
ix_male_miss = (pd.isna(predictions["male_count"])) & ~(
pd.isna(predictions["female_count"])
ix_male_miss = (pd.isna(prediction["male_count"])) & ~(
pd.isna(prediction["female_count"])
)
predictions.loc[ix_male_miss, "male_count"] = (
predictions.loc[ix_male_miss, "count"]
- predictions.loc[ix_male_miss, "female_count"]
prediction.loc[ix_male_miss, "male_count"] = (
prediction.loc[ix_male_miss, "count"]
- prediction.loc[ix_male_miss, "female_count"]
)

# Same for female count
ix_female_miss = (pd.isna(predictions["female_count"])) & ~(
pd.isna(predictions["male_count"])
ix_female_miss = (pd.isna(prediction["female_count"])) & ~(
pd.isna(prediction["male_count"])
)
predictions.loc[ix_female_miss, "female_count"] = (
predictions.loc[ix_female_miss, "count"]
- predictions.loc[ix_female_miss, "male_count"]
prediction.loc[ix_female_miss, "female_count"] = (
prediction.loc[ix_female_miss, "count"]
- prediction.loc[ix_female_miss, "male_count"]
)

return predictions
return {"groups": prediction.to_dict(orient="records")}
81 changes: 19 additions & 62 deletions ns_pipelines/participant_demographics/run.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,51 @@
""" Extract participant demographics from HTML files. """
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if given that we now have a class inside run.py with a run() method, if we should rename run.py to something else? pipeline.py?

import os

from publang.extract import extract_from_text
from openai import OpenAI
from pathlib import Path
import json
import pandas as pd
import logging

from . import prompts
from .clean import clean_predictions
from .clean import clean_prediction

from ns_pipelines.pipeline import IndependentPipeline

def extract(extraction_model, extraction_client, docs, output_dir, prompt_set='', **extract_kwargs):

def extract(extraction_model, extraction_client, text, prompt_set='', **extract_kwargs):
extract_kwargs.pop('search_query', None)

# Extract
predictions = extract_from_text(
docs['body'].to_list(),
model=extraction_model, client=extraction_client,
text,
model=extraction_model,
client=extraction_client,
**extract_kwargs
)

# Add PMCID to predictions
for i, pred in enumerate(predictions):
if not pred:
logging.warning(f"No prediction for document {docs['pmid'].iloc[i]}")
continue
pred['pmid'] = int(docs['pmid'].iloc[i])
if not predictions:
logging.warning("No predictions found.")
return None, None

clean_preds = clean_predictions(predictions)
clean_preds = clean_prediction(predictions)

return predictions, clean_preds


def _load_client(model_name):
if 'gpt' in model_name:
client = OpenAI(api_key=os.getenv('MYOPENAI_API_KEY'))
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I had it this way was because if the environment variable is set to OPENAI_API_KEY it will automatically be ingested by OpenAI and thus this is not necessary.

So I wanted to have the option to not pass that key to OpenAI.

Specifically, this is for when you want to use the OpenAI client for another API (such as OpenRouter).

So what I would do is add which API key to use as a configuration parameter, and in the production environment name it something else that is not OPENAI_API_KEY


else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle we can run this using other API keys and hence other model names, so perhaps let's not worry about validation here

raise ValueError(f"Model {model_name} not supported")

return client


def _load_prompt_config(prompt_set):
return getattr(prompts, prompt_set)

def _save_predictions(predictions, clean_preds, extraction_model, prompt_set, output_dir):
short_model_name = extraction_model.split('/')[-1]
outname = f"{prompt_set}_{short_model_name}"
predictions_path = output_dir / f'{outname}.json'
clean_predictions_path = output_dir / f'{outname}_clean.csv'

json.dump(predictions, predictions_path.open('w'))

clean_preds.to_csv(
clean_predictions_path, index=False
)

def __main__(extraction_model, docs_path, prompt_set, output_dir=None, **kwargs):
""" Run the participant demographics extraction pipeline.

Args:
extraction_model (str): The model to use for extraction.
docs_path (str): The path to the csv file containing the documents.
prompt_set (str): The prompt set to use for the extraction.
output_dir (str): The directory to save the output files.
**kwargs: Additional keyword arguments to pass to the extraction function.
"""

docs = pd.read_csv(docs_path)

extraction_client = _load_client(extraction_model)

prompt_config = _load_prompt_config(prompt_set)
if kwargs is not None:
prompt_config.update(kwargs)

output_dir = Path(output_dir)

predictions, clean_preds = extract(
extraction_model, extraction_client, docs,
**prompt_config
)

if output_dir is not None:
_save_predictions(predictions, clean_preds, extraction_model, prompt_set, output_dir)

return predictions, clean_preds


def ParticipantDemographics(IndependentPipeline):
class ParticipantDemographicsExtraction(IndependentPipeline):
"""Participant demographics extraction pipeline."""

_version = "1.0.0"
Expand All @@ -100,7 +55,8 @@ def ParticipantDemographics(IndependentPipeline):
def __init__(
self,
extraction_model,
prompt_set, inputs=("text",),
prompt_set,
inputs=("text",),
input_sources=("pubget", "ace"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I would do is add the key as part of the __init__.

Later on, we could define a base class OpenAIPipeline that sets up the client for the subclass automatically, and know which parameters to expect.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e. the _load_client function could be part of t his new parent class and is alwasy called. For now it's fine though, we can cross that bridge later.

For now just make the key a config parameter and rename the value of the key something else.

**kwargs
):
Expand All @@ -117,15 +73,16 @@ def _run(self, study_inputs, n_cpus=1):
if self.kwargs is not None:
prompt_config.update(self.kwargs)

with open(study_inputs["text"]) as f:
text = f.read()

predictions, clean_preds = extract(
self.extraction_model,
extraction_client,
study_inputs["text"],
text,
prompt_set=self.prompt_set,
**prompt_config
)

# Save predictions

return {"predictions": predictions, "clean_predictions": clean_preds}
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ participant_demographics = [
"pandas",
"numpy",
"pydantic",
"publang",
"openai",
"publang @ git+https://github.com/adelavega/publang.git",
"openai"
]
umls_disease = [
"pandas",
Expand All @@ -39,10 +39,15 @@ word_count = [

tests = [
"pytest",
"pytest-recording",
"vcrpy",
]

[tool.hatch.version]
source = "vcs"

[tool.hatch.build.hooks.vcs]
version-file = "ns_pipelines/_version.py"

[tool.hatch.metadata]
allow-direct-references = true
2 changes: 0 additions & 2 deletions requirements.txt

This file was deleted.

16 changes: 0 additions & 16 deletions setup.py.old

This file was deleted.

Loading
Loading