Skip to content

Commit

Permalink
[ML] Fix for Deberta tokenizer when input sequence exceeds 512 tokens (
Browse files Browse the repository at this point in the history
…elastic#117595) (elastic#117601)

* Add test and fix

* Update docs/changelog/117595.yaml

* Remove test which wasn't working
  • Loading branch information
maxhniebergall authored Nov 27, 2024
1 parent 038c688 commit bdebe39
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 2 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/117595.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117595
summary: Fix for Deberta tokenizer when input sequence exceeds 512 tokens
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,29 @@ public List<TokenizationResult.Tokens> tokenize(String seq1, String seq2, Tokeni
tokenIdsSeq2 = tokenIdsSeq2.subList(0, maxSequenceLength() - extraTokens - tokenIdsSeq1.size());
tokenPositionMapSeq2 = tokenPositionMapSeq2.subList(0, maxSequenceLength() - extraTokens - tokenIdsSeq1.size());
}
case BALANCED -> {
isTruncated = true;
int firstSequenceLength = 0;

if (tokenIdsSeq2.size() > (maxSequenceLength() - getNumExtraTokensForSeqPair()) / 2) {
firstSequenceLength = min(tokenIdsSeq1.size(), (maxSequenceLength() - getNumExtraTokensForSeqPair()) / 2);
} else {
firstSequenceLength = min(
tokenIdsSeq1.size(),
maxSequenceLength() - tokenIdsSeq2.size() - getNumExtraTokensForSeqPair()
);
}
int secondSequenceLength = min(
tokenIdsSeq2.size(),
maxSequenceLength() - firstSequenceLength - getNumExtraTokensForSeqPair()
);

tokenIdsSeq1 = tokenIdsSeq1.subList(0, firstSequenceLength);
tokenPositionMapSeq1 = tokenPositionMapSeq1.subList(0, firstSequenceLength);

tokenIdsSeq2 = tokenIdsSeq2.subList(0, secondSequenceLength);
tokenPositionMapSeq2 = tokenPositionMapSeq2.subList(0, secondSequenceLength);
}
case NONE -> throw ExceptionsHelper.badRequestException(
"Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]",
numTokens,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,22 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2Tokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizationResult;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2Tokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizerTests.TEST_CASED_VOCAB;
import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2TokenizerTests.TEST_CASE_SCORES;
import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2TokenizerTests.TEST_CASE_VOCAB;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
Expand Down Expand Up @@ -62,6 +66,33 @@ public void testProcessor() throws IOException {
assertThat(result.predictedValue(), closeTo(42, 1e-6));
}

public void testBalancedTruncationWithLongInput() throws IOException {
String question = "Is Elasticsearch scalable?";
StringBuilder longInputBuilder = new StringBuilder();
for (int i = 0; i < 1000; i++) {
longInputBuilder.append(TEST_CASE_VOCAB.get(randomIntBetween(0, TEST_CASE_VOCAB.size() - 1))).append(i).append(" ");
}
String longInput = longInputBuilder.toString().trim();

DebertaV2Tokenization tokenization = new DebertaV2Tokenization(false, true, null, Tokenization.Truncate.BALANCED, -1);
DebertaV2Tokenizer tokenizer = DebertaV2Tokenizer.builder(TEST_CASE_VOCAB, TEST_CASE_SCORES, tokenization).build();
TextSimilarityConfig textSimilarityConfig = new TextSimilarityConfig(
question,
new VocabularyConfig(""),
tokenization,
"result",
TextSimilarityConfig.SpanScoreFunction.MAX
);
TextSimilarityProcessor processor = new TextSimilarityProcessor(tokenizer);
TokenizationResult tokenizationResult = processor.getRequestBuilder(textSimilarityConfig)
.buildRequest(List.of(longInput), "1", Tokenization.Truncate.BALANCED, -1, null)
.tokenization();

// Assert that the tokenization result is as expected
assertThat(tokenizationResult.anyTruncated(), is(true));
assertThat(tokenizationResult.getTokenization(0).tokenIds().length, equalTo(512));
}

public void testResultFunctions() {
BertTokenization tokenization = new BertTokenization(false, true, 384, Tokenization.Truncate.NONE, 128);
BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, tokenization).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

public class DebertaV2TokenizerTests extends ESTestCase {

private static final List<String> TEST_CASE_VOCAB = List.of(
public static final List<String> TEST_CASE_VOCAB = List.of(
DebertaV2Tokenizer.CLASS_TOKEN,
DebertaV2Tokenizer.PAD_TOKEN,
DebertaV2Tokenizer.SEPARATOR_TOKEN,
Expand All @@ -48,7 +48,7 @@ public class DebertaV2TokenizerTests extends ESTestCase {
"<0xAD>",
"▁"
);
private static final List<Double> TEST_CASE_SCORES = List.of(
public static final List<Double> TEST_CASE_SCORES = List.of(
0.0,
0.0,
0.0,
Expand Down

0 comments on commit bdebe39

Please sign in to comment.