-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgenerate.py
42 lines (36 loc) · 1.6 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from fairseq.models.bart import BARTModel
import os
import time
import numpy as np
os.environ['CUDA_VISIBLE_DEVICES']="1"
bart = BARTModel.from_pretrained('checkpoint-simile/',checkpoint_file='checkpoint_best.pt',data_name_or_path='simile')
#If you want to use pretrained BART model use this
# bart = BARTModel.from_pretrained('bart.large', checkpoint_file='model.pt',task='translation',data_name_or_path='your data')
# https://github.com/pytorch/fairseq/issues/1944 make changes in hub_interface.py as mentioned in this issue
bart.cuda()
bart.eval()
np.random.seed(4)
torch.manual_seed(4)
count = 1
bsz = 1
t = 0.7
for val in [5]:
with open('literal.txt') as source, open('simile.hypo', 'w') as fout:
sline = source.readline().strip()
slines = [sline]
for sline in source:
if count % bsz == 0:
with torch.no_grad():
hypotheses_batch = bart.sample(slines, sampling=True, sampling_topk=val ,temperature=t ,lenpen=2.0, max_len_b=30, min_len=7, no_repeat_ngram_size=3)
for hypothesis in hypotheses_batch:
fout.write(hypothesis.replace('\n','') + '\n')
fout.flush()
slines = []
slines.append(sline.strip())
count += 1
if slines != []:
hypotheses_batch = bart.sample(slines, sampling=True, sampling_topk=val ,temperature=t ,lenpen=2.0, max_len_b=30, min_len=7, no_repeat_ngram_size=3)
for hypothesis in hypotheses_batch:
fout.write(hypothesis.replace('\n','') + '\n')
fout.flush()