From 62be96a543d1074e08686a2ea2be9370201f1d6e Mon Sep 17 00:00:00 2001 From: Sourcery AI Date: Thu, 10 Nov 2022 06:44:17 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- CodeBERT/code2nl/bleu.py | 73 +++---- CodeBERT/code2nl/model.py | 23 +- CodeBERT/code2nl/run.py | 33 ++- CodeBERT/codesearch/mrr.py | 6 +- CodeBERT/codesearch/process_data.py | 8 +- CodeBERT/codesearch/run_classifier.py | 58 ++--- CodeBERT/codesearch/utils.py | 34 ++- .../clonedetection/evaluator/evaluator.py | 5 +- GraphCodeBERT/clonedetection/model.py | 2 +- GraphCodeBERT/clonedetection/parser/DFG.py | 66 +++--- GraphCodeBERT/clonedetection/parser/utils.py | 26 +-- GraphCodeBERT/clonedetection/run.py | 204 +++++++++--------- GraphCodeBERT/codesearch/model.py | 19 +- GraphCodeBERT/codesearch/parser/DFG.py | 66 +++--- GraphCodeBERT/codesearch/parser/utils.py | 26 +-- GraphCodeBERT/codesearch/run.py | 183 ++++++++-------- GraphCodeBERT/refinement/bleu.py | 54 ++--- GraphCodeBERT/refinement/model.py | 27 ++- GraphCodeBERT/refinement/parser/DFG.py | 66 +++--- GraphCodeBERT/refinement/parser/utils.py | 26 +-- GraphCodeBERT/refinement/run.py | 91 ++++---- GraphCodeBERT/translation/bleu.py | 54 ++--- GraphCodeBERT/translation/model.py | 27 ++- GraphCodeBERT/translation/parser/DFG.py | 66 +++--- GraphCodeBERT/translation/parser/utils.py | 26 +-- GraphCodeBERT/translation/run.py | 98 +++++---- 26 files changed, 655 insertions(+), 712 deletions(-) diff --git a/CodeBERT/code2nl/bleu.py b/CodeBERT/code2nl/bleu.py index f8f2b0a..3d11132 100644 --- a/CodeBERT/code2nl/bleu.py +++ b/CodeBERT/code2nl/bleu.py @@ -55,7 +55,7 @@ def normalize(s): s = re.sub(pattern, replace, s) s = xml.sax.saxutils.unescape(s, {'"':'"'}) # language-dependent part (assuming Western languages): - s = " %s " % s + s = f" {s} " if not preserve_case: s = s.lower() # this might not be identical to the original for (pattern, replace) in normalize2: @@ -88,14 +88,10 @@ def cook_test(test, item, n=4): encapsulates everything that BLEU needs to know about it.''' (reflens, refmaxcounts)=item test = normalize(test) - result = {} - result["testlen"] = len(test) - + result = {"testlen": len(test)} # Calculate effective reference sentence length. - - if eff_ref_len == "shortest": - result["reflen"] = min(reflens) - elif eff_ref_len == "average": + + if eff_ref_len == "average": result["reflen"] = float(sum(reflens))/len(reflens) elif eff_ref_len == "closest": min_diff = None @@ -104,6 +100,8 @@ def cook_test(test, item, n=4): min_diff = abs(reflen-len(test)) result['reflen'] = reflen + elif eff_ref_len == "shortest": + result["reflen"] = min(reflens) result["guess"] = [max(len(test)-k+1,0) for k in range(1,n+1)] result['correct'] = [0]*n @@ -154,47 +152,42 @@ def splitPuncts(line): return ' '.join(re.findall(r"[\w]+|[^\s\w]", line)) def computeMaps(predictions, goldfile): - predictionMap = {} - goldMap = {} - gf = open(goldfile, 'r') + predictionMap = {} + goldMap = {} + gf = open(goldfile, 'r') - for row in predictions: - cols = row.strip().split('\t') - if len(cols) == 1: - (rid, pred) = (cols[0], '') - else: - (rid, pred) = (cols[0], cols[1]) - predictionMap[rid] = [splitPuncts(pred.strip().lower())] + for row in predictions: + cols = row.strip().split('\t') + (rid, pred) = (cols[0], '') if len(cols) == 1 else (cols[0], cols[1]) + predictionMap[rid] = [splitPuncts(pred.strip().lower())] - for row in gf: - (rid, pred) = row.split('\t') - if rid in predictionMap: # Only insert if the id exists for the method - if rid not in goldMap: - goldMap[rid] = [] - goldMap[rid].append(splitPuncts(pred.strip().lower())) + for row in gf: + (rid, pred) = row.split('\t') + if rid in predictionMap: # Only insert if the id exists for the method + if rid not in goldMap: + goldMap[rid] = [] + goldMap[rid].append(splitPuncts(pred.strip().lower())) - sys.stderr.write('Total: ' + str(len(goldMap)) + '\n') - return (goldMap, predictionMap) + sys.stderr.write(f'Total: {len(goldMap)}' + '\n') + return (goldMap, predictionMap) #m1 is the reference map #m2 is the prediction map def bleuFromMaps(m1, m2): - score = [0] * 5 - num = 0.0 + score = [0] * 5 + num = 0.0 - for key in m1: - if key in m2: - bl = bleu(m1[key], m2[key][0]) - score = [ score[i] + bl[i] for i in range(0, len(bl))] - num += 1 - return [s * 100.0 / num for s in score] + for key in m1: + if key in m2: + bl = bleu(m1[key], m2[key][0]) + score = [score[i] + bl[i] for i in range(len(bl))] + num += 1 + return [s * 100.0 / num for s in score] if __name__ == '__main__': - reference_file = sys.argv[1] - predictions = [] - for row in sys.stdin: - predictions.append(row) - (goldMap, predictionMap) = computeMaps(predictions, reference_file) - print (bleuFromMaps(goldMap, predictionMap)[0]) + reference_file = sys.argv[1] + predictions = list(sys.stdin) + (goldMap, predictionMap) = computeMaps(predictions, reference_file) + print (bleuFromMaps(goldMap, predictionMap)[0]) diff --git a/CodeBERT/code2nl/model.py b/CodeBERT/code2nl/model.py index a49e819..8f5a3b2 100644 --- a/CodeBERT/code2nl/model.py +++ b/CodeBERT/code2nl/model.py @@ -73,8 +73,8 @@ def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=N return outputs else: #Predict - preds=[] - zero=torch.cuda.LongTensor(1).fill_(0) + preds=[] + zero=torch.cuda.LongTensor(1).fill_(0) for i in range(source_ids.shape[0]): context=encoder_output[:,i:i+1] context_mask=source_mask[i:i+1,:] @@ -98,9 +98,8 @@ def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=N pred=beam.buildTargetTokens(hyp)[:self.beam_size] pred=[torch.cat([x.view(-1) for x in p]+[zero]*(self.max_length-len(p))).view(1,-1) for p in pred] preds.append(torch.cat(pred,0).unsqueeze(0)) - - preds=torch.cat(preds,0) - return preds + + return torch.cat(preds,0) @@ -124,8 +123,7 @@ def __init__(self, size,sos,eos): def getCurrentState(self): "Get the outputs for the current timestep." - batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) - return batch + return self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) def getCurrentOrigin(self): "Get the backpointers for the current timestep." @@ -184,11 +182,12 @@ def getFinal(self): self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) self.finished.sort(key=lambda a: -a[0]) if len(self.finished) != self.size: - unfinished=[] - for i in range(self.nextYs[-1].size(0)): - if self.nextYs[-1][i] != self._eos: - s = self.scores[i] - unfinished.append((s, len(self.nextYs) - 1, i)) + unfinished = [ + (self.scores[i], len(self.nextYs) - 1, i) + for i in range(self.nextYs[-1].size(0)) + if self.nextYs[-1][i] != self._eos + ] + unfinished.sort(key=lambda a: -a[0]) self.finished+=unfinished[:self.size-len(self.finished)] return self.finished[:self.size] diff --git a/CodeBERT/code2nl/run.py b/CodeBERT/code2nl/run.py index 4265a3b..79e582f 100644 --- a/CodeBERT/code2nl/run.py +++ b/CodeBERT/code2nl/run.py @@ -104,37 +104,36 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): #source source_tokens = tokenizer.tokenize(example.source)[:args.max_source_length-2] source_tokens =[tokenizer.cls_token]+source_tokens+[tokenizer.sep_token] - source_ids = tokenizer.convert_tokens_to_ids(source_tokens) + source_ids = tokenizer.convert_tokens_to_ids(source_tokens) source_mask = [1] * (len(source_tokens)) padding_length = args.max_source_length - len(source_ids) source_ids+=[tokenizer.pad_token_id]*padding_length source_mask+=[0]*padding_length - + #target if stage=="test": target_tokens = tokenizer.tokenize("None") else: target_tokens = tokenizer.tokenize(example.target)[:args.max_target_length-2] - target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token] + target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token] target_ids = tokenizer.convert_tokens_to_ids(target_tokens) target_mask = [1] *len(target_ids) padding_length = args.max_target_length - len(target_ids) target_ids+=[tokenizer.pad_token_id]*padding_length target_mask+=[0]*padding_length - - if example_index < 5: - if stage=='train': - logger.info("*** Example ***") - logger.info("idx: {}".format(example.idx)) - - logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens])) - logger.info("source_ids: {}".format(' '.join(map(str, source_ids)))) - logger.info("source_mask: {}".format(' '.join(map(str, source_mask)))) - - logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens])) - logger.info("target_ids: {}".format(' '.join(map(str, target_ids)))) - logger.info("target_mask: {}".format(' '.join(map(str, target_mask)))) - + + if example_index < 5 and stage == 'train': + logger.info("*** Example ***") + logger.info(f"idx: {example.idx}") + + logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens])) + logger.info(f"source_ids: {' '.join(map(str, source_ids))}") + logger.info(f"source_mask: {' '.join(map(str, source_mask))}") + + logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens])) + logger.info(f"target_ids: {' '.join(map(str, target_ids))}") + logger.info(f"target_mask: {' '.join(map(str, target_mask))}") + features.append( InputFeatures( example_index, diff --git a/CodeBERT/codesearch/mrr.py b/CodeBERT/codesearch/mrr.py index c5903a5..e8931bb 100644 --- a/CodeBERT/codesearch/mrr.py +++ b/CodeBERT/codesearch/mrr.py @@ -15,7 +15,7 @@ def main(): languages = ['ruby', 'go', 'php', 'python', 'java', 'javascript'] MRR_dict = {} for language in languages: - file_dir = './results/{}'.format(language) + file_dir = f'./results/{language}' ranks = [] num_batch = 0 for file in sorted(os.listdir(file_dir)): @@ -30,10 +30,10 @@ def main(): ranks.append(rank) mean_mrr = np.mean(1.0 / np.array(ranks)) - print("{} mrr: {}".format(language, mean_mrr)) + print(f"{language} mrr: {mean_mrr}") MRR_dict[language] = mean_mrr for key, val in MRR_dict.items(): - print("{} mrr: {}".format(key, val)) + print(f"{key} mrr: {val}") if __name__ == "__main__": diff --git a/CodeBERT/codesearch/process_data.py b/CodeBERT/codesearch/process_data.py index 8f06c6b..b232547 100644 --- a/CodeBERT/codesearch/process_data.py +++ b/CodeBERT/codesearch/process_data.py @@ -17,7 +17,7 @@ def format_str(string): def preprocess_test_data(language, test_batch_size=1000): - path = os.path.join(DATA_DIR, '{}_test_0.jsonl.gz'.format(language)) + path = os.path.join(DATA_DIR, f'{language}_test_0.jsonl.gz') print(path) with gzip.open(path, 'r') as pf: data = pf.readlines() @@ -35,7 +35,7 @@ def preprocess_test_data(language, test_batch_size=1000): if len(batch_data) < test_batch_size: break # the last batch is smaller than the others, exclude. examples = [] - for d_idx, d in enumerate(batch_data): + for d in batch_data: line_a = json.loads(str(d, encoding='utf-8')) doc_token = ' '.join(line_a['docstring_tokens']) for dd in batch_data: @@ -46,10 +46,10 @@ def preprocess_test_data(language, test_batch_size=1000): example = ''.join(example) examples.append(example) - data_path = os.path.join(DATA_DIR, 'test/{}'.format(language)) + data_path = os.path.join(DATA_DIR, f'test/{language}') if not os.path.exists(data_path): os.makedirs(data_path) - file_path = os.path.join(data_path, 'batch_{}.txt'.format(batch_idx)) + file_path = os.path.join(data_path, f'batch_{batch_idx}.txt') print(file_path) with open(file_path, 'w', encoding='utf-8') as f: f.writelines('\n'.join(examples)) diff --git a/CodeBERT/codesearch/run_classifier.py b/CodeBERT/codesearch/run_classifier.py index 4364bde..9cceaef 100644 --- a/CodeBERT/codesearch/run_classifier.py +++ b/CodeBERT/codesearch/run_classifier.py @@ -134,7 +134,7 @@ def train(args, train_dataset, model, tokenizer, optimizer): if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args, model, tokenizer, checkpoint=str(global_step)) for key, value in results.items(): - tb_writer.add_scalar('eval_{}'.format(key), value, global_step) + tb_writer.add_scalar(f'eval_{key}', value, global_step) logger.info('loss %s', str(tr_loss - logging_loss)) tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step) @@ -173,7 +173,7 @@ def train(args, train_dataset, model, tokenizer, optimizer): model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) - torch.save(args, os.path.join(output_dir, 'training_{}.bin'.format(idx))) + torch.save(args, os.path.join(output_dir, f'training_{idx}.bin')) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) @@ -216,7 +216,7 @@ def evaluate(args, model, tokenizer, checkpoint=None, prefix="", mode='dev'): eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) # Eval! - logger.info("***** Running evaluation {} *****".format(prefix)) + logger.info(f"***** Running evaluation {prefix} *****") logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Batch size = %d", args.eval_batch_size) eval_loss = 0.0 @@ -252,11 +252,11 @@ def evaluate(args, model, tokenizer, checkpoint=None, prefix="", mode='dev'): if args.output_mode == "classification": preds_label = np.argmax(preds, axis=1) result = compute_metrics(eval_task, preds_label, out_label_ids) - results.update(result) + results |= result if (mode == 'dev'): output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") with open(output_eval_file, "a+") as writer: - logger.info("***** Eval results {} *****".format(prefix)) + logger.info(f"***** Eval results {prefix} *****") writer.write('evaluate %s\n' % checkpoint) for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) @@ -273,9 +273,14 @@ def evaluate(args, model, tokenizer, checkpoint=None, prefix="", mode='dev'): instance_rep = ''.join( [item.encode('ascii', 'ignore').decode('ascii') for item in instances[i]]) - writer.write(instance_rep + '' + ''.join([str(l) for l in logit]) + '\n') + writer.write( + f'{instance_rep}' + + ''.join([str(l) for l in logit]) + + '\n' + ) + for key in sorted(result.keys()): - print("%s = %s" % (key, str(result[key]))) + print(f"{key} = {str(result[key])}") return results @@ -290,12 +295,11 @@ def load_and_cache_examples(args, task, tokenizer, ttype='train'): file_name = args.dev_file.split('.')[0] elif ttype == 'test': file_name = args.test_file.split('.')[0] - cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}_{}'.format( - ttype, - file_name, - list(filter(None, args.model_name_or_path.split('/'))).pop(), - str(args.max_seq_length), - str(task))) + cached_features_file = os.path.join( + args.data_dir, + f"cached_{ttype}_{file_name}_{list(filter(None, args.model_name_or_path.split('/'))).pop()}_{str(args.max_seq_length)}_{str(task)}", + ) + # if os.path.exists(cached_features_file): try: @@ -313,15 +317,20 @@ def load_and_cache_examples(args, task, tokenizer, ttype='train'): elif ttype == 'test': examples, instances = processor.get_test_examples(args.data_dir, args.test_file) - features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode, - cls_token_at_end=bool(args.model_type in ['xlnet']), - # xlnet has a cls token at the end - cls_token=tokenizer.cls_token, - sep_token=tokenizer.sep_token, - cls_token_segment_id=2 if args.model_type in ['xlnet'] else 1, - pad_on_left=bool(args.model_type in ['xlnet']), - # pad on the left for xlnet - pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0) + features = convert_examples_to_features( + examples, + label_list, + args.max_seq_length, + tokenizer, + output_mode, + cls_token_at_end=args.model_type in ['xlnet'], + cls_token=tokenizer.cls_token, + sep_token=tokenizer.sep_token, + cls_token_segment_id=2 if args.model_type in ['xlnet'] else 1, + pad_on_left=args.model_type in ['xlnet'], + pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, + ) + if args.local_rank in [-1, 0]: logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) @@ -333,10 +342,7 @@ def load_and_cache_examples(args, task, tokenizer, ttype='train'): all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long) dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) - if (ttype == 'test'): - return dataset, instances - else: - return dataset + return (dataset, instances) if (ttype == 'test') else dataset def main(): diff --git a/CodeBERT/codesearch/utils.py b/CodeBERT/codesearch/utils.py index 41d3c60..0472f6f 100644 --- a/CodeBERT/codesearch/utils.py +++ b/CodeBERT/codesearch/utils.py @@ -79,7 +79,7 @@ def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file.""" with open(input_file, "r", encoding='utf-8') as f: lines = [] - for line in f.readlines(): + for line in f: line = line.strip().split('') if len(line) != 5: continue @@ -92,19 +92,19 @@ class CodesearchProcessor(DataProcessor): def get_train_examples(self, data_dir, train_file): """See base class.""" - logger.info("LOOKING AT {}".format(os.path.join(data_dir, train_file))) + logger.info(f"LOOKING AT {os.path.join(data_dir, train_file)}") return self._create_examples( self._read_tsv(os.path.join(data_dir, train_file)), "train") def get_dev_examples(self, data_dir, dev_file): """See base class.""" - logger.info("LOOKING AT {}".format(os.path.join(data_dir, dev_file))) + logger.info(f"LOOKING AT {os.path.join(data_dir, dev_file)}") return self._create_examples( self._read_tsv(os.path.join(data_dir, dev_file)), "dev") def get_test_examples(self, data_dir, test_file): """See base class.""" - logger.info("LOOKING AT {}".format(os.path.join(data_dir, test_file))) + logger.info(f"LOOKING AT {os.path.join(data_dir, test_file)}") return self._create_examples( self._read_tsv(os.path.join(data_dir, test_file)), "test") @@ -116,19 +116,13 @@ def _create_examples(self, lines, set_type): """Creates examples for the training and dev sets.""" examples = [] for (i, line) in enumerate(lines): - guid = "%s-%s" % (set_type, i) + guid = f"{set_type}-{i}" text_a = line[3] text_b = line[4] - if (set_type == 'test'): - label = self.get_labels()[0] - else: - label = line[0] + label = self.get_labels()[0] if (set_type == 'test') else line[0] examples.append( InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) - if (set_type == 'test'): - return examples, lines - else: - return examples + return (examples, lines) if (set_type == 'test') else examples def convert_examples_to_features(examples, label_list, max_seq_length, @@ -161,10 +155,8 @@ def convert_examples_to_features(examples, label_list, max_seq_length, # length is less than the specified length. # Account for [CLS], [SEP], [SEP] with "- 3" _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) - else: - # Account for [CLS] and [SEP] with "- 2" - if len(tokens_a) > max_seq_length - 2: - tokens_a = tokens_a[:(max_seq_length - 2)] + elif len(tokens_a) > max_seq_length - 2: + tokens_a = tokens_a[:(max_seq_length - 2)] # The convention in BERT is: # (a) For sequence pairs: @@ -228,12 +220,12 @@ def convert_examples_to_features(examples, label_list, max_seq_length, if ex_index < 5: logger.info("*** Example ***") - logger.info("guid: %s" % (example.guid)) + logger.info(f"guid: {example.guid}") logger.info("tokens: %s" % " ".join( [str(x) for x in tokens])) - logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) - logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) - logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) + logger.info(f'input_ids: {" ".join([str(x) for x in input_ids])}') + logger.info(f'input_mask: {" ".join([str(x) for x in input_mask])}') + logger.info(f'segment_ids: {" ".join([str(x) for x in segment_ids])}') logger.info("label: %s (id = %d)" % (example.label, label_id)) features.append( diff --git a/GraphCodeBERT/clonedetection/evaluator/evaluator.py b/GraphCodeBERT/clonedetection/evaluator/evaluator.py index db8417b..eddf07e 100644 --- a/GraphCodeBERT/clonedetection/evaluator/evaluator.py +++ b/GraphCodeBERT/clonedetection/evaluator/evaluator.py @@ -29,12 +29,11 @@ def calculate_scores(answers,predictions): y_trues,y_preds=[],[] for key in answers: if key not in predictions: - logging.error("Missing prediction for ({},{}) pair.".format(key[0],key[1])) + logging.error(f"Missing prediction for ({key[0]},{key[1]}) pair.") sys.exit() y_trues.append(answers[key]) y_preds.append(predictions[key]) - scores={} - scores['Recall']=recall_score(y_trues, y_preds, average='macro') + scores = {'Recall': recall_score(y_trues, y_preds, average='macro')} scores['Prediction']=precision_score(y_trues, y_preds, average='macro') scores['F1']=f1_score(y_trues, y_preds, average='macro') return scores diff --git a/GraphCodeBERT/clonedetection/model.py b/GraphCodeBERT/clonedetection/model.py index 01edb09..47fb887 100644 --- a/GraphCodeBERT/clonedetection/model.py +++ b/GraphCodeBERT/clonedetection/model.py @@ -49,7 +49,7 @@ def forward(self, inputs_ids_1,position_idx_1,attn_mask_1,inputs_ids_2,position_ nodes_to_token_mask=nodes_to_token_mask/(nodes_to_token_mask.sum(-1)+1e-10)[:,:,None] avg_embeddings=torch.einsum("abc,acd->abd",nodes_to_token_mask,inputs_embeddings) inputs_embeddings=inputs_embeddings*(~nodes_mask)[:,:,None]+avg_embeddings*nodes_mask[:,:,None] - + outputs = self.encoder.roberta(inputs_embeds=inputs_embeddings,attention_mask=attn_mask,position_ids=position_idx)[0] logits=self.classifier(outputs) prob=F.softmax(logits) diff --git a/GraphCodeBERT/clonedetection/parser/DFG.py b/GraphCodeBERT/clonedetection/parser/DFG.py index 61e0179..70d393b 100644 --- a/GraphCodeBERT/clonedetection/parser/DFG.py +++ b/GraphCodeBERT/clonedetection/parser/DFG.py @@ -13,9 +13,8 @@ def DFG_python(root_node,index_to_code,states): if_statement=['if_statement'] for_statement=['for_statement'] while_statement=['while_statement'] - do_first_statement=['for_in_clause'] def_statement=['default_parameter'] - states=states.copy() + states=states.copy() if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': idx,code=index_to_code[(root_node.start_point,root_node.end_point)] if root_node.type==code: @@ -36,19 +35,18 @@ def DFG_python(root_node,index_to_code,states): idx,code=index_to_code[index] DFG.append((code,idx,'comesFrom',[],[])) states[code]=[idx] - return sorted(DFG,key=lambda x:x[1]),states else: name_indexs=tree_to_variable_index(name,index_to_code) value_indexs=tree_to_variable_index(value,index_to_code) temp,states=DFG_python(value,index_to_code,states) - DFG+=temp + DFG+=temp for index1 in name_indexs: idx1,code1=index_to_code[index1] for index2 in value_indexs: idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) - states[code1]=[idx1] - return sorted(DFG,key=lambda x:x[1]),states + states[code1]=[idx1] + return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in assignment: if root_node.type=='for_in_clause': right_nodes=[root_node.children[-1]] @@ -61,15 +59,15 @@ def DFG_python(root_node,index_to_code,states): if len(right_nodes)!=len(left_nodes): left_nodes=[root_node.child_by_field_name('left')] right_nodes=[root_node.child_by_field_name('right')] - if len(left_nodes)==0: + if not left_nodes: left_nodes=[root_node.child_by_field_name('left')] - if len(right_nodes)==0: + if not right_nodes: right_nodes=[root_node.child_by_field_name('right')] DFG=[] for node in right_nodes: temp,states=DFG_python(node,index_to_code,states) DFG+=temp - + for left_node,right_node in zip(left_nodes,right_nodes): left_tokens_index=tree_to_variable_index(left_node,index_to_code) right_tokens_index=tree_to_variable_index(right_node,index_to_code) @@ -79,7 +77,7 @@ def DFG_python(root_node,index_to_code,states): temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], [index_to_code[x][0] for x in right_tokens_index])) states[code1]=[idx1] - DFG+=temp + DFG+=temp return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in if_statement: DFG=[] @@ -113,15 +111,15 @@ def DFG_python(root_node,index_to_code,states): return sorted(DFG,key=lambda x:x[1]),new_states elif root_node.type in for_statement: DFG=[] - for i in range(2): + for _ in range(2): right_nodes=[x for x in root_node.child_by_field_name('right').children if x.type!=','] left_nodes=[x for x in root_node.child_by_field_name('left').children if x.type!=','] if len(right_nodes)!=len(left_nodes): left_nodes=[root_node.child_by_field_name('left')] right_nodes=[root_node.child_by_field_name('right')] - if len(left_nodes)==0: + if not left_nodes: left_nodes=[root_node.child_by_field_name('left')] - if len(right_nodes)==0: + if not right_nodes: right_nodes=[root_node.child_by_field_name('right')] for node in right_nodes: temp,states=DFG_python(node,index_to_code,states) @@ -135,10 +133,10 @@ def DFG_python(root_node,index_to_code,states): temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], [index_to_code[x][0] for x in right_tokens_index])) states[code1]=[idx1] - DFG+=temp + DFG+=temp if root_node.children[-1].type=="block": temp,states=DFG_python(root_node.children[-1],index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -150,10 +148,10 @@ def DFG_python(root_node,index_to_code,states): return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in while_statement: DFG=[] - for i in range(2): + for _ in range(2): for child in root_node.children: temp,states=DFG_python(child,index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -162,9 +160,10 @@ def DFG_python(root_node,index_to_code,states): dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] - return sorted(DFG,key=lambda x:x[1]),states + return sorted(DFG,key=lambda x:x[1]),states else: DFG=[] + do_first_statement=['for_in_clause'] for child in root_node.children: if child.type in do_first_statement: temp,states=DFG_python(child,index_to_code,states) @@ -173,7 +172,7 @@ def DFG_python(root_node,index_to_code,states): if child.type not in do_first_statement: temp,states=DFG_python(child,index_to_code,states) DFG+=temp - + return sorted(DFG,key=lambda x:x[1]),states @@ -185,7 +184,6 @@ def DFG_java(root_node,index_to_code,states): for_statement=['for_statement'] enhanced_for_statement=['enhanced_for_statement'] while_statement=['while_statement'] - do_first_statement=[] states=states.copy() if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': idx,code=index_to_code[(root_node.start_point,root_node.end_point)] @@ -207,19 +205,18 @@ def DFG_java(root_node,index_to_code,states): idx,code=index_to_code[index] DFG.append((code,idx,'comesFrom',[],[])) states[code]=[idx] - return sorted(DFG,key=lambda x:x[1]),states else: name_indexs=tree_to_variable_index(name,index_to_code) value_indexs=tree_to_variable_index(value,index_to_code) temp,states=DFG_java(value,index_to_code,states) - DFG+=temp + DFG+=temp for index1 in name_indexs: idx1,code1=index_to_code[index1] for index2 in value_indexs: idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) - states[code1]=[idx1] - return sorted(DFG,key=lambda x:x[1]),states + states[code1]=[idx1] + return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in assignment: left_nodes=root_node.child_by_field_name('left') right_nodes=root_node.child_by_field_name('right') @@ -244,7 +241,7 @@ def DFG_java(root_node,index_to_code,states): idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) states[code1]=[idx1] - return sorted(DFG,key=lambda x:x[1]),states + return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in if_statement: DFG=[] current_states=states.copy() @@ -303,19 +300,19 @@ def DFG_java(root_node,index_to_code,states): value=root_node.child_by_field_name('value') body=root_node.child_by_field_name('body') DFG=[] - for i in range(2): + for _ in range(2): temp,states=DFG_java(value,index_to_code,states) - DFG+=temp + DFG+=temp name_indexs=tree_to_variable_index(name,index_to_code) - value_indexs=tree_to_variable_index(value,index_to_code) + value_indexs=tree_to_variable_index(value,index_to_code) for index1 in name_indexs: idx1,code1=index_to_code[index1] for index2 in value_indexs: idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) - states[code1]=[idx1] + states[code1]=[idx1] temp,states=DFG_java(body,index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -327,10 +324,10 @@ def DFG_java(root_node,index_to_code,states): return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in while_statement: DFG=[] - for i in range(2): + for _ in range(2): for child in root_node.children: temp,states=DFG_java(child,index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -339,9 +336,10 @@ def DFG_java(root_node,index_to_code,states): dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] - return sorted(DFG,key=lambda x:x[1]),states + return sorted(DFG,key=lambda x:x[1]),states else: DFG=[] + do_first_statement=[] for child in root_node.children: if child.type in do_first_statement: temp,states=DFG_java(child,index_to_code,states) @@ -350,7 +348,7 @@ def DFG_java(root_node,index_to_code,states): if child.type not in do_first_statement: temp,states=DFG_java(child,index_to_code,states) DFG+=temp - + return sorted(DFG,key=lambda x:x[1]),states def DFG_csharp(root_node,index_to_code,states): diff --git a/GraphCodeBERT/clonedetection/parser/utils.py b/GraphCodeBERT/clonedetection/parser/utils.py index 270fba2..61888ea 100644 --- a/GraphCodeBERT/clonedetection/parser/utils.py +++ b/GraphCodeBERT/clonedetection/parser/utils.py @@ -24,13 +24,12 @@ def remove_comments_and_docstrings(source,lang): # Remove comments: if token_type == tokenize.COMMENT: pass - # This series of conditionals removes docstrings: elif token_type == tokenize.STRING: - if prev_toktype != tokenize.INDENT: - # This is likely a docstring; double-check we're not inside an operator: - if prev_toktype != tokenize.NEWLINE: - if start_col > 0: - out += token_string + if ( + prev_toktype not in [tokenize.INDENT, tokenize.NEWLINE] + and start_col > 0 + ): + out += token_string else: out += token_string prev_toktype = token_type @@ -46,10 +45,8 @@ def remove_comments_and_docstrings(source,lang): else: def replacer(match): s = match.group(0) - if s.startswith('/'): - return " " # note: a space and not an empty string - else: - return s + return " " if s.startswith('/') else s + pattern = re.compile( r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', re.DOTALL | re.MULTILINE @@ -63,11 +60,10 @@ def replacer(match): def tree_to_token_index(root_node): if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': return [(root_node.start_point,root_node.end_point)] - else: - code_tokens=[] - for child in root_node.children: - code_tokens+=tree_to_token_index(child) - return code_tokens + code_tokens=[] + for child in root_node.children: + code_tokens+=tree_to_token_index(child) + return code_tokens def tree_to_variable_index(root_node,index_to_code): if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': diff --git a/GraphCodeBERT/clonedetection/run.py b/GraphCodeBERT/clonedetection/run.py index ae90b4b..2ca37de 100644 --- a/GraphCodeBERT/clonedetection/run.py +++ b/GraphCodeBERT/clonedetection/run.py @@ -74,19 +74,21 @@ def extract_dataflow(code, parser,lang): try: code=remove_comments_and_docstrings(code,lang) except: - pass + pass #obtain dataflow if lang=="php": - code="" + code = f"" try: - tree = parser[0].parse(bytes(code,'utf8')) - root_node = tree.root_node - tokens_index=tree_to_token_index(root_node) + tree = parser[0].parse(bytes(code,'utf8')) + root_node = tree.root_node + tokens_index=tree_to_token_index(root_node) code=code.split('\n') - code_tokens=[index_to_code_token(x,code) for x in tokens_index] - index_to_code={} - for idx,(index,code) in enumerate(zip(tokens_index,code_tokens)): - index_to_code[index]=(idx,code) + code_tokens=[index_to_code_token(x,code) for x in tokens_index] + index_to_code = { + index: (idx, code) + for idx, (index, code) in enumerate(zip(tokens_index, code_tokens)) + } + try: DFG,_=parser[1](root_node,index_to_code,{}) except: @@ -98,10 +100,7 @@ def extract_dataflow(code, parser,lang): indexs.add(d[1]) for x in d[-1]: indexs.add(x) - new_DFG=[] - for d in DFG: - if d[1] in indexs: - new_DFG.append(d) + new_DFG = [d for d in DFG if d[1] in indexs] dfg=new_DFG except: dfg=[] @@ -149,20 +148,25 @@ def convert_examples_to_features(item): #source url1,url2,label,tokenizer, args,cache,url_to_code=item parser=parsers['java'] - + for url in [url1,url2]: if url not in cache: func=url_to_code[url] - + #extract data flow code_tokens,dfg=extract_dataflow(func,parser,'java') - code_tokens=[tokenizer.tokenize('@ '+x)[1:] if idx!=0 else tokenizer.tokenize(x) for idx,x in enumerate(code_tokens)] - ori2cur_pos={} - ori2cur_pos[-1]=(0,0) + code_tokens = [ + tokenizer.tokenize(f'@ {x}')[1:] + if idx != 0 + else tokenizer.tokenize(x) + for idx, x in enumerate(code_tokens) + ] + + ori2cur_pos = {-1: (0, 0)} for i in range(len(code_tokens)): - ori2cur_pos[i]=(ori2cur_pos[i-1][1],ori2cur_pos[i-1][1]+len(code_tokens[i])) + ori2cur_pos[i]=(ori2cur_pos[i-1][1],ori2cur_pos[i-1][1]+len(code_tokens[i])) code_tokens=[y for x in code_tokens for y in x] - + #truncating code_tokens=code_tokens[:args.code_length+args.data_flow_length-3-min(len(dfg),args.data_flow_length)][:512-3] source_tokens =[tokenizer.cls_token]+code_tokens+[tokenizer.sep_token] @@ -170,27 +174,25 @@ def convert_examples_to_features(item): position_idx = [i+tokenizer.pad_token_id + 1 for i in range(len(source_tokens))] dfg=dfg[:args.code_length+args.data_flow_length-len(source_tokens)] source_tokens+=[x[0] for x in dfg] - position_idx+=[0 for x in dfg] - source_ids+=[tokenizer.unk_token_id for x in dfg] + position_idx += [0 for _ in dfg] + source_ids += [tokenizer.unk_token_id for _ in dfg] padding_length=args.code_length+args.data_flow_length-len(source_ids) position_idx+=[tokenizer.pad_token_id]*padding_length source_ids+=[tokenizer.pad_token_id]*padding_length - + #reindex - reverse_index={} - for idx,x in enumerate(dfg): - reverse_index[x[1]]=idx + reverse_index = {x[1]: idx for idx, x in enumerate(dfg)} for idx,x in enumerate(dfg): - dfg[idx]=x[:-1]+([reverse_index[i] for i in x[-1] if i in reverse_index],) + dfg[idx]=x[:-1]+([reverse_index[i] for i in x[-1] if i in reverse_index],) dfg_to_dfg=[x[-1] for x in dfg] dfg_to_code=[ori2cur_pos[x[1]] for x in dfg] length=len([tokenizer.cls_token]) - dfg_to_code=[(x[0]+length,x[1]+length) for x in dfg_to_code] + dfg_to_code=[(x[0]+length,x[1]+length) for x in dfg_to_code] cache[url]=source_tokens,source_ids,position_idx,dfg_to_code,dfg_to_dfg - - source_tokens_1,source_ids_1,position_idx_1,dfg_to_code_1,dfg_to_dfg_1=cache[url1] - source_tokens_2,source_ids_2,position_idx_2,dfg_to_code_2,dfg_to_dfg_2=cache[url2] + + source_tokens_1,source_ids_1,position_idx_1,dfg_to_code_1,dfg_to_dfg_1=cache[url1] + source_tokens_2,source_ids_2,position_idx_2,dfg_to_code_2,dfg_to_dfg_2=cache[url2] return InputFeatures(source_tokens_1,source_ids_1,position_idx_1,dfg_to_code_1,dfg_to_dfg_1, source_tokens_2,source_ids_2,position_idx_2,dfg_to_code_2,dfg_to_dfg_2, label,url1,url2) @@ -200,7 +202,7 @@ def __init__(self, tokenizer, args, file_path='train'): self.examples = [] self.args=args index_filename=file_path - + #load index logger.info("Creating features from index file at %s ", index_filename) url_to_code={} @@ -209,7 +211,7 @@ def __init__(self, tokenizer, args, file_path='train'): line=line.strip() js=json.loads(line) url_to_code[js['idx']]=js['func'] - + #load code function according to index data=[] cache={} @@ -220,35 +222,32 @@ def __init__(self, tokenizer, args, file_path='train'): url1,url2,label=line.split('\t') if url1 not in url_to_code or url2 not in url_to_code: continue - if label=='0': - label=0 - else: - label=1 + label = 0 if label=='0' else 1 data.append((url1,url2,label,tokenizer, args,cache,url_to_code)) - + #only use 10% valid data to keep best model if 'valid' in file_path: data=random.sample(data,int(len(data)*0.1)) - + #convert example to input features self.examples=[convert_examples_to_features(x) for x in tqdm(data,total=len(data))] - + if 'train' in file_path: for idx, example in enumerate(self.examples[:3]): logger.info("*** Example ***") - logger.info("idx: {}".format(idx)) - logger.info("label: {}".format(example.label)) + logger.info(f"idx: {idx}") + logger.info(f"label: {example.label}") logger.info("input_tokens_1: {}".format([x.replace('\u0120','_') for x in example.input_tokens_1])) - logger.info("input_ids_1: {}".format(' '.join(map(str, example.input_ids_1)))) - logger.info("position_idx_1: {}".format(example.position_idx_1)) - logger.info("dfg_to_code_1: {}".format(' '.join(map(str, example.dfg_to_code_1)))) - logger.info("dfg_to_dfg_1: {}".format(' '.join(map(str, example.dfg_to_dfg_1)))) - + logger.info(f"input_ids_1: {' '.join(map(str, example.input_ids_1))}") + logger.info(f"position_idx_1: {example.position_idx_1}") + logger.info(f"dfg_to_code_1: {' '.join(map(str, example.dfg_to_code_1))}") + logger.info(f"dfg_to_dfg_1: {' '.join(map(str, example.dfg_to_dfg_1))}") + logger.info("input_tokens_2: {}".format([x.replace('\u0120','_') for x in example.input_tokens_2])) - logger.info("input_ids_2: {}".format(' '.join(map(str, example.input_ids_2)))) - logger.info("position_idx_2: {}".format(example.position_idx_2)) - logger.info("dfg_to_code_2: {}".format(' '.join(map(str, example.dfg_to_code_2)))) - logger.info("dfg_to_dfg_2: {}".format(' '.join(map(str, example.dfg_to_dfg_2)))) + logger.info(f"input_ids_2: {' '.join(map(str, example.input_ids_2))}") + logger.info(f"position_idx_2: {example.position_idx_2}") + logger.info(f"dfg_to_code_2: {' '.join(map(str, example.dfg_to_code_2))}") + logger.info(f"dfg_to_dfg_2: {' '.join(map(str, example.dfg_to_dfg_2))}") def __len__(self): @@ -259,8 +258,8 @@ def __getitem__(self, item): attn_mask_1= np.zeros((self.args.code_length+self.args.data_flow_length, self.args.code_length+self.args.data_flow_length),dtype=np.bool) #calculate begin index of node and max length of input - node_index=sum([i>1 for i in self.examples[item].position_idx_1]) - max_length=sum([i!=1 for i in self.examples[item].position_idx_1]) + node_index = sum(i>1 for i in self.examples[item].position_idx_1) + max_length = sum(i!=1 for i in self.examples[item].position_idx_1) #sequence can attend to sequence attn_mask_1[:node_index,:node_index]=True #special tokens attend to all tokens @@ -277,13 +276,13 @@ def __getitem__(self, item): for a in nodes: if a+node_index1 for i in self.examples[item].position_idx_2]) - max_length=sum([i!=1 for i in self.examples[item].position_idx_2]) + node_index = sum(i>1 for i in self.examples[item].position_idx_2) + max_length = sum(i!=1 for i in self.examples[item].position_idx_2) #sequence can attend to sequence attn_mask_2[:node_index,:node_index]=True #special tokens attend to all tokens @@ -300,7 +299,7 @@ def __getitem__(self, item): for a in nodes: if a+node_index 1: loss = loss.mean() - + if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps @@ -385,34 +398,34 @@ def train(args, train_dataset, model, tokenizer): train_loss+=loss.item() if avg_loss==0: avg_loss=tr_loss - + avg_loss=round(train_loss/tr_num,5) - bar.set_description("epoch {} loss {}".format(idx,avg_loss)) - + bar.set_description(f"epoch {idx} loss {avg_loss}") + if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() - scheduler.step() + scheduler.step() global_step += 1 output_flag=True avg_loss=round(np.exp((tr_loss - logging_loss) /(global_step- tr_nb)),4) if global_step % args.save_steps == 0: results = evaluate(args, model, tokenizer, eval_when_training=True) - + # Save model checkpoint if results['eval_f1']>best_f1: best_f1=results['eval_f1'] - logger.info(" "+"*"*20) + logger.info(" "+"*"*20) logger.info(" Best f1:%s",round(best_f1,4)) logger.info(" "+"*"*20) - + checkpoint_prefix = 'checkpoint-best-f1' - output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) + output_dir = os.path.join(args.output_dir, f'{checkpoint_prefix}') if not os.path.exists(output_dir): - os.makedirs(output_dir) + os.makedirs(output_dir) model_to_save = model.module if hasattr(model,'module') else model - output_dir = os.path.join(output_dir, '{}'.format('model.bin')) + output_dir = os.path.join(output_dir, 'model.bin') torch.save(model_to_save.state_dict(), output_dir) logger.info("Saving model checkpoint to %s", output_dir) @@ -430,11 +443,10 @@ def evaluate(args, model, tokenizer, eval_when_training=False): logger.info("***** Running evaluation *****") logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Batch size = %d", args.eval_batch_size) - + eval_loss = 0.0 - nb_eval_steps = 0 model.eval() - logits=[] + logits=[] y_trues=[] for batch in eval_dataloader: (inputs_ids_1,position_idx_1,attn_mask_1, @@ -445,8 +457,6 @@ def evaluate(args, model, tokenizer, eval_when_training=False): eval_loss += lm_loss.mean().item() logits.append(logit.cpu().numpy()) y_trues.append(labels.cpu().numpy()) - nb_eval_steps += 1 - #calculate scores logits=np.concatenate(logits,0) y_trues=np.concatenate(y_trues,0) @@ -457,15 +467,15 @@ def evaluate(args, model, tokenizer, eval_when_training=False): from sklearn.metrics import recall_score recall=recall_score(y_trues, y_preds, average='macro') from sklearn.metrics import precision_score - precision=precision_score(y_trues, y_preds, average='macro') + precision=precision_score(y_trues, y_preds, average='macro') from sklearn.metrics import f1_score - f1=f1_score(y_trues, y_preds, average='macro') + f1=f1_score(y_trues, y_preds, average='macro') result = { "eval_recall": float(recall), "eval_precision": float(precision), "eval_f1": float(f1), "eval_threshold":best_threshold, - + } logger.info("***** Eval results *****") @@ -489,9 +499,8 @@ def test(args, model, tokenizer, best_threshold=0): logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Batch size = %d", args.eval_batch_size) eval_loss = 0.0 - nb_eval_steps = 0 model.eval() - logits=[] + logits=[] y_trues=[] for batch in eval_dataloader: (inputs_ids_1,position_idx_1,attn_mask_1, @@ -502,8 +511,6 @@ def test(args, model, tokenizer, best_threshold=0): eval_loss += lm_loss.mean().item() logits.append(logit.cpu().numpy()) y_trues.append(labels.cpu().numpy()) - nb_eval_steps += 1 - #output result logits=np.concatenate(logits,0) y_preds=logits[:,1]>best_threshold @@ -528,7 +535,7 @@ def main(): help="An optional input evaluation data file to evaluate the perplexity on (a text file).") parser.add_argument("--test_data_file", default=None, type=str, help="An optional input evaluation data file to evaluate the perplexity on (a text file).") - + parser.add_argument("--model_name_or_path", default=None, type=str, help="The model checkpoint for weights initialization.") @@ -538,15 +545,15 @@ def main(): help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") parser.add_argument("--code_length", default=256, type=int, - help="Optional Code input sequence length after tokenization.") + help="Optional Code input sequence length after tokenization.") parser.add_argument("--data_flow_length", default=64, type=int, - help="Optional Data Flow input sequence length after tokenization.") + help="Optional Data Flow input sequence length after tokenization.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--do_test", action='store_true', - help="Whether to run eval on the dev set.") + help="Whether to run eval on the dev set.") parser.add_argument("--evaluate_during_training", action='store_true', help="Run evaluation during training at each logging step.") @@ -589,7 +596,10 @@ def main(): # Set seed set_seed(args) - config = RobertaConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) + config = RobertaConfig.from_pretrained( + args.config_name or args.model_name_or_path + ) + config.num_labels=1 tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) model = RobertaForSequenceClassification.from_pretrained(args.model_name_or_path,config=config) @@ -601,23 +611,21 @@ def main(): train_dataset = TextDataset(tokenizer, args, file_path=args.train_data_file) train(args, train_dataset, model, tokenizer) - # Evaluation - results = {} if args.do_eval: checkpoint_prefix = 'checkpoint-best-f1/model.bin' - output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) + output_dir = os.path.join(args.output_dir, f'{checkpoint_prefix}') model.load_state_dict(torch.load(output_dir)) model.to(args.device) result=evaluate(args, model, tokenizer) - + if args.do_test: checkpoint_prefix = 'checkpoint-best-f1/model.bin' - output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) + output_dir = os.path.join(args.output_dir, f'{checkpoint_prefix}') model.load_state_dict(torch.load(output_dir)) model.to(args.device) test(args, model, tokenizer,best_threshold=0.5) - return results + return {} if __name__ == "__main__": diff --git a/GraphCodeBERT/codesearch/model.py b/GraphCodeBERT/codesearch/model.py index ecf71ae..3de5c8c 100644 --- a/GraphCodeBERT/codesearch/model.py +++ b/GraphCodeBERT/codesearch/model.py @@ -8,17 +8,16 @@ def __init__(self, encoder): self.encoder = encoder def forward(self, code_inputs=None, attn_mask=None,position_idx=None, nl_inputs=None): - if code_inputs is not None: - nodes_mask=position_idx.eq(0) - token_mask=position_idx.ge(2) - inputs_embeddings=self.encoder.embeddings.word_embeddings(code_inputs) - nodes_to_token_mask=nodes_mask[:,:,None]&token_mask[:,None,:]&attn_mask - nodes_to_token_mask=nodes_to_token_mask/(nodes_to_token_mask.sum(-1)+1e-10)[:,:,None] - avg_embeddings=torch.einsum("abc,acd->abd",nodes_to_token_mask,inputs_embeddings) - inputs_embeddings=inputs_embeddings*(~nodes_mask)[:,:,None]+avg_embeddings*nodes_mask[:,:,None] - return self.encoder(inputs_embeds=inputs_embeddings,attention_mask=attn_mask,position_ids=position_idx)[1] - else: + if code_inputs is None: return self.encoder(nl_inputs,attention_mask=nl_inputs.ne(1))[1] + nodes_mask=position_idx.eq(0) + token_mask=position_idx.ge(2) + inputs_embeddings=self.encoder.embeddings.word_embeddings(code_inputs) + nodes_to_token_mask=nodes_mask[:,:,None]&token_mask[:,None,:]&attn_mask + nodes_to_token_mask=nodes_to_token_mask/(nodes_to_token_mask.sum(-1)+1e-10)[:,:,None] + avg_embeddings=torch.einsum("abc,acd->abd",nodes_to_token_mask,inputs_embeddings) + inputs_embeddings=inputs_embeddings*(~nodes_mask)[:,:,None]+avg_embeddings*nodes_mask[:,:,None] + return self.encoder(inputs_embeds=inputs_embeddings,attention_mask=attn_mask,position_ids=position_idx)[1] diff --git a/GraphCodeBERT/codesearch/parser/DFG.py b/GraphCodeBERT/codesearch/parser/DFG.py index 61e0179..70d393b 100644 --- a/GraphCodeBERT/codesearch/parser/DFG.py +++ b/GraphCodeBERT/codesearch/parser/DFG.py @@ -13,9 +13,8 @@ def DFG_python(root_node,index_to_code,states): if_statement=['if_statement'] for_statement=['for_statement'] while_statement=['while_statement'] - do_first_statement=['for_in_clause'] def_statement=['default_parameter'] - states=states.copy() + states=states.copy() if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': idx,code=index_to_code[(root_node.start_point,root_node.end_point)] if root_node.type==code: @@ -36,19 +35,18 @@ def DFG_python(root_node,index_to_code,states): idx,code=index_to_code[index] DFG.append((code,idx,'comesFrom',[],[])) states[code]=[idx] - return sorted(DFG,key=lambda x:x[1]),states else: name_indexs=tree_to_variable_index(name,index_to_code) value_indexs=tree_to_variable_index(value,index_to_code) temp,states=DFG_python(value,index_to_code,states) - DFG+=temp + DFG+=temp for index1 in name_indexs: idx1,code1=index_to_code[index1] for index2 in value_indexs: idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) - states[code1]=[idx1] - return sorted(DFG,key=lambda x:x[1]),states + states[code1]=[idx1] + return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in assignment: if root_node.type=='for_in_clause': right_nodes=[root_node.children[-1]] @@ -61,15 +59,15 @@ def DFG_python(root_node,index_to_code,states): if len(right_nodes)!=len(left_nodes): left_nodes=[root_node.child_by_field_name('left')] right_nodes=[root_node.child_by_field_name('right')] - if len(left_nodes)==0: + if not left_nodes: left_nodes=[root_node.child_by_field_name('left')] - if len(right_nodes)==0: + if not right_nodes: right_nodes=[root_node.child_by_field_name('right')] DFG=[] for node in right_nodes: temp,states=DFG_python(node,index_to_code,states) DFG+=temp - + for left_node,right_node in zip(left_nodes,right_nodes): left_tokens_index=tree_to_variable_index(left_node,index_to_code) right_tokens_index=tree_to_variable_index(right_node,index_to_code) @@ -79,7 +77,7 @@ def DFG_python(root_node,index_to_code,states): temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], [index_to_code[x][0] for x in right_tokens_index])) states[code1]=[idx1] - DFG+=temp + DFG+=temp return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in if_statement: DFG=[] @@ -113,15 +111,15 @@ def DFG_python(root_node,index_to_code,states): return sorted(DFG,key=lambda x:x[1]),new_states elif root_node.type in for_statement: DFG=[] - for i in range(2): + for _ in range(2): right_nodes=[x for x in root_node.child_by_field_name('right').children if x.type!=','] left_nodes=[x for x in root_node.child_by_field_name('left').children if x.type!=','] if len(right_nodes)!=len(left_nodes): left_nodes=[root_node.child_by_field_name('left')] right_nodes=[root_node.child_by_field_name('right')] - if len(left_nodes)==0: + if not left_nodes: left_nodes=[root_node.child_by_field_name('left')] - if len(right_nodes)==0: + if not right_nodes: right_nodes=[root_node.child_by_field_name('right')] for node in right_nodes: temp,states=DFG_python(node,index_to_code,states) @@ -135,10 +133,10 @@ def DFG_python(root_node,index_to_code,states): temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], [index_to_code[x][0] for x in right_tokens_index])) states[code1]=[idx1] - DFG+=temp + DFG+=temp if root_node.children[-1].type=="block": temp,states=DFG_python(root_node.children[-1],index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -150,10 +148,10 @@ def DFG_python(root_node,index_to_code,states): return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in while_statement: DFG=[] - for i in range(2): + for _ in range(2): for child in root_node.children: temp,states=DFG_python(child,index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -162,9 +160,10 @@ def DFG_python(root_node,index_to_code,states): dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] - return sorted(DFG,key=lambda x:x[1]),states + return sorted(DFG,key=lambda x:x[1]),states else: DFG=[] + do_first_statement=['for_in_clause'] for child in root_node.children: if child.type in do_first_statement: temp,states=DFG_python(child,index_to_code,states) @@ -173,7 +172,7 @@ def DFG_python(root_node,index_to_code,states): if child.type not in do_first_statement: temp,states=DFG_python(child,index_to_code,states) DFG+=temp - + return sorted(DFG,key=lambda x:x[1]),states @@ -185,7 +184,6 @@ def DFG_java(root_node,index_to_code,states): for_statement=['for_statement'] enhanced_for_statement=['enhanced_for_statement'] while_statement=['while_statement'] - do_first_statement=[] states=states.copy() if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': idx,code=index_to_code[(root_node.start_point,root_node.end_point)] @@ -207,19 +205,18 @@ def DFG_java(root_node,index_to_code,states): idx,code=index_to_code[index] DFG.append((code,idx,'comesFrom',[],[])) states[code]=[idx] - return sorted(DFG,key=lambda x:x[1]),states else: name_indexs=tree_to_variable_index(name,index_to_code) value_indexs=tree_to_variable_index(value,index_to_code) temp,states=DFG_java(value,index_to_code,states) - DFG+=temp + DFG+=temp for index1 in name_indexs: idx1,code1=index_to_code[index1] for index2 in value_indexs: idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) - states[code1]=[idx1] - return sorted(DFG,key=lambda x:x[1]),states + states[code1]=[idx1] + return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in assignment: left_nodes=root_node.child_by_field_name('left') right_nodes=root_node.child_by_field_name('right') @@ -244,7 +241,7 @@ def DFG_java(root_node,index_to_code,states): idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) states[code1]=[idx1] - return sorted(DFG,key=lambda x:x[1]),states + return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in if_statement: DFG=[] current_states=states.copy() @@ -303,19 +300,19 @@ def DFG_java(root_node,index_to_code,states): value=root_node.child_by_field_name('value') body=root_node.child_by_field_name('body') DFG=[] - for i in range(2): + for _ in range(2): temp,states=DFG_java(value,index_to_code,states) - DFG+=temp + DFG+=temp name_indexs=tree_to_variable_index(name,index_to_code) - value_indexs=tree_to_variable_index(value,index_to_code) + value_indexs=tree_to_variable_index(value,index_to_code) for index1 in name_indexs: idx1,code1=index_to_code[index1] for index2 in value_indexs: idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) - states[code1]=[idx1] + states[code1]=[idx1] temp,states=DFG_java(body,index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -327,10 +324,10 @@ def DFG_java(root_node,index_to_code,states): return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in while_statement: DFG=[] - for i in range(2): + for _ in range(2): for child in root_node.children: temp,states=DFG_java(child,index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -339,9 +336,10 @@ def DFG_java(root_node,index_to_code,states): dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] - return sorted(DFG,key=lambda x:x[1]),states + return sorted(DFG,key=lambda x:x[1]),states else: DFG=[] + do_first_statement=[] for child in root_node.children: if child.type in do_first_statement: temp,states=DFG_java(child,index_to_code,states) @@ -350,7 +348,7 @@ def DFG_java(root_node,index_to_code,states): if child.type not in do_first_statement: temp,states=DFG_java(child,index_to_code,states) DFG+=temp - + return sorted(DFG,key=lambda x:x[1]),states def DFG_csharp(root_node,index_to_code,states): diff --git a/GraphCodeBERT/codesearch/parser/utils.py b/GraphCodeBERT/codesearch/parser/utils.py index 270fba2..61888ea 100644 --- a/GraphCodeBERT/codesearch/parser/utils.py +++ b/GraphCodeBERT/codesearch/parser/utils.py @@ -24,13 +24,12 @@ def remove_comments_and_docstrings(source,lang): # Remove comments: if token_type == tokenize.COMMENT: pass - # This series of conditionals removes docstrings: elif token_type == tokenize.STRING: - if prev_toktype != tokenize.INDENT: - # This is likely a docstring; double-check we're not inside an operator: - if prev_toktype != tokenize.NEWLINE: - if start_col > 0: - out += token_string + if ( + prev_toktype not in [tokenize.INDENT, tokenize.NEWLINE] + and start_col > 0 + ): + out += token_string else: out += token_string prev_toktype = token_type @@ -46,10 +45,8 @@ def remove_comments_and_docstrings(source,lang): else: def replacer(match): s = match.group(0) - if s.startswith('/'): - return " " # note: a space and not an empty string - else: - return s + return " " if s.startswith('/') else s + pattern = re.compile( r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', re.DOTALL | re.MULTILINE @@ -63,11 +60,10 @@ def replacer(match): def tree_to_token_index(root_node): if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': return [(root_node.start_point,root_node.end_point)] - else: - code_tokens=[] - for child in root_node.children: - code_tokens+=tree_to_token_index(child) - return code_tokens + code_tokens=[] + for child in root_node.children: + code_tokens+=tree_to_token_index(child) + return code_tokens def tree_to_variable_index(root_node,index_to_code): if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': diff --git a/GraphCodeBERT/codesearch/run.py b/GraphCodeBERT/codesearch/run.py index cca171a..ff26287 100644 --- a/GraphCodeBERT/codesearch/run.py +++ b/GraphCodeBERT/codesearch/run.py @@ -70,19 +70,21 @@ def extract_dataflow(code, parser,lang): try: code=remove_comments_and_docstrings(code,lang) except: - pass + pass #obtain dataflow if lang=="php": - code="" + code = f"" try: - tree = parser[0].parse(bytes(code,'utf8')) - root_node = tree.root_node - tokens_index=tree_to_token_index(root_node) + tree = parser[0].parse(bytes(code,'utf8')) + root_node = tree.root_node + tokens_index=tree_to_token_index(root_node) code=code.split('\n') - code_tokens=[index_to_code_token(x,code) for x in tokens_index] - index_to_code={} - for idx,(index,code) in enumerate(zip(tokens_index,code_tokens)): - index_to_code[index]=(idx,code) + code_tokens=[index_to_code_token(x,code) for x in tokens_index] + index_to_code = { + index: (idx, code) + for idx, (index, code) in enumerate(zip(tokens_index, code_tokens)) + } + try: DFG,_=parser[1](root_node,index_to_code,{}) except: @@ -94,10 +96,7 @@ def extract_dataflow(code, parser,lang): indexs.add(d[1]) for x in d[-1]: indexs.add(x) - new_DFG=[] - for d in DFG: - if d[1] in indexs: - new_DFG.append(d) + new_DFG = [d for d in DFG if d[1] in indexs] dfg=new_DFG except: dfg=[] @@ -132,12 +131,15 @@ def convert_examples_to_features(item): parser=parsers[args.lang] #extract data flow code_tokens,dfg=extract_dataflow(js['original_string'],parser,args.lang) - code_tokens=[tokenizer.tokenize('@ '+x)[1:] if idx!=0 else tokenizer.tokenize(x) for idx,x in enumerate(code_tokens)] - ori2cur_pos={} - ori2cur_pos[-1]=(0,0) + code_tokens = [ + tokenizer.tokenize(f'@ {x}')[1:] if idx != 0 else tokenizer.tokenize(x) + for idx, x in enumerate(code_tokens) + ] + + ori2cur_pos = {-1: (0, 0)} for i in range(len(code_tokens)): - ori2cur_pos[i]=(ori2cur_pos[i-1][1],ori2cur_pos[i-1][1]+len(code_tokens[i])) - code_tokens=[y for x in code_tokens for y in x] + ori2cur_pos[i]=(ori2cur_pos[i-1][1],ori2cur_pos[i-1][1]+len(code_tokens[i])) + code_tokens=[y for x in code_tokens for y in x] #truncating code_tokens=code_tokens[:args.code_length+args.data_flow_length-2-min(len(dfg),args.data_flow_length)] code_tokens =[tokenizer.cls_token]+code_tokens+[tokenizer.sep_token] @@ -145,21 +147,19 @@ def convert_examples_to_features(item): position_idx = [i+tokenizer.pad_token_id + 1 for i in range(len(code_tokens))] dfg=dfg[:args.code_length+args.data_flow_length-len(code_tokens)] code_tokens+=[x[0] for x in dfg] - position_idx+=[0 for x in dfg] - code_ids+=[tokenizer.unk_token_id for x in dfg] + position_idx += [0 for _ in dfg] + code_ids += [tokenizer.unk_token_id for _ in dfg] padding_length=args.code_length+args.data_flow_length-len(code_ids) position_idx+=[tokenizer.pad_token_id]*padding_length - code_ids+=[tokenizer.pad_token_id]*padding_length + code_ids+=[tokenizer.pad_token_id]*padding_length #reindex - reverse_index={} - for idx,x in enumerate(dfg): - reverse_index[x[1]]=idx + reverse_index = {x[1]: idx for idx, x in enumerate(dfg)} for idx,x in enumerate(dfg): - dfg[idx]=x[:-1]+([reverse_index[i] for i in x[-1] if i in reverse_index],) + dfg[idx]=x[:-1]+([reverse_index[i] for i in x[-1] if i in reverse_index],) dfg_to_dfg=[x[-1] for x in dfg] dfg_to_code=[ori2cur_pos[x[1]] for x in dfg] length=len([tokenizer.cls_token]) - dfg_to_code=[(x[0]+length,x[1]+length) for x in dfg_to_code] + dfg_to_code=[(x[0]+length,x[1]+length) for x in dfg_to_code] #nl nl=' '.join(js['docstring_tokens']) nl_tokens=tokenizer.tokenize(nl)[:args.nl_length-2] @@ -167,14 +167,14 @@ def convert_examples_to_features(item): nl_ids = tokenizer.convert_tokens_to_ids(nl_tokens) padding_length = args.nl_length - len(nl_ids) nl_ids+=[tokenizer.pad_token_id]*padding_length - + return InputFeatures(code_tokens,code_ids,position_idx,dfg_to_code,dfg_to_dfg,nl_tokens,nl_ids,js['url']) class TextDataset(Dataset): def __init__(self, tokenizer, args, file_path=None,pool=None): self.args=args prefix=file_path.split('/')[-1][:-6] - cache_file=args.output_dir+'/'+prefix+'.pkl' + cache_file = f'{args.output_dir}/{prefix}.pkl' if os.path.exists(cache_file): self.examples=pickle.load(open(cache_file,'rb')) else: @@ -187,18 +187,18 @@ def __init__(self, tokenizer, args, file_path=None,pool=None): data.append((js,tokenizer,args)) self.examples=pool.map(convert_examples_to_features, tqdm(data,total=len(data))) pickle.dump(self.examples,open(cache_file,'wb')) - + if 'train' in file_path: for idx, example in enumerate(self.examples[:3]): logger.info("*** Example ***") - logger.info("idx: {}".format(idx)) + logger.info(f"idx: {idx}") logger.info("code_tokens: {}".format([x.replace('\u0120','_') for x in example.code_tokens])) - logger.info("code_ids: {}".format(' '.join(map(str, example.code_ids)))) - logger.info("position_idx: {}".format(example.position_idx)) - logger.info("dfg_to_code: {}".format(' '.join(map(str, example.dfg_to_code)))) - logger.info("dfg_to_dfg: {}".format(' '.join(map(str, example.dfg_to_dfg)))) + logger.info(f"code_ids: {' '.join(map(str, example.code_ids))}") + logger.info(f"position_idx: {example.position_idx}") + logger.info(f"dfg_to_code: {' '.join(map(str, example.dfg_to_code))}") + logger.info(f"dfg_to_dfg: {' '.join(map(str, example.dfg_to_dfg))}") logger.info("nl_tokens: {}".format([x.replace('\u0120','_') for x in example.nl_tokens])) - logger.info("nl_ids: {}".format(' '.join(map(str, example.nl_ids)))) + logger.info(f"nl_ids: {' '.join(map(str, example.nl_ids))}") def __len__(self): return len(self.examples) @@ -208,8 +208,8 @@ def __getitem__(self, item): attn_mask=np.zeros((self.args.code_length+self.args.data_flow_length, self.args.code_length+self.args.data_flow_length),dtype=np.bool) #calculate begin index of node and max length of input - node_index=sum([i>1 for i in self.examples[item].position_idx]) - max_length=sum([i!=1 for i in self.examples[item].position_idx]) + node_index = sum(i>1 for i in self.examples[item].position_idx) + max_length = sum(i!=1 for i in self.examples[item].position_idx) #sequence can attend to sequence attn_mask[:node_index,:node_index]=True #special tokens attend to all tokens @@ -226,7 +226,7 @@ def __getitem__(self, item): for a in nodes: if a+node_index 1: model = torch.nn.DataParallel(model) @@ -264,61 +264,61 @@ def train(args, model, tokenizer,pool): logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size//args.n_gpu) logger.info(" Total train batch size = %d", args.train_batch_size) logger.info(" Total optimization steps = %d", len(train_dataloader)*args.num_train_epochs) - + # model.resize_token_embeddings(len(tokenizer)) model.zero_grad() - + model.train() - tr_num,tr_loss,best_mrr=0,0,0 + tr_num,tr_loss,best_mrr=0,0,0 for idx in range(args.num_train_epochs): for step,batch in enumerate(train_dataloader): #get inputs - code_inputs = batch[0].to(args.device) + code_inputs = batch[0].to(args.device) attn_mask = batch[1].to(args.device) position_idx = batch[2].to(args.device) nl_inputs = batch[3].to(args.device) #get code and nl vectors code_vec = model(code_inputs=code_inputs,attn_mask=attn_mask,position_idx=position_idx) nl_vec = model(nl_inputs=nl_inputs) - + #calculate scores and loss scores=torch.einsum("ab,cb->ac",nl_vec,code_vec) loss_fct = CrossEntropyLoss() loss = loss_fct(scores, torch.arange(code_inputs.size(0), device=scores.device)) - + #report loss tr_loss += loss.item() tr_num+=1 if (step+1)% 100==0: - logger.info("epoch {} step {} loss {}".format(idx,step+1,round(tr_loss/tr_num,5))) + logger.info(f"epoch {idx} step {step + 1} loss {round(tr_loss / tr_num, 5)}") tr_loss=0 tr_num=0 - + #backward loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() optimizer.zero_grad() scheduler.step() - + #evaluate results = evaluate(args, model, tokenizer,args.eval_data_file, pool, eval_when_training=True) for key, value in results.items(): logger.info(" %s = %s", key, round(value,4)) - + #save best model if results['eval_mrr']>best_mrr: best_mrr=results['eval_mrr'] - logger.info(" "+"*"*20) + logger.info(" "+"*"*20) logger.info(" Best mrr:%s",round(best_mrr,4)) logger.info(" "+"*"*20) checkpoint_prefix = 'checkpoint-best-mrr' - output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) + output_dir = os.path.join(args.output_dir, f'{checkpoint_prefix}') if not os.path.exists(output_dir): - os.makedirs(output_dir) + os.makedirs(output_dir) model_to_save = model.module if hasattr(model,'module') else model - output_dir = os.path.join(output_dir, '{}'.format('model.bin')) + output_dir = os.path.join(output_dir, 'model.bin') torch.save(model_to_save.state_dict(), output_dir) logger.info("Saving model checkpoint to %s", output_dir) @@ -327,7 +327,7 @@ def evaluate(args, model, tokenizer,file_name,pool, eval_when_training=False): query_dataset = TextDataset(tokenizer, args, file_name, pool) query_sampler = SequentialSampler(query_dataset) query_dataloader = DataLoader(query_dataset, sampler=query_sampler, batch_size=args.eval_batch_size,num_workers=4) - + code_dataset = TextDataset(tokenizer, args, args.codebase_file, pool) code_sampler = SequentialSampler(code_dataset) code_dataloader = DataLoader(code_dataset, sampler=code_sampler, batch_size=args.eval_batch_size,num_workers=4) @@ -342,9 +342,9 @@ def evaluate(args, model, tokenizer,file_name,pool, eval_when_training=False): logger.info(" Num codes = %d", len(code_dataset)) logger.info(" Batch size = %d", args.eval_batch_size) - + model.eval() - code_vecs=[] + code_vecs=[] nl_vecs=[] for batch in query_dataloader: nl_inputs = batch[3].to(args.device) @@ -358,23 +358,17 @@ def evaluate(args, model, tokenizer,file_name,pool, eval_when_training=False): position_idx =batch[2].to(args.device) with torch.no_grad(): code_vec= model(code_inputs=code_inputs, attn_mask=attn_mask,position_idx=position_idx) - code_vecs.append(code_vec.cpu().numpy()) - model.train() + code_vecs.append(code_vec.cpu().numpy()) + model.train() code_vecs=np.concatenate(code_vecs,0) nl_vecs=np.concatenate(nl_vecs,0) scores=np.matmul(nl_vecs,code_vecs.T) - + sort_ids=np.argsort(scores, axis=-1, kind='quicksort', order=None)[:,::-1] - - nl_urls=[] - code_urls=[] - for example in query_dataset.examples: - nl_urls.append(example.url) - - for example in code_dataset.examples: - code_urls.append(example.url) - + + nl_urls = [example.url for example in query_dataset.examples] + code_urls = [example.url for example in code_dataset.examples] ranks=[] for url, sort_id in zip(nl_urls,sort_ids): rank=0 @@ -388,12 +382,8 @@ def evaluate(args, model, tokenizer,file_name,pool, eval_when_training=False): ranks.append(1/rank) else: ranks.append(0) - - result = { - "eval_mrr":float(np.mean(ranks)) - } - return result + return {"eval_mrr": float(np.mean(ranks))} @@ -411,31 +401,31 @@ def main(): help="An optional input test data file to test the MRR(a josnl file).") parser.add_argument("--codebase_file", default=None, type=str, help="An optional input test data file to codebase (a jsonl file).") - + parser.add_argument("--lang", default=None, type=str, help="language.") - + parser.add_argument("--model_name_or_path", default=None, type=str, help="The model checkpoint for weights initialization.") parser.add_argument("--config_name", default="", type=str, help="Optional pretrained config name or path if not the same as model_name_or_path") parser.add_argument("--tokenizer_name", default="", type=str, help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") - + parser.add_argument("--nl_length", default=128, type=int, - help="Optional NL input sequence length after tokenization.") + help="Optional NL input sequence length after tokenization.") parser.add_argument("--code_length", default=256, type=int, - help="Optional Code input sequence length after tokenization.") + help="Optional Code input sequence length after tokenization.") parser.add_argument("--data_flow_length", default=64, type=int, help="Optional Data Flow input sequence length after tokenization.") - + parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--do_test", action='store_true', help="Whether to run eval on the test set.") - + parser.add_argument("--train_batch_size", default=4, type=int, help="Batch size for training.") @@ -450,12 +440,12 @@ def main(): parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") - + pool = multiprocessing.Pool(cpu_cont) - + #print arguments args = parser.parse_args() - + #set log logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S',level=logging.INFO ) @@ -464,45 +454,46 @@ def main(): args.n_gpu = torch.cuda.device_count() args.device = device logger.info("device: %s, n_gpu: %s",device, args.n_gpu) - + # Set seed set_seed(args.seed) #build model - config = RobertaConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) + config = RobertaConfig.from_pretrained( + args.config_name or args.model_name_or_path + ) + tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) - model = RobertaModel.from_pretrained(args.model_name_or_path) + model = RobertaModel.from_pretrained(args.model_name_or_path) model=Model(model) logger.info("Training/evaluation parameters %s", args) model.to(args.device) - + # Training if args.do_train: train(args, model, tokenizer, pool) - # Evaluation - results = {} if args.do_eval: checkpoint_prefix = 'checkpoint-best-mrr/model.bin' - output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) - model.load_state_dict(torch.load(output_dir),strict=False) + output_dir = os.path.join(args.output_dir, f'{checkpoint_prefix}') + model.load_state_dict(torch.load(output_dir),strict=False) model.to(args.device) result=evaluate(args, model, tokenizer,args.eval_data_file, pool) logger.info("***** Eval results *****") for key in sorted(result.keys()): logger.info(" %s = %s", key, str(round(result[key],4))) - + if args.do_test: checkpoint_prefix = 'checkpoint-best-mrr/model.bin' - output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) - model.load_state_dict(torch.load(output_dir),strict=False) + output_dir = os.path.join(args.output_dir, f'{checkpoint_prefix}') + model.load_state_dict(torch.load(output_dir),strict=False) model.to(args.device) result=evaluate(args, model, tokenizer,args.test_data_file, pool) logger.info("***** Eval results *****") for key in sorted(result.keys()): logger.info(" %s = %s", key, str(round(result[key],4))) - return results + return {} if __name__ == "__main__": diff --git a/GraphCodeBERT/refinement/bleu.py b/GraphCodeBERT/refinement/bleu.py index 47e1335..5b57803 100644 --- a/GraphCodeBERT/refinement/bleu.py +++ b/GraphCodeBERT/refinement/bleu.py @@ -39,7 +39,7 @@ def _get_ngrams(segment, max_order): """ ngram_counts = collections.Counter() for order in range(1, max_order + 1): - for i in range(0, len(segment) - order + 1): + for i in range(len(segment) - order + 1): ngram = tuple(segment[i:i+order]) ngram_counts[ngram] += 1 return ngram_counts @@ -83,16 +83,15 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4, possible_matches_by_order[order-1] += possible_matches precisions = [0] * max_order - for i in range(0, max_order): + for i in range(max_order): if smooth: precisions[i] = ((matches_by_order[i] + 1.) / (possible_matches_by_order[i] + 1.)) + elif possible_matches_by_order[i] > 0: + precisions[i] = (float(matches_by_order[i]) / + possible_matches_by_order[i]) else: - if possible_matches_by_order[i] > 0: - precisions[i] = (float(matches_by_order[i]) / - possible_matches_by_order[i]) - else: - precisions[i] = 0.0 + precisions[i] = 0.0 if min(precisions) > 0: p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) @@ -102,33 +101,26 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4, ratio = float(translation_length) / reference_length - if ratio > 1.0: - bp = 1. - else: - bp = math.exp(1 - 1. / ratio) - + bp = 1. if ratio > 1.0 else math.exp(1 - 1. / ratio) bleu = geo_mean * bp return (bleu, precisions, bp, ratio, translation_length, reference_length) def _bleu(ref_file, trans_file, subword_option=None): - max_order = 4 - smooth = True - ref_files = [ref_file] - reference_text = [] - for reference_filename in ref_files: - with open(reference_filename) as fh: - reference_text.append(fh.readlines()) - per_segment_references = [] - for references in zip(*reference_text): - reference_list = [] - for reference in references: - reference_list.append(reference.strip().split()) - per_segment_references.append(reference_list) - translations = [] - with open(trans_file) as fh: - for line in fh: - translations.append(line.strip().split()) - bleu_score, _, _, _, _, _ = compute_bleu(per_segment_references, translations, max_order, smooth) - return round(100 * bleu_score,2) \ No newline at end of file + max_order = 4 + smooth = True + ref_files = [ref_file] + reference_text = [] + for reference_filename in ref_files: + with open(reference_filename) as fh: + reference_text.append(fh.readlines()) + per_segment_references = [] + for references in zip(*reference_text): + reference_list = [reference.strip().split() for reference in references] + per_segment_references.append(reference_list) + translations = [] + with open(trans_file) as fh: + translations.extend(line.strip().split() for line in fh) + bleu_score, _, _, _, _, _ = compute_bleu(per_segment_references, translations, max_order, smooth) + return round(100 * bleu_score,2) \ No newline at end of file diff --git a/GraphCodeBERT/refinement/model.py b/GraphCodeBERT/refinement/model.py index 433c54f..de3028f 100644 --- a/GraphCodeBERT/refinement/model.py +++ b/GraphCodeBERT/refinement/model.py @@ -54,13 +54,13 @@ def tie_weights(self): def forward(self, source_ids,source_mask,position_idx,attn_mask,target_ids=None,target_mask=None,args=None): #embedding nodes_mask=position_idx.eq(0) - token_mask=position_idx.ge(2) + token_mask=position_idx.ge(2) inputs_embeddings=self.encoder.embeddings.word_embeddings(source_ids) nodes_to_token_mask=nodes_mask[:,:,None]&token_mask[:,None,:]&attn_mask nodes_to_token_mask=nodes_to_token_mask/(nodes_to_token_mask.sum(-1)+1e-10)[:,:,None] avg_embeddings=torch.einsum("abc,acd->abd",nodes_to_token_mask,inputs_embeddings) inputs_embeddings=inputs_embeddings*(~nodes_mask)[:,:,None]+avg_embeddings*nodes_mask[:,:,None] - + outputs = self.encoder(inputs_embeds=inputs_embeddings,attention_mask=attn_mask,position_ids=position_idx) encoder_output = outputs[0].permute([1,0,2]).contiguous() #source_mask=token_mask.float() @@ -83,8 +83,8 @@ def forward(self, source_ids,source_mask,position_idx,attn_mask,target_ids=None, return outputs else: #Predict - preds=[] - zero=torch.cuda.LongTensor(1).fill_(0) + preds=[] + zero=torch.cuda.LongTensor(1).fill_(0) for i in range(source_ids.shape[0]): context=encoder_output[:,i:i+1] context_mask=source_mask[i:i+1,:] @@ -108,9 +108,8 @@ def forward(self, source_ids,source_mask,position_idx,attn_mask,target_ids=None, pred=beam.buildTargetTokens(hyp)[:self.beam_size] pred=[torch.cat([x.view(-1) for x in p]+[zero]*(self.max_length-len(p))).view(1,-1) for p in pred] preds.append(torch.cat(pred,0).unsqueeze(0)) - - preds=torch.cat(preds,0) - return preds + + return torch.cat(preds,0) @@ -134,8 +133,7 @@ def __init__(self, size,sos,eos): def getCurrentState(self): "Get the outputs for the current timestep." - batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) - return batch + return self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) def getCurrentOrigin(self): "Get the backpointers for the current timestep." @@ -194,11 +192,12 @@ def getFinal(self): self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) self.finished.sort(key=lambda a: -a[0]) if len(self.finished) != self.size: - unfinished=[] - for i in range(self.nextYs[-1].size(0)): - if self.nextYs[-1][i] != self._eos: - s = self.scores[i] - unfinished.append((s, len(self.nextYs) - 1, i)) + unfinished = [ + (self.scores[i], len(self.nextYs) - 1, i) + for i in range(self.nextYs[-1].size(0)) + if self.nextYs[-1][i] != self._eos + ] + unfinished.sort(key=lambda a: -a[0]) self.finished+=unfinished[:self.size-len(self.finished)] return self.finished[:self.size] diff --git a/GraphCodeBERT/refinement/parser/DFG.py b/GraphCodeBERT/refinement/parser/DFG.py index 61e0179..70d393b 100644 --- a/GraphCodeBERT/refinement/parser/DFG.py +++ b/GraphCodeBERT/refinement/parser/DFG.py @@ -13,9 +13,8 @@ def DFG_python(root_node,index_to_code,states): if_statement=['if_statement'] for_statement=['for_statement'] while_statement=['while_statement'] - do_first_statement=['for_in_clause'] def_statement=['default_parameter'] - states=states.copy() + states=states.copy() if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': idx,code=index_to_code[(root_node.start_point,root_node.end_point)] if root_node.type==code: @@ -36,19 +35,18 @@ def DFG_python(root_node,index_to_code,states): idx,code=index_to_code[index] DFG.append((code,idx,'comesFrom',[],[])) states[code]=[idx] - return sorted(DFG,key=lambda x:x[1]),states else: name_indexs=tree_to_variable_index(name,index_to_code) value_indexs=tree_to_variable_index(value,index_to_code) temp,states=DFG_python(value,index_to_code,states) - DFG+=temp + DFG+=temp for index1 in name_indexs: idx1,code1=index_to_code[index1] for index2 in value_indexs: idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) - states[code1]=[idx1] - return sorted(DFG,key=lambda x:x[1]),states + states[code1]=[idx1] + return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in assignment: if root_node.type=='for_in_clause': right_nodes=[root_node.children[-1]] @@ -61,15 +59,15 @@ def DFG_python(root_node,index_to_code,states): if len(right_nodes)!=len(left_nodes): left_nodes=[root_node.child_by_field_name('left')] right_nodes=[root_node.child_by_field_name('right')] - if len(left_nodes)==0: + if not left_nodes: left_nodes=[root_node.child_by_field_name('left')] - if len(right_nodes)==0: + if not right_nodes: right_nodes=[root_node.child_by_field_name('right')] DFG=[] for node in right_nodes: temp,states=DFG_python(node,index_to_code,states) DFG+=temp - + for left_node,right_node in zip(left_nodes,right_nodes): left_tokens_index=tree_to_variable_index(left_node,index_to_code) right_tokens_index=tree_to_variable_index(right_node,index_to_code) @@ -79,7 +77,7 @@ def DFG_python(root_node,index_to_code,states): temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], [index_to_code[x][0] for x in right_tokens_index])) states[code1]=[idx1] - DFG+=temp + DFG+=temp return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in if_statement: DFG=[] @@ -113,15 +111,15 @@ def DFG_python(root_node,index_to_code,states): return sorted(DFG,key=lambda x:x[1]),new_states elif root_node.type in for_statement: DFG=[] - for i in range(2): + for _ in range(2): right_nodes=[x for x in root_node.child_by_field_name('right').children if x.type!=','] left_nodes=[x for x in root_node.child_by_field_name('left').children if x.type!=','] if len(right_nodes)!=len(left_nodes): left_nodes=[root_node.child_by_field_name('left')] right_nodes=[root_node.child_by_field_name('right')] - if len(left_nodes)==0: + if not left_nodes: left_nodes=[root_node.child_by_field_name('left')] - if len(right_nodes)==0: + if not right_nodes: right_nodes=[root_node.child_by_field_name('right')] for node in right_nodes: temp,states=DFG_python(node,index_to_code,states) @@ -135,10 +133,10 @@ def DFG_python(root_node,index_to_code,states): temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], [index_to_code[x][0] for x in right_tokens_index])) states[code1]=[idx1] - DFG+=temp + DFG+=temp if root_node.children[-1].type=="block": temp,states=DFG_python(root_node.children[-1],index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -150,10 +148,10 @@ def DFG_python(root_node,index_to_code,states): return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in while_statement: DFG=[] - for i in range(2): + for _ in range(2): for child in root_node.children: temp,states=DFG_python(child,index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -162,9 +160,10 @@ def DFG_python(root_node,index_to_code,states): dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] - return sorted(DFG,key=lambda x:x[1]),states + return sorted(DFG,key=lambda x:x[1]),states else: DFG=[] + do_first_statement=['for_in_clause'] for child in root_node.children: if child.type in do_first_statement: temp,states=DFG_python(child,index_to_code,states) @@ -173,7 +172,7 @@ def DFG_python(root_node,index_to_code,states): if child.type not in do_first_statement: temp,states=DFG_python(child,index_to_code,states) DFG+=temp - + return sorted(DFG,key=lambda x:x[1]),states @@ -185,7 +184,6 @@ def DFG_java(root_node,index_to_code,states): for_statement=['for_statement'] enhanced_for_statement=['enhanced_for_statement'] while_statement=['while_statement'] - do_first_statement=[] states=states.copy() if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': idx,code=index_to_code[(root_node.start_point,root_node.end_point)] @@ -207,19 +205,18 @@ def DFG_java(root_node,index_to_code,states): idx,code=index_to_code[index] DFG.append((code,idx,'comesFrom',[],[])) states[code]=[idx] - return sorted(DFG,key=lambda x:x[1]),states else: name_indexs=tree_to_variable_index(name,index_to_code) value_indexs=tree_to_variable_index(value,index_to_code) temp,states=DFG_java(value,index_to_code,states) - DFG+=temp + DFG+=temp for index1 in name_indexs: idx1,code1=index_to_code[index1] for index2 in value_indexs: idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) - states[code1]=[idx1] - return sorted(DFG,key=lambda x:x[1]),states + states[code1]=[idx1] + return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in assignment: left_nodes=root_node.child_by_field_name('left') right_nodes=root_node.child_by_field_name('right') @@ -244,7 +241,7 @@ def DFG_java(root_node,index_to_code,states): idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) states[code1]=[idx1] - return sorted(DFG,key=lambda x:x[1]),states + return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in if_statement: DFG=[] current_states=states.copy() @@ -303,19 +300,19 @@ def DFG_java(root_node,index_to_code,states): value=root_node.child_by_field_name('value') body=root_node.child_by_field_name('body') DFG=[] - for i in range(2): + for _ in range(2): temp,states=DFG_java(value,index_to_code,states) - DFG+=temp + DFG+=temp name_indexs=tree_to_variable_index(name,index_to_code) - value_indexs=tree_to_variable_index(value,index_to_code) + value_indexs=tree_to_variable_index(value,index_to_code) for index1 in name_indexs: idx1,code1=index_to_code[index1] for index2 in value_indexs: idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) - states[code1]=[idx1] + states[code1]=[idx1] temp,states=DFG_java(body,index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -327,10 +324,10 @@ def DFG_java(root_node,index_to_code,states): return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in while_statement: DFG=[] - for i in range(2): + for _ in range(2): for child in root_node.children: temp,states=DFG_java(child,index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -339,9 +336,10 @@ def DFG_java(root_node,index_to_code,states): dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] - return sorted(DFG,key=lambda x:x[1]),states + return sorted(DFG,key=lambda x:x[1]),states else: DFG=[] + do_first_statement=[] for child in root_node.children: if child.type in do_first_statement: temp,states=DFG_java(child,index_to_code,states) @@ -350,7 +348,7 @@ def DFG_java(root_node,index_to_code,states): if child.type not in do_first_statement: temp,states=DFG_java(child,index_to_code,states) DFG+=temp - + return sorted(DFG,key=lambda x:x[1]),states def DFG_csharp(root_node,index_to_code,states): diff --git a/GraphCodeBERT/refinement/parser/utils.py b/GraphCodeBERT/refinement/parser/utils.py index 270fba2..61888ea 100644 --- a/GraphCodeBERT/refinement/parser/utils.py +++ b/GraphCodeBERT/refinement/parser/utils.py @@ -24,13 +24,12 @@ def remove_comments_and_docstrings(source,lang): # Remove comments: if token_type == tokenize.COMMENT: pass - # This series of conditionals removes docstrings: elif token_type == tokenize.STRING: - if prev_toktype != tokenize.INDENT: - # This is likely a docstring; double-check we're not inside an operator: - if prev_toktype != tokenize.NEWLINE: - if start_col > 0: - out += token_string + if ( + prev_toktype not in [tokenize.INDENT, tokenize.NEWLINE] + and start_col > 0 + ): + out += token_string else: out += token_string prev_toktype = token_type @@ -46,10 +45,8 @@ def remove_comments_and_docstrings(source,lang): else: def replacer(match): s = match.group(0) - if s.startswith('/'): - return " " # note: a space and not an empty string - else: - return s + return " " if s.startswith('/') else s + pattern = re.compile( r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', re.DOTALL | re.MULTILINE @@ -63,11 +60,10 @@ def replacer(match): def tree_to_token_index(root_node): if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': return [(root_node.start_point,root_node.end_point)] - else: - code_tokens=[] - for child in root_node.children: - code_tokens+=tree_to_token_index(child) - return code_tokens + code_tokens=[] + for child in root_node.children: + code_tokens+=tree_to_token_index(child) + return code_tokens def tree_to_variable_index(root_node,index_to_code): if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': diff --git a/GraphCodeBERT/refinement/run.py b/GraphCodeBERT/refinement/run.py index e6135bb..8e163e9 100644 --- a/GraphCodeBERT/refinement/run.py +++ b/GraphCodeBERT/refinement/run.py @@ -75,19 +75,21 @@ def extract_dataflow(code, parser,lang): try: code=remove_comments_and_docstrings(code,lang) except: - pass + pass #obtain dataflow if lang=="php": - code="" + code = f"" try: - tree = parser[0].parse(bytes(code,'utf8')) - root_node = tree.root_node - tokens_index=tree_to_token_index(root_node) + tree = parser[0].parse(bytes(code,'utf8')) + root_node = tree.root_node + tokens_index=tree_to_token_index(root_node) code=code.split('\n') - code_tokens=[index_to_code_token(x,code) for x in tokens_index] - index_to_code={} - for idx,(index,code) in enumerate(zip(tokens_index,code_tokens)): - index_to_code[index]=(idx,code) + code_tokens=[index_to_code_token(x,code) for x in tokens_index] + index_to_code = { + index: (idx, code) + for idx, (index, code) in enumerate(zip(tokens_index, code_tokens)) + } + try: DFG,_=parser[1](root_node,index_to_code,{}) except: @@ -99,10 +101,7 @@ def extract_dataflow(code, parser,lang): indexs.add(d[1]) for x in d[-1]: indexs.add(x) - new_DFG=[] - for d in DFG: - if d[1] in indexs: - new_DFG.append(d) + new_DFG = [d for d in DFG if d[1] in indexs] dfg=new_DFG except: dfg=[] @@ -164,13 +163,18 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): for example_index, example in enumerate(tqdm(examples,total=len(examples))): ##extract data flow code_tokens,dfg=extract_dataflow(example.source,parsers['java'],'java') - code_tokens=[tokenizer.tokenize('@ '+x)[1:] if idx!=0 else tokenizer.tokenize(x) for idx,x in enumerate(code_tokens)] - ori2cur_pos={} - ori2cur_pos[-1]=(0,0) + code_tokens = [ + tokenizer.tokenize(f'@ {x}')[1:] + if idx != 0 + else tokenizer.tokenize(x) + for idx, x in enumerate(code_tokens) + ] + + ori2cur_pos = {-1: (0, 0)} for i in range(len(code_tokens)): - ori2cur_pos[i]=(ori2cur_pos[i-1][1],ori2cur_pos[i-1][1]+len(code_tokens[i])) + ori2cur_pos[i]=(ori2cur_pos[i-1][1],ori2cur_pos[i-1][1]+len(code_tokens[i])) code_tokens=[y for x in code_tokens for y in x] - + #truncating code_tokens=code_tokens[:args.max_source_length-3] source_tokens =[tokenizer.cls_token]+code_tokens+[tokenizer.sep_token] @@ -178,20 +182,18 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): position_idx = [i+tokenizer.pad_token_id + 1 for i in range(len(source_tokens))] dfg=dfg[:args.max_source_length-len(source_tokens)] source_tokens+=[x[0] for x in dfg] - position_idx+=[0 for x in dfg] - source_ids+=[tokenizer.unk_token_id for x in dfg] + position_idx += [0 for _ in dfg] + source_ids += [tokenizer.unk_token_id for _ in dfg] padding_length=args.max_source_length-len(source_ids) position_idx+=[tokenizer.pad_token_id]*padding_length - source_ids+=[tokenizer.pad_token_id]*padding_length + source_ids+=[tokenizer.pad_token_id]*padding_length source_mask = [1] * (len(source_tokens)) source_mask+=[0]*padding_length - + #reindex - reverse_index={} - for idx,x in enumerate(dfg): - reverse_index[x[1]]=idx + reverse_index = {x[1]: idx for idx, x in enumerate(dfg)} for idx,x in enumerate(dfg): - dfg[idx]=x[:-1]+([reverse_index[i] for i in x[-1] if i in reverse_index],) + dfg[idx]=x[:-1]+([reverse_index[i] for i in x[-1] if i in reverse_index],) dfg_to_dfg=[x[-1] for x in dfg] dfg_to_code=[ori2cur_pos[x[1]] for x in dfg] length=len([tokenizer.cls_token]) @@ -202,27 +204,26 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): target_tokens = tokenizer.tokenize("None") else: target_tokens = tokenizer.tokenize(example.target)[:args.max_target_length-2] - target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token] + target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token] target_ids = tokenizer.convert_tokens_to_ids(target_tokens) target_mask = [1] *len(target_ids) padding_length = args.max_target_length - len(target_ids) target_ids+=[tokenizer.pad_token_id]*padding_length target_mask+=[0]*padding_length - - if example_index < 5: - if stage=='train': - logger.info("*** Example ***") - logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens])) - logger.info("source_ids: {}".format(' '.join(map(str, source_ids)))) - logger.info("source_mask: {}".format(' '.join(map(str, source_mask)))) - logger.info("position_idx: {}".format(position_idx)) - logger.info("dfg_to_code: {}".format(' '.join(map(str, dfg_to_code)))) - logger.info("dfg_to_dfg: {}".format(' '.join(map(str, dfg_to_dfg)))) - - logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens])) - logger.info("target_ids: {}".format(' '.join(map(str, target_ids)))) - logger.info("target_mask: {}".format(' '.join(map(str, target_mask)))) - + + if example_index < 5 and stage == 'train': + logger.info("*** Example ***") + logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens])) + logger.info(f"source_ids: {' '.join(map(str, source_ids))}") + logger.info(f"source_mask: {' '.join(map(str, source_mask))}") + logger.info(f"position_idx: {position_idx}") + logger.info(f"dfg_to_code: {' '.join(map(str, dfg_to_code))}") + logger.info(f"dfg_to_dfg: {' '.join(map(str, dfg_to_dfg))}") + + logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens])) + logger.info(f"target_ids: {' '.join(map(str, target_ids))}") + logger.info(f"target_mask: {' '.join(map(str, target_mask))}") + features.append( InputFeatures( example_index, @@ -249,8 +250,8 @@ def __getitem__(self, item): #calculate graph-guided masked function attn_mask=np.zeros((self.args.max_source_length,self.args.max_source_length),dtype=np.bool) #calculate begin index of node and max length of input - node_index=sum([i>1 for i in self.examples[item].position_idx]) - max_length=sum([i!=1 for i in self.examples[item].position_idx]) + node_index = sum(i>1 for i in self.examples[item].position_idx) + max_length = sum(i!=1 for i in self.examples[item].position_idx) #sequence can attend to sequence attn_mask[:node_index,:node_index]=True #special tokens attend to all tokens @@ -267,7 +268,7 @@ def __getitem__(self, item): for a in nodes: if a+node_index 0: + precisions[i] = (float(matches_by_order[i]) / + possible_matches_by_order[i]) else: - if possible_matches_by_order[i] > 0: - precisions[i] = (float(matches_by_order[i]) / - possible_matches_by_order[i]) - else: - precisions[i] = 0.0 + precisions[i] = 0.0 if min(precisions) > 0: p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) @@ -102,33 +101,26 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4, ratio = float(translation_length) / reference_length - if ratio > 1.0: - bp = 1. - else: - bp = math.exp(1 - 1. / ratio) - + bp = 1. if ratio > 1.0 else math.exp(1 - 1. / ratio) bleu = geo_mean * bp return (bleu, precisions, bp, ratio, translation_length, reference_length) def _bleu(ref_file, trans_file, subword_option=None): - max_order = 4 - smooth = True - ref_files = [ref_file] - reference_text = [] - for reference_filename in ref_files: - with open(reference_filename) as fh: - reference_text.append(fh.readlines()) - per_segment_references = [] - for references in zip(*reference_text): - reference_list = [] - for reference in references: - reference_list.append(reference.strip().split()) - per_segment_references.append(reference_list) - translations = [] - with open(trans_file) as fh: - for line in fh: - translations.append(line.strip().split()) - bleu_score, _, _, _, _, _ = compute_bleu(per_segment_references, translations, max_order, smooth) - return round(100 * bleu_score,2) \ No newline at end of file + max_order = 4 + smooth = True + ref_files = [ref_file] + reference_text = [] + for reference_filename in ref_files: + with open(reference_filename) as fh: + reference_text.append(fh.readlines()) + per_segment_references = [] + for references in zip(*reference_text): + reference_list = [reference.strip().split() for reference in references] + per_segment_references.append(reference_list) + translations = [] + with open(trans_file) as fh: + translations.extend(line.strip().split() for line in fh) + bleu_score, _, _, _, _, _ = compute_bleu(per_segment_references, translations, max_order, smooth) + return round(100 * bleu_score,2) \ No newline at end of file diff --git a/GraphCodeBERT/translation/model.py b/GraphCodeBERT/translation/model.py index 433c54f..de3028f 100644 --- a/GraphCodeBERT/translation/model.py +++ b/GraphCodeBERT/translation/model.py @@ -54,13 +54,13 @@ def tie_weights(self): def forward(self, source_ids,source_mask,position_idx,attn_mask,target_ids=None,target_mask=None,args=None): #embedding nodes_mask=position_idx.eq(0) - token_mask=position_idx.ge(2) + token_mask=position_idx.ge(2) inputs_embeddings=self.encoder.embeddings.word_embeddings(source_ids) nodes_to_token_mask=nodes_mask[:,:,None]&token_mask[:,None,:]&attn_mask nodes_to_token_mask=nodes_to_token_mask/(nodes_to_token_mask.sum(-1)+1e-10)[:,:,None] avg_embeddings=torch.einsum("abc,acd->abd",nodes_to_token_mask,inputs_embeddings) inputs_embeddings=inputs_embeddings*(~nodes_mask)[:,:,None]+avg_embeddings*nodes_mask[:,:,None] - + outputs = self.encoder(inputs_embeds=inputs_embeddings,attention_mask=attn_mask,position_ids=position_idx) encoder_output = outputs[0].permute([1,0,2]).contiguous() #source_mask=token_mask.float() @@ -83,8 +83,8 @@ def forward(self, source_ids,source_mask,position_idx,attn_mask,target_ids=None, return outputs else: #Predict - preds=[] - zero=torch.cuda.LongTensor(1).fill_(0) + preds=[] + zero=torch.cuda.LongTensor(1).fill_(0) for i in range(source_ids.shape[0]): context=encoder_output[:,i:i+1] context_mask=source_mask[i:i+1,:] @@ -108,9 +108,8 @@ def forward(self, source_ids,source_mask,position_idx,attn_mask,target_ids=None, pred=beam.buildTargetTokens(hyp)[:self.beam_size] pred=[torch.cat([x.view(-1) for x in p]+[zero]*(self.max_length-len(p))).view(1,-1) for p in pred] preds.append(torch.cat(pred,0).unsqueeze(0)) - - preds=torch.cat(preds,0) - return preds + + return torch.cat(preds,0) @@ -134,8 +133,7 @@ def __init__(self, size,sos,eos): def getCurrentState(self): "Get the outputs for the current timestep." - batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) - return batch + return self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) def getCurrentOrigin(self): "Get the backpointers for the current timestep." @@ -194,11 +192,12 @@ def getFinal(self): self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) self.finished.sort(key=lambda a: -a[0]) if len(self.finished) != self.size: - unfinished=[] - for i in range(self.nextYs[-1].size(0)): - if self.nextYs[-1][i] != self._eos: - s = self.scores[i] - unfinished.append((s, len(self.nextYs) - 1, i)) + unfinished = [ + (self.scores[i], len(self.nextYs) - 1, i) + for i in range(self.nextYs[-1].size(0)) + if self.nextYs[-1][i] != self._eos + ] + unfinished.sort(key=lambda a: -a[0]) self.finished+=unfinished[:self.size-len(self.finished)] return self.finished[:self.size] diff --git a/GraphCodeBERT/translation/parser/DFG.py b/GraphCodeBERT/translation/parser/DFG.py index 61e0179..70d393b 100644 --- a/GraphCodeBERT/translation/parser/DFG.py +++ b/GraphCodeBERT/translation/parser/DFG.py @@ -13,9 +13,8 @@ def DFG_python(root_node,index_to_code,states): if_statement=['if_statement'] for_statement=['for_statement'] while_statement=['while_statement'] - do_first_statement=['for_in_clause'] def_statement=['default_parameter'] - states=states.copy() + states=states.copy() if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': idx,code=index_to_code[(root_node.start_point,root_node.end_point)] if root_node.type==code: @@ -36,19 +35,18 @@ def DFG_python(root_node,index_to_code,states): idx,code=index_to_code[index] DFG.append((code,idx,'comesFrom',[],[])) states[code]=[idx] - return sorted(DFG,key=lambda x:x[1]),states else: name_indexs=tree_to_variable_index(name,index_to_code) value_indexs=tree_to_variable_index(value,index_to_code) temp,states=DFG_python(value,index_to_code,states) - DFG+=temp + DFG+=temp for index1 in name_indexs: idx1,code1=index_to_code[index1] for index2 in value_indexs: idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) - states[code1]=[idx1] - return sorted(DFG,key=lambda x:x[1]),states + states[code1]=[idx1] + return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in assignment: if root_node.type=='for_in_clause': right_nodes=[root_node.children[-1]] @@ -61,15 +59,15 @@ def DFG_python(root_node,index_to_code,states): if len(right_nodes)!=len(left_nodes): left_nodes=[root_node.child_by_field_name('left')] right_nodes=[root_node.child_by_field_name('right')] - if len(left_nodes)==0: + if not left_nodes: left_nodes=[root_node.child_by_field_name('left')] - if len(right_nodes)==0: + if not right_nodes: right_nodes=[root_node.child_by_field_name('right')] DFG=[] for node in right_nodes: temp,states=DFG_python(node,index_to_code,states) DFG+=temp - + for left_node,right_node in zip(left_nodes,right_nodes): left_tokens_index=tree_to_variable_index(left_node,index_to_code) right_tokens_index=tree_to_variable_index(right_node,index_to_code) @@ -79,7 +77,7 @@ def DFG_python(root_node,index_to_code,states): temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], [index_to_code[x][0] for x in right_tokens_index])) states[code1]=[idx1] - DFG+=temp + DFG+=temp return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in if_statement: DFG=[] @@ -113,15 +111,15 @@ def DFG_python(root_node,index_to_code,states): return sorted(DFG,key=lambda x:x[1]),new_states elif root_node.type in for_statement: DFG=[] - for i in range(2): + for _ in range(2): right_nodes=[x for x in root_node.child_by_field_name('right').children if x.type!=','] left_nodes=[x for x in root_node.child_by_field_name('left').children if x.type!=','] if len(right_nodes)!=len(left_nodes): left_nodes=[root_node.child_by_field_name('left')] right_nodes=[root_node.child_by_field_name('right')] - if len(left_nodes)==0: + if not left_nodes: left_nodes=[root_node.child_by_field_name('left')] - if len(right_nodes)==0: + if not right_nodes: right_nodes=[root_node.child_by_field_name('right')] for node in right_nodes: temp,states=DFG_python(node,index_to_code,states) @@ -135,10 +133,10 @@ def DFG_python(root_node,index_to_code,states): temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], [index_to_code[x][0] for x in right_tokens_index])) states[code1]=[idx1] - DFG+=temp + DFG+=temp if root_node.children[-1].type=="block": temp,states=DFG_python(root_node.children[-1],index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -150,10 +148,10 @@ def DFG_python(root_node,index_to_code,states): return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in while_statement: DFG=[] - for i in range(2): + for _ in range(2): for child in root_node.children: temp,states=DFG_python(child,index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -162,9 +160,10 @@ def DFG_python(root_node,index_to_code,states): dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] - return sorted(DFG,key=lambda x:x[1]),states + return sorted(DFG,key=lambda x:x[1]),states else: DFG=[] + do_first_statement=['for_in_clause'] for child in root_node.children: if child.type in do_first_statement: temp,states=DFG_python(child,index_to_code,states) @@ -173,7 +172,7 @@ def DFG_python(root_node,index_to_code,states): if child.type not in do_first_statement: temp,states=DFG_python(child,index_to_code,states) DFG+=temp - + return sorted(DFG,key=lambda x:x[1]),states @@ -185,7 +184,6 @@ def DFG_java(root_node,index_to_code,states): for_statement=['for_statement'] enhanced_for_statement=['enhanced_for_statement'] while_statement=['while_statement'] - do_first_statement=[] states=states.copy() if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': idx,code=index_to_code[(root_node.start_point,root_node.end_point)] @@ -207,19 +205,18 @@ def DFG_java(root_node,index_to_code,states): idx,code=index_to_code[index] DFG.append((code,idx,'comesFrom',[],[])) states[code]=[idx] - return sorted(DFG,key=lambda x:x[1]),states else: name_indexs=tree_to_variable_index(name,index_to_code) value_indexs=tree_to_variable_index(value,index_to_code) temp,states=DFG_java(value,index_to_code,states) - DFG+=temp + DFG+=temp for index1 in name_indexs: idx1,code1=index_to_code[index1] for index2 in value_indexs: idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) - states[code1]=[idx1] - return sorted(DFG,key=lambda x:x[1]),states + states[code1]=[idx1] + return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in assignment: left_nodes=root_node.child_by_field_name('left') right_nodes=root_node.child_by_field_name('right') @@ -244,7 +241,7 @@ def DFG_java(root_node,index_to_code,states): idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) states[code1]=[idx1] - return sorted(DFG,key=lambda x:x[1]),states + return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in if_statement: DFG=[] current_states=states.copy() @@ -303,19 +300,19 @@ def DFG_java(root_node,index_to_code,states): value=root_node.child_by_field_name('value') body=root_node.child_by_field_name('body') DFG=[] - for i in range(2): + for _ in range(2): temp,states=DFG_java(value,index_to_code,states) - DFG+=temp + DFG+=temp name_indexs=tree_to_variable_index(name,index_to_code) - value_indexs=tree_to_variable_index(value,index_to_code) + value_indexs=tree_to_variable_index(value,index_to_code) for index1 in name_indexs: idx1,code1=index_to_code[index1] for index2 in value_indexs: idx2,code2=index_to_code[index2] DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) - states[code1]=[idx1] + states[code1]=[idx1] temp,states=DFG_java(body,index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -327,10 +324,10 @@ def DFG_java(root_node,index_to_code,states): return sorted(DFG,key=lambda x:x[1]),states elif root_node.type in while_statement: DFG=[] - for i in range(2): + for _ in range(2): for child in root_node.children: temp,states=DFG_java(child,index_to_code,states) - DFG+=temp + DFG+=temp dic={} for x in DFG: if (x[0],x[1],x[2]) not in dic: @@ -339,9 +336,10 @@ def DFG_java(root_node,index_to_code,states): dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] - return sorted(DFG,key=lambda x:x[1]),states + return sorted(DFG,key=lambda x:x[1]),states else: DFG=[] + do_first_statement=[] for child in root_node.children: if child.type in do_first_statement: temp,states=DFG_java(child,index_to_code,states) @@ -350,7 +348,7 @@ def DFG_java(root_node,index_to_code,states): if child.type not in do_first_statement: temp,states=DFG_java(child,index_to_code,states) DFG+=temp - + return sorted(DFG,key=lambda x:x[1]),states def DFG_csharp(root_node,index_to_code,states): diff --git a/GraphCodeBERT/translation/parser/utils.py b/GraphCodeBERT/translation/parser/utils.py index 270fba2..61888ea 100644 --- a/GraphCodeBERT/translation/parser/utils.py +++ b/GraphCodeBERT/translation/parser/utils.py @@ -24,13 +24,12 @@ def remove_comments_and_docstrings(source,lang): # Remove comments: if token_type == tokenize.COMMENT: pass - # This series of conditionals removes docstrings: elif token_type == tokenize.STRING: - if prev_toktype != tokenize.INDENT: - # This is likely a docstring; double-check we're not inside an operator: - if prev_toktype != tokenize.NEWLINE: - if start_col > 0: - out += token_string + if ( + prev_toktype not in [tokenize.INDENT, tokenize.NEWLINE] + and start_col > 0 + ): + out += token_string else: out += token_string prev_toktype = token_type @@ -46,10 +45,8 @@ def remove_comments_and_docstrings(source,lang): else: def replacer(match): s = match.group(0) - if s.startswith('/'): - return " " # note: a space and not an empty string - else: - return s + return " " if s.startswith('/') else s + pattern = re.compile( r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', re.DOTALL | re.MULTILINE @@ -63,11 +60,10 @@ def replacer(match): def tree_to_token_index(root_node): if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': return [(root_node.start_point,root_node.end_point)] - else: - code_tokens=[] - for child in root_node.children: - code_tokens+=tree_to_token_index(child) - return code_tokens + code_tokens=[] + for child in root_node.children: + code_tokens+=tree_to_token_index(child) + return code_tokens def tree_to_variable_index(root_node,index_to_code): if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': diff --git a/GraphCodeBERT/translation/run.py b/GraphCodeBERT/translation/run.py index 48dd06d..9b5179b 100644 --- a/GraphCodeBERT/translation/run.py +++ b/GraphCodeBERT/translation/run.py @@ -76,19 +76,21 @@ def extract_dataflow(code, parser,lang): try: code=remove_comments_and_docstrings(code,lang) except: - pass + pass #obtain dataflow if lang=="php": - code="" + code = f"" try: - tree = parser[0].parse(bytes(code,'utf8')) - root_node = tree.root_node - tokens_index=tree_to_token_index(root_node) + tree = parser[0].parse(bytes(code,'utf8')) + root_node = tree.root_node + tokens_index=tree_to_token_index(root_node) code=code.split('\n') - code_tokens=[index_to_code_token(x,code) for x in tokens_index] - index_to_code={} - for idx,(index,code) in enumerate(zip(tokens_index,code_tokens)): - index_to_code[index]=(idx,code) + code_tokens=[index_to_code_token(x,code) for x in tokens_index] + index_to_code = { + index: (idx, code) + for idx, (index, code) in enumerate(zip(tokens_index, code_tokens)) + } + try: DFG,_=parser[1](root_node,index_to_code,{}) except: @@ -100,10 +102,7 @@ def extract_dataflow(code, parser,lang): indexs.add(d[1]) for x in d[-1]: indexs.add(x) - new_DFG=[] - for d in DFG: - if d[1] in indexs: - new_DFG.append(d) + new_DFG = [d for d in DFG if d[1] in indexs] dfg=new_DFG except: dfg=[] @@ -125,10 +124,7 @@ def read_examples(filename): """Read examples from filename.""" examples=[] source,target=filename.split(',') - lang='java' - if source[-1]=='s': - lang='c_sharp' - + lang = 'c_sharp' if source[-1]=='s' else 'java' with open(source,encoding="utf-8") as f1,open(target,encoding="utf-8") as f2: for line1,line2 in zip(f1,f2): line1=line1.strip() @@ -180,13 +176,18 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): for example_index, example in enumerate(tqdm(examples,total=len(examples))): ##extract data flow code_tokens,dfg=extract_dataflow(example.source,parsers['java'],'java') - code_tokens=[tokenizer.tokenize('@ '+x)[1:] if idx!=0 else tokenizer.tokenize(x) for idx,x in enumerate(code_tokens)] - ori2cur_pos={} - ori2cur_pos[-1]=(0,0) + code_tokens = [ + tokenizer.tokenize(f'@ {x}')[1:] + if idx != 0 + else tokenizer.tokenize(x) + for idx, x in enumerate(code_tokens) + ] + + ori2cur_pos = {-1: (0, 0)} for i in range(len(code_tokens)): - ori2cur_pos[i]=(ori2cur_pos[i-1][1],ori2cur_pos[i-1][1]+len(code_tokens[i])) + ori2cur_pos[i]=(ori2cur_pos[i-1][1],ori2cur_pos[i-1][1]+len(code_tokens[i])) code_tokens=[y for x in code_tokens for y in x] - + #truncating code_tokens=code_tokens[:args.max_source_length-3][:512-3] source_tokens =[tokenizer.cls_token]+code_tokens+[tokenizer.sep_token] @@ -194,52 +195,49 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): position_idx = [i+tokenizer.pad_token_id + 1 for i in range(len(source_tokens))] dfg=dfg[:args.max_source_length-len(source_tokens)] source_tokens+=[x[0] for x in dfg] - position_idx+=[0 for x in dfg] - source_ids+=[tokenizer.unk_token_id for x in dfg] + position_idx += [0 for _ in dfg] + source_ids += [tokenizer.unk_token_id for _ in dfg] padding_length=args.max_source_length-len(source_ids) position_idx+=[tokenizer.pad_token_id]*padding_length - source_ids+=[tokenizer.pad_token_id]*padding_length + source_ids+=[tokenizer.pad_token_id]*padding_length source_mask = [1] * (len(source_tokens)) source_mask+=[0]*padding_length - + #reindex - reverse_index={} - for idx,x in enumerate(dfg): - reverse_index[x[1]]=idx + reverse_index = {x[1]: idx for idx, x in enumerate(dfg)} for idx,x in enumerate(dfg): - dfg[idx]=x[:-1]+([reverse_index[i] for i in x[-1] if i in reverse_index],) + dfg[idx]=x[:-1]+([reverse_index[i] for i in x[-1] if i in reverse_index],) dfg_to_dfg=[x[-1] for x in dfg] dfg_to_code=[ori2cur_pos[x[1]] for x in dfg] length=len([tokenizer.cls_token]) dfg_to_code=[(x[0]+length,x[1]+length) for x in dfg_to_code] - + #target if stage=="test": target_tokens = tokenizer.tokenize("None") else: target_tokens = tokenizer.tokenize(example.target)[:args.max_target_length-2] - target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token] + target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token] target_ids = tokenizer.convert_tokens_to_ids(target_tokens) target_mask = [1] *len(target_ids) padding_length = args.max_target_length - len(target_ids) target_ids+=[tokenizer.pad_token_id]*padding_length target_mask+=[0]*padding_length - - if example_index < 5: - if stage=='train': - logger.info("*** Example ***") - logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens])) - logger.info("source_ids: {}".format(' '.join(map(str, source_ids)))) - logger.info("source_mask: {}".format(' '.join(map(str, source_mask)))) - logger.info("position_idx: {}".format(position_idx)) - logger.info("dfg_to_code: {}".format(' '.join(map(str, dfg_to_code)))) - logger.info("dfg_to_dfg: {}".format(' '.join(map(str, dfg_to_dfg)))) - - logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens])) - logger.info("target_ids: {}".format(' '.join(map(str, target_ids)))) - logger.info("target_mask: {}".format(' '.join(map(str, target_mask)))) - + + if example_index < 5 and stage == 'train': + logger.info("*** Example ***") + logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens])) + logger.info(f"source_ids: {' '.join(map(str, source_ids))}") + logger.info(f"source_mask: {' '.join(map(str, source_mask))}") + logger.info(f"position_idx: {position_idx}") + logger.info(f"dfg_to_code: {' '.join(map(str, dfg_to_code))}") + logger.info(f"dfg_to_dfg: {' '.join(map(str, dfg_to_dfg))}") + + logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens])) + logger.info(f"target_ids: {' '.join(map(str, target_ids))}") + logger.info(f"target_mask: {' '.join(map(str, target_mask))}") + features.append( InputFeatures( example_index, @@ -266,8 +264,8 @@ def __getitem__(self, item): #calculate graph-guided masked function attn_mask=np.zeros((self.args.max_source_length,self.args.max_source_length),dtype=np.bool) #calculate begin index of node and max length of input - node_index=sum([i>1 for i in self.examples[item].position_idx]) - max_length=sum([i!=1 for i in self.examples[item].position_idx]) + node_index = sum(i>1 for i in self.examples[item].position_idx) + max_length = sum(i!=1 for i in self.examples[item].position_idx) #sequence can attend to sequence attn_mask[:node_index,:node_index]=True #special tokens attend to all tokens @@ -284,7 +282,7 @@ def __getitem__(self, item): for a in nodes: if a+node_index