From a468a6783396b824b2652fa7a6b7f259f1fc616a Mon Sep 17 00:00:00 2001 From: Drake Wong <40375132+drakejwong@users.noreply.github.com> Date: Fri, 10 Mar 2023 02:04:13 -0800 Subject: [PATCH] https://github.com/shawwn/llama-dl/issues/1#issuecomment-1458870564 --- download.sh | 31 ++++++------ example.py | 116 ++++++++++++++++++++++++++++++-------------- llama/generation.py | 69 +++++++++++++++++++------- 3 files changed, 147 insertions(+), 69 deletions(-) diff --git a/download.sh b/download.sh index db520dcfe..f383a6920 100644 --- a/download.sh +++ b/download.sh @@ -1,9 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the GNU General Public License version 3. -PRESIGNED_URL="" # replace with presigned url from email -MODEL_SIZE="7B,13B,30B,65B" # edit this list with the model sizes you wish to download -TARGET_FOLDER="" # where all files should end up +PRESIGNED_URL="https://agi.gpt4.org/llama/LLaMA/*" + +MODEL_SIZE="7B,13B" # edit this list with the model sizes you wish to download +TARGET_FOLDER="$HOME/llama-model/" # where all files should end up declare -A N_SHARD_DICT @@ -18,16 +19,14 @@ wget ${PRESIGNED_URL/'*'/"tokenizer_checklist.chk"} -O ${TARGET_FOLDER}"/tokeniz (cd ${TARGET_FOLDER} && md5sum -c tokenizer_checklist.chk) -for i in ${MODEL_SIZE//,/ } -do - echo "Downloading ${i}" - mkdir -p ${TARGET_FOLDER}"/${i}" - for s in $(seq -f "0%g" 0 ${N_SHARD_DICT[$i]}) - do - wget ${PRESIGNED_URL/'*'/"${i}/consolidated.${s}.pth"} -O ${TARGET_FOLDER}"/${i}/consolidated.${s}.pth" - done - wget ${PRESIGNED_URL/'*'/"${i}/params.json"} -O ${TARGET_FOLDER}"/${i}/params.json" - wget ${PRESIGNED_URL/'*'/"${i}/checklist.chk"} -O ${TARGET_FOLDER}"/${i}/checklist.chk" - echo "Checking checksums" - (cd ${TARGET_FOLDER}"/${i}" && md5sum -c checklist.chk) -done \ No newline at end of file +for i in ${MODEL_SIZE//,/ }; do + echo "Downloading ${i}" + mkdir -p ${TARGET_FOLDER}"/${i}" + for s in $(seq -f "0%g" 0 ${N_SHARD_DICT[$i]}); do + wget ${PRESIGNED_URL/'*'/"${i}/consolidated.${s}.pth"} -O ${TARGET_FOLDER}"/${i}/consolidated.${s}.pth" + done + wget ${PRESIGNED_URL/'*'/"${i}/params.json"} -O ${TARGET_FOLDER}"/${i}/params.json" + wget ${PRESIGNED_URL/'*'/"${i}/checklist.chk"} -O ${TARGET_FOLDER}"/${i}/checklist.chk" + echo "Checking checksums" + (cd ${TARGET_FOLDER}"/${i}" && md5sum -c checklist.chk) +done diff --git a/example.py b/example.py index fba9a54a5..55c4f6beb 100755 --- a/example.py +++ b/example.py @@ -16,7 +16,7 @@ from llama import ModelArgs, Transformer, Tokenizer, LLaMA -def setup_model_parallel() -> Tuple[int, int]: +def setup_model_parallel(seed: int) -> Tuple[int, int]: local_rank = int(os.environ.get("LOCAL_RANK", -1)) world_size = int(os.environ.get("WORLD_SIZE", -1)) @@ -25,7 +25,7 @@ def setup_model_parallel() -> Tuple[int, int]: torch.cuda.set_device(local_rank) # seed must be the same in all processes - torch.manual_seed(1) + torch.manual_seed(seed) return local_rank, world_size @@ -66,53 +66,97 @@ def load( def main( ckpt_dir: str, tokenizer_path: str, - temperature: float = 0.8, - top_p: float = 0.95, + temperature: float = 0.7, + # top_p: float = 0.95, + top_p: float = 0.0, + top_k: int = 40, + repetition_penalty: float = (1 / 0.85), max_seq_len: int = 512, + max_gen_len: int = 256, max_batch_size: int = 32, + seed: int = 1, + count: int = 5, ): - local_rank, world_size = setup_model_parallel() + local_rank, world_size = setup_model_parallel(seed) if local_rank > 0: sys.stdout = open(os.devnull, "w") + print("\n") + print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + print(json.dumps(dict( + seed=seed, + temp=temperature, + top_p=top_p, + top_k=top_k, + repetition_penalty=repetition_penalty, + max_seq_len=max_seq_len, + max_gen_len=max_gen_len, + ))) + print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + + generator = load( ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size ) prompts = [ # For these prompts, the expected answer is the natural continuation of the prompt - "I believe the meaning of life is", - "Simply put, the theory of relativity states that ", - "Building a website can be done in 10 simple steps:\n", - # Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api - """Tweet: "I hate it when my phone battery dies." -Sentiment: Negative -### -Tweet: "My day has been 👍" -Sentiment: Positive -### -Tweet: "This is the link to the article" -Sentiment: Neutral -### -Tweet: "This new music video was incredibile" -Sentiment:""", - """Translate English to French: - -sea otter => loutre de mer - -peppermint => menthe poivrée - -plush girafe => girafe peluche - -cheese =>""", - ] - results = generator.generate( - prompts, max_gen_len=256, temperature=temperature, top_p=top_p - ) - for result in results: - print(result) - print("\n==================================\n") + # "I believe the meaning of life is", + # "Simply put, the theory of relativity states that", + # "Building a website can be done in a few simple steps:\n1.", + # "Here's how to build it in a few simple steps:\n1.", + + "This is Captain Jean-Luc Picard", + "I am Lieutenant Commander Data", + "The Klingons are attacking", + +# # Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api +# """Tweet: "I hate it when my phone battery dies." +# Sentiment: Negative +# ### +# Tweet: "My day has been 👍" +# Sentiment: Positive +# ### +# Tweet: "This is the link to the article" +# Sentiment: Neutral +# ### +# Tweet: "This new music video was incredibile" +# Sentiment:""", +# """Translate English to French: +# +# sea otter => loutre de mer +# +# peppermint => menthe poivrée +# +# plush girafe => girafe peluche +# +# cheese =>""", + ] + i = 0 + while i < count or count <= 0: + i += 1 + for prompt in prompts: + print(f"\n============== sample {i} =================\n") + width = 0 + def callback(text): + nonlocal width + text = text.replace('\n', '\n\n') + chars = [] + for i, c in enumerate(text): + if c == ' ' and width >= 60: + chars.append('\n') + width = 0 + else: + width += 1 + chars.append(c) + if c == '\n': + width = 0 + text = ''.join(chars) + print(text, end='', flush=True) + text, = generator.generate( + [prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, token_callback=callback, + ) if __name__ == "__main__": diff --git a/llama/generation.py b/llama/generation.py index 3abd3edb1..6d64bdb8e 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -18,8 +18,11 @@ def generate( self, prompts: List[str], max_gen_len: int, - temperature: float = 0.8, - top_p: float = 0.95, + temperature: float = 0.7, + top_k: int = 40, + top_p: float = 0.0, #0.95, + repetition_penalty: float = (1.0 / 0.85), + token_callback=None, ) -> List[str]: bsz = len(prompts) params = self.model.params @@ -38,11 +41,26 @@ def generate( input_text_mask = tokens != self.tokenizer.pad_id start_pos = min_prompt_size prev_pos = 0 + prev_text = '' for cur_pos in range(start_pos, total_len): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + + # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + logits_new = logits.clone() + batch_size = len(tokens) + for i in range(batch_size): + for token in set(tokens[i].tolist()): + # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if logits[i, token] < 0: + logits_new[i, token] = logits[i, token] * repetition_penalty + else: + logits_new[i, token] = logits[i, token] / repetition_penalty + logits = logits_new + if temperature > 0: probs = torch.softmax(logits / temperature, dim=-1) - next_token = sample_top_p(probs, top_p) + next_token = sample(probs, top_p=top_p, top_k=top_k) else: next_token = torch.argmax(logits, dim=-1) next_token = next_token.reshape(-1) @@ -50,28 +68,45 @@ def generate( next_token = torch.where( input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token ) + if next_token == self.tokenizer.eos_id: + break tokens[:, cur_pos] = next_token + if token_callback is not None: + assert len(prompts) == 1 + text, = self.decode(tokens) + #assert text.startswith(prev_text) + if not text.startswith(prev_text): + # Some kind of bogus token generation; abort early. + break + next_word = text[len(prev_text):] + prev_text = text + token_callback(next_word) prev_pos = cur_pos + return self.decode(tokens) + + def decode(self, tokens): decoded = [] for i, t in enumerate(tokens.tolist()): - # cut to max gen len - t = t[: len(prompt_tokens[i]) + max_gen_len] - # cut to eos tok if any - try: - t = t[: t.index(self.tokenizer.eos_id)] - except ValueError: - pass + t = [token for token in t if token != -1] + # # cut to max gen len + # t = t[: len(prompt_tokens[i]) + max_gen_len] + while self.tokenizer.eos_id in t: + pos = t.index(self.tokenizer.eos_id) + t[pos:pos+1] = self.tokenizer.encode('\n<|endoftext|>\n', bos=False, eos=False) decoded.append(self.tokenizer.decode(t)) return decoded - -def sample_top_p(probs, p): - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) +def sample(probs, top_p=0.0, top_k=40): + if top_k > 0: + probs_sort, probs_idx = torch.topk(probs, top_k) + else: + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + if top_p > 0.0: + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > top_p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token