diff --git a/run_span.py b/run_span.py index 119a675..4970a1c 100644 --- a/run_span.py +++ b/run_span.py @@ -1107,7 +1107,7 @@ def predict(args, model, processor, tokenizer, prefix=""): MODEL_CLASSES = { "bert_span": (BertConfigSpanV2, BertSpanV2ForNer, BertTokenizer), - # "nezha_span": (BertConfigSpanV2, NeZhaSpanV2ForNer, BertTokenizer), + "nezha_span": (BertConfigSpanV2, NeZhaSpanV2ForNer, BertTokenizer), } def load_dataset(args, processor, tokenizer, data_type='train'):