From c4b984e14186f954a08a38a4d844938e578044cf Mon Sep 17 00:00:00 2001 From: mamu Date: Wed, 22 May 2024 14:15:16 +0900 Subject: [PATCH] Work evo merge on load --- package/samplers/evo_merge/example.py | 11 +++++------ package/samplers/evo_merge/sampler.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/package/samplers/evo_merge/example.py b/package/samplers/evo_merge/example.py index e98163df..6819c054 100644 --- a/package/samplers/evo_merge/example.py +++ b/package/samplers/evo_merge/example.py @@ -10,13 +10,12 @@ from langchain.llms.base import BaseLLM from langchain.prompts import PromptTemplate import optuna +import optunahub -from package.samplers.evo_merge.sampler import EvoMergeSampler -from package.samplers.evo_merge.trial import EvoMergeTrial - -# EvoMergeSampler = optunahub.load_module("samplers/evo_merge").EvoMergeSampler -# EvoMergeTrial = optunahub.load_module("samplers/evo_merge").EvoMergeTrial +module = optunahub.load_module("samplers/evo_merge") +EvoMergeSampler = module.EvoMergeSampler +EvoMergeTrial = module.EvoMergeTrial TEMPLATE = "質問に答えなさい。質問: {question} 回答: " @@ -39,7 +38,7 @@ def eval_jaqket(llm_chain: LLMChain) -> int: out = llm_chain.run(question=problem["question"]) if len(out.strip()) != 0: out = out.strip().split()[0].strip() - if problem["answer_number"] in out: + if str(problem["answer_number"]) in out: correct += 1 return correct diff --git a/package/samplers/evo_merge/sampler.py b/package/samplers/evo_merge/sampler.py index 156c1876..8e3a580f 100644 --- a/package/samplers/evo_merge/sampler.py +++ b/package/samplers/evo_merge/sampler.py @@ -106,7 +106,7 @@ def sample_model(self, study: Study, trial: EvoMergeTrial) -> BaseLLM: def load_model(model_id: str) -> BaseLLM: bnbconf = BitsAndBytesConfig(load_in_4bit=True) tokenizer = AutoTokenizer.from_pretrained(model_id) - model = AutoModelForCausalLM.frm_pretrained(model_id, quantization_config=bnbconf) + model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnbconf) llm = HuggingFacePipeline( pipeline=pipeline( "text-generation",