Skip to content

Commit

Permalink
feat: end of string optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
ganler committed May 2, 2024
1 parent b07be70 commit f0395c0
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 16 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pip install -r requirements.txt
## 🏁 Search Needle Function (SNF)

Search Needle Function is the first RepoQA task which aims to practice LLMs' ability of **long-context code understanding and retrieval**.
Its corresponding real-life application is to perform precise code search from user intent rather than simple keyword match.
Its corresponding real-life application is to perform precise code search from user intent rather than simple keyword match.

> [!Important]
>
Expand All @@ -47,7 +47,7 @@ Its corresponding real-life application is to perform precise code search from u
> 2. A NL description of the needle function without revealing keywords like function names
> 3. An instruction to retrieve the described function
>
> The evaluator passes a test if the searched function is syntactically closest to the ground-truth compared against
> The evaluator passes a test if the searched function is syntactically closest to the ground-truth compared against
> other functions (systematically parsed by `treesitter`) and the similarity is greater than a user defined threshold (by default 0.8).
You can run the SNF evaluation using various backends.
Expand All @@ -56,7 +56,7 @@ You can run the SNF evaluation using various backends.
>
> All evaluation can be performed in one just command.
>
> As a reference of evaluation time, it takes 30 minutes to evaluate a 7B model using two A6000s.
> As a reference of evaluation time, it takes one hour to evaluate a 7B model using two A6000s.
### OpenAI Compatible Servers

Expand Down
2 changes: 1 addition & 1 deletion repoqa/provider/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, model):
self.client = Client(api_key=os.getenv("ANTHROPIC_KEY"))

def generate_reply(
self, question, n=1, max_tokens=1024, temperature=0, system_msg=None
self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
) -> List[str]:
assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
replies = []
Expand Down
2 changes: 1 addition & 1 deletion repoqa/provider/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@
class BaseProvider(ABC):
@abstractmethod
def generate_reply(
self, question, n=1, max_tokens=1024, temperature=0, system_msg=None
self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
) -> List[str]:
...
2 changes: 1 addition & 1 deletion repoqa/provider/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, model):
self.client = genai.GenerativeModel(model)

def generate_reply(
self, question, n=1, max_tokens=1024, temperature=0, system_msg=None
self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
) -> List[str]:
assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
replies = make_auto_request(
Expand Down
28 changes: 24 additions & 4 deletions repoqa/provider/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from typing import List

import torch
from stop_sequencer import StopSequencer
from transformers import AutoModelForCausalLM, AutoTokenizer

from repoqa.provider.base import BaseProvider
from repoqa.provider.request import construct_message_list
from repoqa.provider.request import construct_message_list, hacky_assistant_stop_seq


class HfProvider(BaseProvider):
Expand All @@ -17,16 +18,35 @@ def __init__(self, model, trust_remote_code=False):
self.hf_model = AutoModelForCausalLM.from_pretrained(
model, trust_remote_code=trust_remote_code
).cuda()
self.stop_sequencer = StopSequencer(
model,
model_type="causal", # or seq2seq
tokenizer=self.tokenizer,
)
self.stop_seq = []
if self.tokenizer.chat_template:
self.stop_seq.append(hacky_assistant_stop_seq(self.tokenizer))

@torch.inference_mode()
def generate_reply(
self, question, n=1, max_tokens=1024, temperature=0, system_msg=None
self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
) -> List[str]:
assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"

prompt_tokens = self.tokenizer.apply_chat_template(
construct_message_list(question, system_msg), return_tensors="pt"
construct_message_list(question, system_msg),
return_tensors="pt",
add_generation_prompt=True,
).cuda()
output_text = self.hf_model.generate(

model = self.hf_model
if self.stop_seq:
model = self.stop_sequencer.register_stop_texts(
stop_texts=self.stop_seq,
input_length=prompt_tokens.size(-1),
)

output_text = model.generate(
input_ids=prompt_tokens,
max_new_tokens=max_tokens,
num_return_sequences=n,
Expand Down
13 changes: 12 additions & 1 deletion repoqa/provider/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from typing import List

from openai import Client
from transformers import AutoTokenizer

from repoqa.provider.base import BaseProvider
from repoqa.provider.request import hacky_assistant_stop_seq
from repoqa.provider.request.openai import make_auto_request


Expand All @@ -17,9 +19,17 @@ def __init__(self, model, base_url: str = None):
self.client = Client(
api_key=os.getenv("OPENAI_API_KEY", "none"), base_url=base_url
)
self.stop_seq = []
try:
tokenizer = AutoTokenizer.from_pretrained(model)
if tokenizer.chat_template:
self.stop_seq.append(hacky_assistant_stop_seq(tokenizer))
print("Using stop sequence: ", self.stop_seq)
except:
print("Failed to automatically fetch stop tokens from HuggingFace.")

def generate_reply(
self, question, n=1, max_tokens=1024, temperature=0, system_msg=None
self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
) -> List[str]:
assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
replies = make_auto_request(
Expand All @@ -30,6 +40,7 @@ def generate_reply(
n=n,
max_tokens=max_tokens,
system_msg=system_msg,
stop=self.stop_seq,
)

return [reply.message.content for reply in replies.choices]
11 changes: 11 additions & 0 deletions repoqa/provider/request/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,14 @@ def construct_message_list(message, system_message=None):
if system_message:
msglist.insert(0, {"role": "system", "content": system_message})
return msglist


def hacky_assistant_stop_seq(tokenizer) -> str:
_magic_string_ = "&==NowOrNever==&Accelerate!!!==&"
return tokenizer.apply_chat_template(
[
{"role": "user", "content": ""},
{"role": "assistant", "content": _magic_string_},
],
tokenize=False,
).split(_magic_string_)[-1]
17 changes: 12 additions & 5 deletions repoqa/provider/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from vllm import LLM, SamplingParams

from repoqa.provider.base import BaseProvider
from repoqa.provider.request import construct_message_list
from repoqa.provider.request import construct_message_list, hacky_assistant_stop_seq


class VllmProvider(BaseProvider):
def __init__(
self, model, tensor_parallel_size, max_model_len, trust_remote_code=False
self, model, tensor_parallel_size, max_model_len=None, trust_remote_code=False
):
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.llm = LLM(
Expand All @@ -22,22 +22,29 @@ def __init__(
max_model_len=max_model_len,
trust_remote_code=trust_remote_code,
)
self.stop_seq = []
if self.tokenizer.chat_template:
self.stop_seq.append(hacky_assistant_stop_seq(self.tokenizer))

def generate_reply(
self, question, n=1, max_tokens=1024, temperature=0, system_msg=None
self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
) -> List[str]:
assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"

prompt = self.tokenizer.apply_chat_template(
construct_message_list(question, system_msg), tokenize=False
construct_message_list(question, system_msg),
tokenize=False,
add_generation_prompt=True,
)
vllm_outputs = self.llm.generate(
[prompt],
SamplingParams(
temperature=temperature,
max_tokens=max_tokens,
stop=self.stop_seq,
),
use_tqdm=False,
)

gen_strs = [x.outputs[0].text.replace("\t", " ") for x in vllm_outputs]
gen_strs = [x.outputs[0].text for x in vllm_outputs]
return gen_strs
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ openai
anthropic
google-generativeai
vllm
stop-sequencer
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ install_requires =
openai>=1.23.2
anthropic>=0.25.6
google-generativeai>=0.5.2
stop-sequencer>=1.2.3

[options.entry_points]
console_scripts =
Expand Down

0 comments on commit f0395c0

Please sign in to comment.