-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery Starbot ⭐ refactored rubby33/CodeBERT #1
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,7 +55,7 @@ def normalize(s): | |
s = re.sub(pattern, replace, s) | ||
s = xml.sax.saxutils.unescape(s, {'"':'"'}) | ||
# language-dependent part (assuming Western languages): | ||
s = " %s " % s | ||
s = f" {s} " | ||
if not preserve_case: | ||
s = s.lower() # this might not be identical to the original | ||
for (pattern, replace) in normalize2: | ||
|
@@ -88,14 +88,10 @@ def cook_test(test, item, n=4): | |
encapsulates everything that BLEU needs to know about it.''' | ||
(reflens, refmaxcounts)=item | ||
test = normalize(test) | ||
result = {} | ||
result["testlen"] = len(test) | ||
|
||
result = {"testlen": len(test)} | ||
# Calculate effective reference sentence length. | ||
|
||
if eff_ref_len == "shortest": | ||
result["reflen"] = min(reflens) | ||
elif eff_ref_len == "average": | ||
|
||
if eff_ref_len == "average": | ||
Comment on lines
-91
to
+94
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
result["reflen"] = float(sum(reflens))/len(reflens) | ||
elif eff_ref_len == "closest": | ||
min_diff = None | ||
|
@@ -104,6 +100,8 @@ def cook_test(test, item, n=4): | |
min_diff = abs(reflen-len(test)) | ||
result['reflen'] = reflen | ||
|
||
elif eff_ref_len == "shortest": | ||
result["reflen"] = min(reflens) | ||
result["guess"] = [max(len(test)-k+1,0) for k in range(1,n+1)] | ||
|
||
result['correct'] = [0]*n | ||
|
@@ -154,47 +152,42 @@ def splitPuncts(line): | |
return ' '.join(re.findall(r"[\w]+|[^\s\w]", line)) | ||
|
||
def computeMaps(predictions, goldfile): | ||
predictionMap = {} | ||
goldMap = {} | ||
gf = open(goldfile, 'r') | ||
predictionMap = {} | ||
goldMap = {} | ||
gf = open(goldfile, 'r') | ||
|
||
for row in predictions: | ||
cols = row.strip().split('\t') | ||
if len(cols) == 1: | ||
(rid, pred) = (cols[0], '') | ||
else: | ||
(rid, pred) = (cols[0], cols[1]) | ||
predictionMap[rid] = [splitPuncts(pred.strip().lower())] | ||
for row in predictions: | ||
cols = row.strip().split('\t') | ||
(rid, pred) = (cols[0], '') if len(cols) == 1 else (cols[0], cols[1]) | ||
predictionMap[rid] = [splitPuncts(pred.strip().lower())] | ||
|
||
for row in gf: | ||
(rid, pred) = row.split('\t') | ||
if rid in predictionMap: # Only insert if the id exists for the method | ||
if rid not in goldMap: | ||
goldMap[rid] = [] | ||
goldMap[rid].append(splitPuncts(pred.strip().lower())) | ||
for row in gf: | ||
(rid, pred) = row.split('\t') | ||
if rid in predictionMap: # Only insert if the id exists for the method | ||
if rid not in goldMap: | ||
goldMap[rid] = [] | ||
goldMap[rid].append(splitPuncts(pred.strip().lower())) | ||
|
||
sys.stderr.write('Total: ' + str(len(goldMap)) + '\n') | ||
return (goldMap, predictionMap) | ||
sys.stderr.write(f'Total: {len(goldMap)}' + '\n') | ||
return (goldMap, predictionMap) | ||
Comment on lines
-157
to
+172
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
#m1 is the reference map | ||
#m2 is the prediction map | ||
def bleuFromMaps(m1, m2): | ||
score = [0] * 5 | ||
num = 0.0 | ||
score = [0] * 5 | ||
num = 0.0 | ||
|
||
for key in m1: | ||
if key in m2: | ||
bl = bleu(m1[key], m2[key][0]) | ||
score = [ score[i] + bl[i] for i in range(0, len(bl))] | ||
num += 1 | ||
return [s * 100.0 / num for s in score] | ||
for key in m1: | ||
if key in m2: | ||
bl = bleu(m1[key], m2[key][0]) | ||
score = [score[i] + bl[i] for i in range(len(bl))] | ||
num += 1 | ||
return [s * 100.0 / num for s in score] | ||
Comment on lines
-183
to
+186
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
if __name__ == '__main__': | ||
reference_file = sys.argv[1] | ||
predictions = [] | ||
for row in sys.stdin: | ||
predictions.append(row) | ||
(goldMap, predictionMap) = computeMaps(predictions, reference_file) | ||
print (bleuFromMaps(goldMap, predictionMap)[0]) | ||
reference_file = sys.argv[1] | ||
predictions = list(sys.stdin) | ||
(goldMap, predictionMap) = computeMaps(predictions, reference_file) | ||
print (bleuFromMaps(goldMap, predictionMap)[0]) | ||
Comment on lines
-194
to
+192
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,8 +73,8 @@ def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=N | |
return outputs | ||
else: | ||
#Predict | ||
preds=[] | ||
zero=torch.cuda.LongTensor(1).fill_(0) | ||
preds=[] | ||
zero=torch.cuda.LongTensor(1).fill_(0) | ||
Comment on lines
-76
to
+77
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
for i in range(source_ids.shape[0]): | ||
context=encoder_output[:,i:i+1] | ||
context_mask=source_mask[i:i+1,:] | ||
|
@@ -98,9 +98,8 @@ def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=N | |
pred=beam.buildTargetTokens(hyp)[:self.beam_size] | ||
pred=[torch.cat([x.view(-1) for x in p]+[zero]*(self.max_length-len(p))).view(1,-1) for p in pred] | ||
preds.append(torch.cat(pred,0).unsqueeze(0)) | ||
|
||
preds=torch.cat(preds,0) | ||
return preds | ||
|
||
return torch.cat(preds,0) | ||
|
||
|
||
|
||
|
@@ -124,8 +123,7 @@ def __init__(self, size,sos,eos): | |
|
||
def getCurrentState(self): | ||
"Get the outputs for the current timestep." | ||
batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) | ||
return batch | ||
return self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) | ||
Comment on lines
-127
to
+126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def getCurrentOrigin(self): | ||
"Get the backpointers for the current timestep." | ||
|
@@ -184,11 +182,12 @@ def getFinal(self): | |
self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) | ||
self.finished.sort(key=lambda a: -a[0]) | ||
if len(self.finished) != self.size: | ||
unfinished=[] | ||
for i in range(self.nextYs[-1].size(0)): | ||
if self.nextYs[-1][i] != self._eos: | ||
s = self.scores[i] | ||
unfinished.append((s, len(self.nextYs) - 1, i)) | ||
unfinished = [ | ||
(self.scores[i], len(self.nextYs) - 1, i) | ||
for i in range(self.nextYs[-1].size(0)) | ||
if self.nextYs[-1][i] != self._eos | ||
] | ||
|
||
Comment on lines
-187
to
+190
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
unfinished.sort(key=lambda a: -a[0]) | ||
self.finished+=unfinished[:self.size-len(self.finished)] | ||
return self.finished[:self.size] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,37 +104,36 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): | |
#source | ||
source_tokens = tokenizer.tokenize(example.source)[:args.max_source_length-2] | ||
source_tokens =[tokenizer.cls_token]+source_tokens+[tokenizer.sep_token] | ||
source_ids = tokenizer.convert_tokens_to_ids(source_tokens) | ||
source_ids = tokenizer.convert_tokens_to_ids(source_tokens) | ||
source_mask = [1] * (len(source_tokens)) | ||
padding_length = args.max_source_length - len(source_ids) | ||
source_ids+=[tokenizer.pad_token_id]*padding_length | ||
source_mask+=[0]*padding_length | ||
|
||
#target | ||
if stage=="test": | ||
target_tokens = tokenizer.tokenize("None") | ||
else: | ||
target_tokens = tokenizer.tokenize(example.target)[:args.max_target_length-2] | ||
target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token] | ||
target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token] | ||
target_ids = tokenizer.convert_tokens_to_ids(target_tokens) | ||
target_mask = [1] *len(target_ids) | ||
padding_length = args.max_target_length - len(target_ids) | ||
target_ids+=[tokenizer.pad_token_id]*padding_length | ||
target_mask+=[0]*padding_length | ||
|
||
if example_index < 5: | ||
if stage=='train': | ||
logger.info("*** Example ***") | ||
logger.info("idx: {}".format(example.idx)) | ||
|
||
logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens])) | ||
logger.info("source_ids: {}".format(' '.join(map(str, source_ids)))) | ||
logger.info("source_mask: {}".format(' '.join(map(str, source_mask)))) | ||
|
||
logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens])) | ||
logger.info("target_ids: {}".format(' '.join(map(str, target_ids)))) | ||
logger.info("target_mask: {}".format(' '.join(map(str, target_mask)))) | ||
|
||
|
||
if example_index < 5 and stage == 'train': | ||
logger.info("*** Example ***") | ||
logger.info(f"idx: {example.idx}") | ||
|
||
logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens])) | ||
logger.info(f"source_ids: {' '.join(map(str, source_ids))}") | ||
logger.info(f"source_mask: {' '.join(map(str, source_mask))}") | ||
|
||
logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens])) | ||
logger.info(f"target_ids: {' '.join(map(str, target_ids))}") | ||
logger.info(f"target_mask: {' '.join(map(str, target_mask))}") | ||
|
||
Comment on lines
-107
to
+136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
features.append( | ||
InputFeatures( | ||
example_index, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ def main(): | |
languages = ['ruby', 'go', 'php', 'python', 'java', 'javascript'] | ||
MRR_dict = {} | ||
for language in languages: | ||
file_dir = './results/{}'.format(language) | ||
file_dir = f'./results/{language}' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
ranks = [] | ||
num_batch = 0 | ||
for file in sorted(os.listdir(file_dir)): | ||
|
@@ -30,10 +30,10 @@ def main(): | |
ranks.append(rank) | ||
|
||
mean_mrr = np.mean(1.0 / np.array(ranks)) | ||
print("{} mrr: {}".format(language, mean_mrr)) | ||
print(f"{language} mrr: {mean_mrr}") | ||
MRR_dict[language] = mean_mrr | ||
for key, val in MRR_dict.items(): | ||
print("{} mrr: {}".format(key, val)) | ||
print(f"{key} mrr: {val}") | ||
|
||
|
||
if __name__ == "__main__": | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ def format_str(string): | |
|
||
|
||
def preprocess_test_data(language, test_batch_size=1000): | ||
path = os.path.join(DATA_DIR, '{}_test_0.jsonl.gz'.format(language)) | ||
path = os.path.join(DATA_DIR, f'{language}_test_0.jsonl.gz') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
print(path) | ||
with gzip.open(path, 'r') as pf: | ||
data = pf.readlines() | ||
|
@@ -35,7 +35,7 @@ def preprocess_test_data(language, test_batch_size=1000): | |
if len(batch_data) < test_batch_size: | ||
break # the last batch is smaller than the others, exclude. | ||
examples = [] | ||
for d_idx, d in enumerate(batch_data): | ||
for d in batch_data: | ||
line_a = json.loads(str(d, encoding='utf-8')) | ||
doc_token = ' '.join(line_a['docstring_tokens']) | ||
for dd in batch_data: | ||
|
@@ -46,10 +46,10 @@ def preprocess_test_data(language, test_batch_size=1000): | |
example = '<CODESPLIT>'.join(example) | ||
examples.append(example) | ||
|
||
data_path = os.path.join(DATA_DIR, 'test/{}'.format(language)) | ||
data_path = os.path.join(DATA_DIR, f'test/{language}') | ||
if not os.path.exists(data_path): | ||
os.makedirs(data_path) | ||
file_path = os.path.join(data_path, 'batch_{}.txt'.format(batch_idx)) | ||
file_path = os.path.join(data_path, f'batch_{batch_idx}.txt') | ||
print(file_path) | ||
with open(file_path, 'w', encoding='utf-8') as f: | ||
f.writelines('\n'.join(examples)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -134,7 +134,7 @@ def train(args, train_dataset, model, tokenizer, optimizer): | |
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well | ||
results = evaluate(args, model, tokenizer, checkpoint=str(global_step)) | ||
for key, value in results.items(): | ||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step) | ||
tb_writer.add_scalar(f'eval_{key}', value, global_step) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
logger.info('loss %s', str(tr_loss - logging_loss)) | ||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) | ||
tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step) | ||
|
@@ -173,7 +173,7 @@ def train(args, train_dataset, model, tokenizer, optimizer): | |
model_to_save = model.module if hasattr(model, | ||
'module') else model # Take care of distributed/parallel training | ||
model_to_save.save_pretrained(output_dir) | ||
torch.save(args, os.path.join(output_dir, 'training_{}.bin'.format(idx))) | ||
torch.save(args, os.path.join(output_dir, f'training_{idx}.bin')) | ||
logger.info("Saving model checkpoint to %s", output_dir) | ||
|
||
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) | ||
|
@@ -216,7 +216,7 @@ def evaluate(args, model, tokenizer, checkpoint=None, prefix="", mode='dev'): | |
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) | ||
|
||
# Eval! | ||
logger.info("***** Running evaluation {} *****".format(prefix)) | ||
logger.info(f"***** Running evaluation {prefix} *****") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
logger.info(" Num examples = %d", len(eval_dataset)) | ||
logger.info(" Batch size = %d", args.eval_batch_size) | ||
eval_loss = 0.0 | ||
|
@@ -252,11 +252,11 @@ def evaluate(args, model, tokenizer, checkpoint=None, prefix="", mode='dev'): | |
if args.output_mode == "classification": | ||
preds_label = np.argmax(preds, axis=1) | ||
result = compute_metrics(eval_task, preds_label, out_label_ids) | ||
results.update(result) | ||
results |= result | ||
if (mode == 'dev'): | ||
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") | ||
with open(output_eval_file, "a+") as writer: | ||
logger.info("***** Eval results {} *****".format(prefix)) | ||
logger.info(f"***** Eval results {prefix} *****") | ||
writer.write('evaluate %s\n' % checkpoint) | ||
for key in sorted(result.keys()): | ||
logger.info(" %s = %s", key, str(result[key])) | ||
|
@@ -273,9 +273,14 @@ def evaluate(args, model, tokenizer, checkpoint=None, prefix="", mode='dev'): | |
instance_rep = '<CODESPLIT>'.join( | ||
[item.encode('ascii', 'ignore').decode('ascii') for item in instances[i]]) | ||
|
||
writer.write(instance_rep + '<CODESPLIT>' + '<CODESPLIT>'.join([str(l) for l in logit]) + '\n') | ||
writer.write( | ||
f'{instance_rep}<CODESPLIT>' | ||
+ '<CODESPLIT>'.join([str(l) for l in logit]) | ||
+ '\n' | ||
) | ||
|
||
for key in sorted(result.keys()): | ||
print("%s = %s" % (key, str(result[key]))) | ||
print(f"{key} = {str(result[key])}") | ||
|
||
return results | ||
|
||
|
@@ -290,12 +295,11 @@ def load_and_cache_examples(args, task, tokenizer, ttype='train'): | |
file_name = args.dev_file.split('.')[0] | ||
elif ttype == 'test': | ||
file_name = args.test_file.split('.')[0] | ||
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}_{}'.format( | ||
ttype, | ||
file_name, | ||
list(filter(None, args.model_name_or_path.split('/'))).pop(), | ||
str(args.max_seq_length), | ||
str(task))) | ||
cached_features_file = os.path.join( | ||
args.data_dir, | ||
f"cached_{ttype}_{file_name}_{list(filter(None, args.model_name_or_path.split('/'))).pop()}_{str(args.max_seq_length)}_{str(task)}", | ||
) | ||
|
||
Comment on lines
-293
to
+302
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ):
|
||
|
||
# if os.path.exists(cached_features_file): | ||
try: | ||
|
@@ -313,15 +317,20 @@ def load_and_cache_examples(args, task, tokenizer, ttype='train'): | |
elif ttype == 'test': | ||
examples, instances = processor.get_test_examples(args.data_dir, args.test_file) | ||
|
||
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode, | ||
cls_token_at_end=bool(args.model_type in ['xlnet']), | ||
# xlnet has a cls token at the end | ||
cls_token=tokenizer.cls_token, | ||
sep_token=tokenizer.sep_token, | ||
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 1, | ||
pad_on_left=bool(args.model_type in ['xlnet']), | ||
# pad on the left for xlnet | ||
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0) | ||
features = convert_examples_to_features( | ||
examples, | ||
label_list, | ||
args.max_seq_length, | ||
tokenizer, | ||
output_mode, | ||
cls_token_at_end=args.model_type in ['xlnet'], | ||
cls_token=tokenizer.cls_token, | ||
sep_token=tokenizer.sep_token, | ||
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 1, | ||
pad_on_left=args.model_type in ['xlnet'], | ||
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, | ||
) | ||
|
||
if args.local_rank in [-1, 0]: | ||
logger.info("Saving features into cached file %s", cached_features_file) | ||
torch.save(features, cached_features_file) | ||
|
@@ -333,10 +342,7 @@ def load_and_cache_examples(args, task, tokenizer, ttype='train'): | |
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long) | ||
|
||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) | ||
if (ttype == 'test'): | ||
return dataset, instances | ||
else: | ||
return dataset | ||
return (dataset, instances) if (ttype == 'test') else dataset | ||
|
||
|
||
def main(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
normalize
refactored with the following changes:replace-interpolation-with-fstring
)