diff --git a/src/tasknet/utils.py b/src/tasknet/utils.py index f6a8214..060e708 100755 --- a/src/tasknet/utils.py +++ b/src/tasknet/utils.py @@ -214,6 +214,9 @@ def load_pipeline( ] += adapter.Z[task_index] pipe = TextClassificationPipeline( - model=model, tokenizer=tokenizer, device=device, return_all_scores=True + model=model, + tokenizer=tokenizer, + device=device, + return_all_scores=return_all_scores, ) return pipe