Skip to content

Commit bbb8e1c

Browse files
authored
Merge pull request #33364 from vespa-engine/bratseth/tensor-type-generalize
Correct argument order in tensor type assignableTo
2 parents b185476 + 8502685 commit bbb8e1c

File tree

5 files changed

+35
-24
lines changed

5 files changed

+35
-24
lines changed

document/src/main/java/com/yahoo/document/TensorDataType.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public boolean isValueCompatible(FieldValue value) {
4848
if (tensorType == null) return true; // any
4949
if ( ! TensorFieldValue.class.isAssignableFrom(value.getClass())) return false;
5050
TensorFieldValue tensorValue = (TensorFieldValue)value;
51-
return tensorType.isConvertibleTo(tensorValue.getDataType().getTensorType());
51+
return tensorValue.getDataType().getTensorType().isConvertibleTo(tensorType);
5252
}
5353

5454
/** Returns the type of the tensor this field can hold */

indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/OutputExpression.java

-5
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,6 @@ public DataType setOutputType(DataType outputType, VerificationContext context)
3838
return context.getFieldType(fieldName, this);
3939
}
4040

41-
@Override
42-
protected void doVerify(VerificationContext context) {
43-
context.tryOutputType(fieldName, context.getCurrentType(), this);
44-
}
45-
4641
@Override
4742
protected void doExecute(ExecutionContext context) {
4843
context.setFieldValue(fieldName, context.getCurrentValue(), this);

indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTestCase.java

+33
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import com.yahoo.tensor.Tensor;
1414
import com.yahoo.tensor.TensorType;
1515
import com.yahoo.vespa.indexinglanguage.expressions.ExecutionContext;
16+
import com.yahoo.vespa.indexinglanguage.expressions.ScriptExpression;
17+
import com.yahoo.vespa.indexinglanguage.expressions.StatementExpression;
1618
import com.yahoo.vespa.indexinglanguage.expressions.VerificationContext;
1719
import com.yahoo.vespa.indexinglanguage.expressions.VerificationException;
1820
import org.junit.Ignore;
@@ -374,4 +376,35 @@ public void testIt() {
374376
expression.verify(new VerificationContext(adapter));
375377
}
376378

379+
@Test
380+
public void binarizeFromSeparateField() {
381+
/*
382+
field text_embeddings type tensor<float>(c{},x[768]) {
383+
indexing: attribute | summary
384+
attribute: paged
385+
}
386+
field text_embeddings_quant_binary type tensor<int8>(c{},x[96]) {
387+
indexing: input text_embeddings | binarize | pack_bits | attribute | index
388+
attribute {
389+
distance-metric: hamming
390+
}
391+
}
392+
*/
393+
394+
var tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockMappedEmbedder("myDocument.my2DSparseTensor")));
395+
396+
SimpleTestAdapter adapter = new SimpleTestAdapter();
397+
TensorType textEmbeddingsType = TensorType.fromSpec("tensor<float>(c{},x[768])");
398+
adapter.createField(new Field("text_embeddings", new TensorDataType(textEmbeddingsType)));
399+
400+
TensorType textEmbeddingsQuantBinaryType = TensorType.fromSpec("tensor<int8>(c{},x[96])");
401+
adapter.createField(new Field("text_embeddings_quant_binary", new TensorDataType(textEmbeddingsQuantBinaryType)));
402+
403+
var text_embeddings_expression = (StatementExpression)tester.expressionFrom("input text_embeddings | summary text_embeddings | attribute text_embeddings");
404+
var text_embeddings_quant_binary_expression = (StatementExpression)tester.expressionFrom("input text_embeddings | binarize | pack_bits | attribute text_embeddings_quant_binary | index text_embeddings_quant_binary");
405+
406+
var script = new ScriptExpression(text_embeddings_expression, text_embeddings_quant_binary_expression);
407+
script.verify(adapter);
408+
}
409+
377410
}

indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/expressions/SelectInputTestCase.java

-17
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,6 @@ public void requireThatHashCodeAndEqualsAreImplemented() {
4949
assertEquals(exp.hashCode(), newSelectInput(foo, "bar", "baz").hashCode());
5050
}
5151

52-
@Test
53-
public void requireThatExpressionCanBeVerified() {
54-
SimpleTestAdapter adapter = new SimpleTestAdapter();
55-
adapter.createField(new Field("my_int", DataType.INT));
56-
adapter.createField(new Field("my_str", DataType.STRING));
57-
58-
Expression exp = newSelectInput(new AttributeExpression("my_int"), "my_int");
59-
assertVerify(adapter, null, exp);
60-
assertVerify(adapter, DataType.INT, exp);
61-
assertVerify(adapter, DataType.STRING, exp);
62-
63-
assertVerifyThrows(adapter, newSelectInput(new AttributeExpression("my_int"), "my_str"),
64-
"Invalid expression 'attribute my_int': Can not assign string to field 'my_int' which is int.");
65-
assertVerifyThrows(adapter, newSelectInput(new AttributeExpression("my_int"), "my_unknown"),
66-
"Invalid expression 'select_input { my_unknown: attribute my_int; }': Field 'my_unknown' not found.");
67-
}
68-
6952
@Test
7053
public void requireThatSelectedExpressionIsRun() {
7154
assertSelect(List.of("foo", "bar"), List.of("foo"), "foo");

vespajlib/src/main/java/com/yahoo/tensor/TensorType.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ public Optional<Long> sizeOfDimension(String dimension) {
224224

225225
/**
226226
* Returns whether this type can be assigned to the given type,
227-
* i.e if the given type is a generalization of this type.
227+
* i.e. if the given type is a generalization of this type.
228228
*/
229229
public boolean isAssignableTo(TensorType generalization) {
230230
return isConvertibleOrAssignableTo(generalization, false, true);

0 commit comments

Comments
 (0)