-
Notifications
You must be signed in to change notification settings - Fork 1
Home
Nils Feldhus edited this page Mar 14, 2023
·
1 revision
Welcome to the InterroLang wiki!
This is an overview how we could integrate other desired dataset into TTM.
git clone [email protected]:dylan-slack/TalkToModel.git
Get into the TTM directory and run these commands:
conda create -n ttm python=3.9
conda activate ttm
Then you should all dependencies:
pip install -r requirements.txt
pip install datasets
Firstly, put the configure file into folder /configs as /configs/boolq.gin
##########################################
# The boolq dataset conversation config
##########################################
# for few shot, e.g., "EleutherAI/gpt-neo-2.7B"
ExplainBot.parsing_model_name = "EleutherAI/gpt-neo-2.7B"
# set skip_prompts to true for quicker startup for finetuned models
# make sure to set to false using few-shot models
ExplainBot.skip_prompts = False
ExplainBot.t5_config = "./parsing/t5/gin_configs/t5-large.gin"
ExplainBot.seed = 0
ExplainBot.background_dataset_file_path = "./data/boolq_train.csv"
ExplainBot.model_file_path = "./data/boolq_model"
ExplainBot.dataset_file_path = "./data/boolq_validation.csv"
ExplainBot.name = "boolq"
ExplainBot.dataset_index_column = "idx"
ExplainBot.target_variable_name = "label"
ExplainBot.categorical_features = None
ExplainBot.numerical_features = None
ExplainBot.remove_underscores = False
ExplainBot.prompt_metric = "cosine"
ExplainBot.prompt_ordering = "ascending"
# Prompt params
Prompts.prompt_cache_size = 1_000_000
Prompts.prompt_cache_location = "./cache/boolq-prompts.pkl"
Prompts.max_values_per_feature = 2
Prompts.sentence_transformer_model_name = "all-mpnet-base-v2"
Prompts.prompt_folder = "./explain/prompts"
Prompts.num_per_knn_prompt_template = 1
Prompts.num_prompt_template = 7
# Explanation Params
Explanation.max_cache_size = 1_000_000
# MegaExplainer Params
MegaExplainer.cache_location = "./cache/boolq-mega-explainer-tabular.pkl"
MegaExplainer.use_selection = False
# Conversation params
Conversation.class_names = {1: "True", 0: "False"}
# Dataset description
DatasetDescription.dataset_objective = "predict to answer yes/no questions based on text passages"
DatasetDescription.dataset_description = "Boolean question answering (yes/no)"
DatasetDescription.model_description = "DistilBERT"
And change the global config files in global_config.gin:
GlobalArgs.config = "./configs/boolq.gin"
Then you should add datasets:
from datasets import load_dataset
val = load_dataset("super_glue", "boolq", split="validation").to_csv('data/boolq_validation.csv')
train = load_dataset("super_glue", "boolq", split="train").to_csv('data/boolq_train.csv')
What's more, you should download the model: https://huggingface.co/andi611/distilbert-base-uncased-qa-boolq/tree/main. And put it under /configs as ./configs/boolq_model.
In /explain/logic.py,
- Add load_hf_model()
@gin.configurable
def load_hf_model(model_id):
""" Loads a (local) Hugging Face model from a directory containing a pytorch_model.bin file and a config.json file.
"""
return TransformerModel(model_id)
# transformers.AutoModel.from_pretrained(model_id)
- Comment load_explanations:
# Load the explanations
# self.load_explanations(background_dataset=background_dataset)
- Change else part from load_model():
else:
model = load_hf_model(filepath)
self.conversation.add_var('model', model, 'model')
python flask_app.py