-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Local Microservice for CheckList in Explainability api #436
base: explainability_api
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
# add functionality to process json to be injected into the db | ||
import json | ||
import logging | ||
import requests | ||
import itertools | ||
|
||
from typing import List | ||
|
||
|
||
def create_query(skill, test_cases: List): | ||
""" | ||
Creates a query and make it suitable for sending to for prediction | ||
|
||
Args: | ||
skill: input skill for which the checklist tests are run | ||
test_cases (list) : Test cases as a list | ||
|
||
Returns: | ||
json_object (json object) : A json object containing the test case and its prediction | ||
answer (str) : Prediction for test case made by the skill | ||
|
||
""" | ||
skill_type = skill["skill_type"] | ||
base_model = skill["default_skill_args"].get("base_model") | ||
adapter = skill["default_skill_args"].get("adapter") | ||
# extract all tests | ||
all_tests = [tests["test_cases"] for tests in test_cases] | ||
# all_tests = list(itertools.chain.from_iterable([tests["test_cases"] for tests in test_cases])) | ||
questions, contexts, answers = list(), list(), list() | ||
|
||
test_type = list(itertools.chain.from_iterable([[test["test_type"]] * len(test["test_cases"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also like list comprehension in python, but here you loop three times over the same list ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you! I sucked right there :D |
||
for test in test_cases])) | ||
capability = list(itertools.chain.from_iterable([[test["capability"]] * len(test["test_cases"]) | ||
for test in test_cases])) | ||
test_name = list(itertools.chain.from_iterable([[test["test_name"]] * len(test["test_cases"]) | ||
for test in test_cases])) | ||
|
||
for tests in all_tests: | ||
questions.append([query["question"] for query in tests]) | ||
# list of list for mcq else list | ||
contexts.append([query["context"] if skill_type != "multiple-choice" | ||
else query["context"] + "\n" + "\n".join(query["options"]) | ||
for query in tests]) | ||
answers.extend([query.get("answer") if "answer" in query.keys() else query.get("prediction_before_change") | ||
for query in tests]) | ||
|
||
# TODO | ||
# send batch to the skill query endpoint | ||
|
||
prediction_requests = list() | ||
# create the prediction request | ||
for idx in range(len(questions)): | ||
for question, context in zip(questions[idx], contexts[idx]): | ||
request = dict() | ||
request["num_results"] = 1 | ||
request["user_id"] = "ukp" | ||
request["skill_args"] = {"base_model": base_model, "adapter": adapter, "context": context} | ||
request["query"] = question | ||
prediction_requests.append(request) | ||
|
||
model_inputs = dict() | ||
model_inputs["request"] = prediction_requests | ||
model_inputs["answers"] = answers | ||
model_inputs["test_type"] = test_type | ||
model_inputs["capability"] = capability | ||
model_inputs["test_name"] = test_name | ||
# logger.info("inputs:", model_inputs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove commented code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use comments only when you feel that some part of the code is hard to understand and commenting will improve readability. |
||
|
||
return model_inputs | ||
|
||
|
||
def predict(model_inputs: dict, skill_id: str) -> list: | ||
""" | ||
Predicts a given query | ||
|
||
Args: | ||
model_inputs (dict) : input for the model inference | ||
skill_id (str) : id of skill for which the predictions need to be run | ||
|
||
Returns: | ||
Returns the model predictions and success rate | ||
""" | ||
model_outputs = list() | ||
try: | ||
headers = {'Content-type': 'application/json'} | ||
skill_query_url = f"https://square.ukp-lab.de/api/skill-manager/skill/{skill_id}/query" #note I hardcoded square URL here | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The url should come from the env variable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part of the code was provided by Haritz. We did not ask him why he hard coded the url but I guess he has some reasons? @HaritzPuerto There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. He's on vacation right now. I think it save to use SQUARE_API_URL. The reason we want make this configurable is to be able to deploy to different environments. BTW, the URL contains the protocol, so something like this should work: import os
# ...
skill_query_url = f"{os.environ['SQUARE_API_URL']}/skill-manager/skill/{skill_id}/query" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay great, could you please change the code to that? :D Thanks |
||
model_predictions = list() | ||
# i = 0 | ||
for request in model_inputs["request"]: | ||
response = requests.post(skill_query_url, data=json.dumps(request), headers=headers) | ||
predictions = response.json() | ||
model_predictions.append(predictions["predictions"][0]["prediction_output"]["output"]) | ||
# i += 1 | ||
# if i == 10: | ||
# break | ||
|
||
# calculate success rate | ||
success_rate = [pred == gold for pred, gold in zip(model_predictions, model_inputs["answers"])] | ||
# overall_success_rate = sum(success_rate) / len(success_rate) | ||
|
||
for test_type, capability, test_name, request, answer, prediction, success in zip( | ||
model_inputs["test_type"], | ||
model_inputs["capability"], | ||
model_inputs["test_name"], | ||
model_inputs["request"], | ||
model_inputs["answers"], | ||
model_predictions, | ||
success_rate | ||
): | ||
model_outputs.append( | ||
{ | ||
"skill_id": skill_id, | ||
"test_type": test_type, | ||
"capability": capability, | ||
"test_name": test_name, | ||
"question": request["query"], | ||
"context": request["skill_args"]["context"], | ||
"answer": answer, | ||
"prediction": prediction, | ||
"success": success | ||
} | ||
) | ||
# print(model_outputs) | ||
except Exception as ex: | ||
logging.info(ex) | ||
return model_outputs | ||
|
||
|
||
def test_name_analysis(model_outputs): | ||
types_of_tests = {} | ||
for element in list(set([result["test_type"] for result in model_outputs])): | ||
types_of_tests[element] = dict() | ||
for test in types_of_tests.keys(): | ||
test_names = list(set([result["test_name"] for result in model_outputs if result["test_type"] == test])) | ||
for name in test_names: | ||
successful = 0 | ||
failure = 0 | ||
for result in model_outputs: | ||
if result["test_name"] == name: | ||
if result["success"]: | ||
successful += 1 | ||
else: | ||
failure += 1 | ||
types_of_tests[test][name] = dict({"successful": successful, "failure": failure}) | ||
return [types_of_tests] | ||
|
||
|
||
def capability_analysis(model_outputs): | ||
types_of_tests = {} | ||
for element in list(set([result["test_type"] for result in model_outputs])): | ||
types_of_tests[element] = dict() | ||
for test in types_of_tests.keys(): | ||
test_capabilities = list(set([result["capability"] for result in model_outputs if result["test_type"] == test])) | ||
for cap in test_capabilities: | ||
successful = 0 | ||
failure = 0 | ||
for result in model_outputs: | ||
if result["capability"] == cap: | ||
if result["success"]: | ||
successful += 1 | ||
else: | ||
failure += 1 | ||
types_of_tests[test][cap] = dict({"successful": successful, "failure": failure}) | ||
return [types_of_tests] | ||
|
||
|
||
def test_type_analysis(model_outputs): | ||
types_of_tests = {} | ||
for element in list(set([result["test_type"] for result in model_outputs])): | ||
successful = 0 | ||
failure = 0 | ||
for result in model_outputs: | ||
if result['test_type'] == element: | ||
if result['success']: | ||
successful += 1 | ||
else: | ||
failure += 1 | ||
types_of_tests[element] = dict({"successful": successful, "failure": failure}) | ||
return [types_of_tests] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import logging | ||
|
||
import requests | ||
import json | ||
from fastapi import APIRouter | ||
|
||
import checklist | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.get('/') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this endpoint |
||
def read_root(): | ||
return {"Hello": "World"} | ||
|
||
|
||
@router.get('/checklist/{skill_id}', name="run checklist") | ||
def run_checklist(skill_id: str, checklist_name: str, n_tests: int = None, test_analysis: str = None) -> list: | ||
""" | ||
|
||
:param skill_id: skill id | ||
:param checklist_name: name of checklist | ||
:param n_tests: show how many test from the first onwards | ||
:param test_analysis: how to return the result of CheckList | ||
:return: output | ||
""" | ||
# tested with this skill_id: 63cdbd06a8b0d566ef20cb54 - although performance is poor | ||
assert checklist_name is not None | ||
|
||
checklist_path_dict = { | ||
'extractive': '../checklists/extractive_model_tests.json', | ||
'boolean': '../checklists/boolean_model_tests.json', | ||
'abstractive': '../checklists/abstractive_models_tests.json', | ||
'multiple_choice': '../checklists/multiple_choice_model_tests.json', | ||
'open_domain': '../checklists/open_domain_models_tests.json', | ||
'open_domain_bioasq': '../checklists/open_domain_models_bioasq_tests.json' | ||
} | ||
|
||
checklist_path = checklist_path_dict[checklist_name] | ||
with open(checklist_path) as f: | ||
checklist_tests = json.load(f) | ||
|
||
try: | ||
skill_response = requests.get(f'https://square.ukp-lab.de/api/skill-manager/skill/{skill_id}') | ||
skill = skill_response.json() | ||
skill_id = skill["id"] | ||
# skill_type = skill["skill_type"] | ||
|
||
test_cases = checklist_tests['tests'] | ||
model_inputs = checklist.create_query(skill, test_cases) | ||
|
||
if n_tests is not None: | ||
model_inputs['request'] = model_inputs["request"][:n_tests] # if all would be too much | ||
else: | ||
model_inputs['request'] = model_inputs["request"] # if all would be too much | ||
model_outputs = checklist.predict(model_inputs, skill_id) | ||
|
||
if test_analysis is None: | ||
output_return = model_outputs | ||
# Analysis result | ||
else: | ||
if test_analysis == 'test_type': | ||
output_return = checklist.test_type_analysis(model_outputs) | ||
elif test_analysis == 'capability': | ||
output_return = checklist.capability_analysis(model_outputs) | ||
elif test_analysis == 'test_name': | ||
output_return = checklist.test_name_analysis(model_outputs) | ||
|
||
# assert output_return is not list | ||
|
||
# saves output as json | ||
with open('temp_result/temp_result.json', 'w') as f: | ||
json.dump(output_return, f, indent=4) | ||
|
||
return output_return | ||
|
||
except Exception as e: | ||
print(e) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from fastapi import FastAPI | ||
|
||
import checklist_api | ||
|
||
app = FastAPI() | ||
|
||
app.include_router(checklist_api.router) | ||
|
||
if __name__ == "__main__": | ||
import uvicorn | ||
uvicorn.run("main:app", reload=True, timeout_keep_alive=200) # for dev purposes |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
[ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need this file checked-in? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so, or else
will throw an error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When using from pathlib import Path
output_file = Path(os.environ["OUTPUT_DIR"]) / "result.json"
output_file.parent.mkdir(parents=True, exist_ok=True)
output_file.write_text(json.dumps(result)) How does the service further need this resutl? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can change it however you think it best fit. I'm thinking maybe you can simply remove the parts about a new directory and just call open a new json file. Haritz just told us to save the results locally somewhere so we can use it to analyse it locally. So I just saved it and return it through the api. I'm so sorry if i'm asking too much of you :'( |
||
{ | ||
"MFT": { | ||
"Temporal": { | ||
"successful": 0, | ||
"failure": 30 | ||
}, | ||
"Coref": { | ||
"successful": 0, | ||
"failure": 80 | ||
}, | ||
"Negation": { | ||
"successful": 0, | ||
"failure": 40 | ||
}, | ||
"Fairness": { | ||
"successful": 0, | ||
"failure": 10 | ||
}, | ||
"SRL": { | ||
"successful": 0, | ||
"failure": 100 | ||
}, | ||
"Taxonomy": { | ||
"successful": 0, | ||
"failure": 242 | ||
}, | ||
"Vocabulary": { | ||
"successful": 0, | ||
"failure": 70 | ||
} | ||
}, | ||
"INV": { | ||
"Robustness": { | ||
"successful": 0, | ||
"failure": 20 | ||
}, | ||
"NER": { | ||
"successful": 0, | ||
"failure": 100 | ||
} | ||
} | ||
} | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented code