diff --git a/run_span.py b/run_span.py index c0bc929..3f501c5 100644 --- a/run_span.py +++ b/run_span.py @@ -817,7 +817,7 @@ def _convert_example_to_feature(self, example): inputs["input_len"] = inputs["attention_mask"].sum(dim=1) # for special tokens input_len = inputs["input_len"].item() inputs["spans"], inputs["span_mask"] = self._encode_span( - input_len, input_len, sent_start, sent_end) # dynamic batch + input_len, input_len, sent_start + 1, sent_end + 1) # dynamic batch inputs["sent_start"] = torch.tensor([sent_start]) inputs["sent_end"] = torch.tensor([sent_end])