Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct argument order in tensor type assignableTo #33364

Merged
merged 2 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public boolean isValueCompatible(FieldValue value) {
if (tensorType == null) return true; // any
if ( ! TensorFieldValue.class.isAssignableFrom(value.getClass())) return false;
TensorFieldValue tensorValue = (TensorFieldValue)value;
return tensorType.isConvertibleTo(tensorValue.getDataType().getTensorType());
return tensorValue.getDataType().getTensorType().isConvertibleTo(tensorType);
}

/** Returns the type of the tensor this field can hold */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ public DataType setOutputType(DataType outputType, VerificationContext context)
return context.getFieldType(fieldName, this);
}

@Override
protected void doVerify(VerificationContext context) {
context.tryOutputType(fieldName, context.getCurrentType(), this);
}

@Override
protected void doExecute(ExecutionContext context) {
context.setFieldValue(fieldName, context.getCurrentValue(), this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.indexinglanguage.expressions.ExecutionContext;
import com.yahoo.vespa.indexinglanguage.expressions.ScriptExpression;
import com.yahoo.vespa.indexinglanguage.expressions.StatementExpression;
import com.yahoo.vespa.indexinglanguage.expressions.VerificationContext;
import com.yahoo.vespa.indexinglanguage.expressions.VerificationException;
import org.junit.Ignore;
Expand Down Expand Up @@ -374,4 +376,35 @@ public void testIt() {
expression.verify(new VerificationContext(adapter));
}

@Test
public void binarizeFromSeparateField() {
/*
field text_embeddings type tensor<float>(c{},x[768]) {
indexing: attribute | summary
attribute: paged
}
field text_embeddings_quant_binary type tensor<int8>(c{},x[96]) {
indexing: input text_embeddings | binarize | pack_bits | attribute | index
attribute {
distance-metric: hamming
}
}
*/

var tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockMappedEmbedder("myDocument.my2DSparseTensor")));

SimpleTestAdapter adapter = new SimpleTestAdapter();
TensorType textEmbeddingsType = TensorType.fromSpec("tensor<float>(c{},x[768])");
adapter.createField(new Field("text_embeddings", new TensorDataType(textEmbeddingsType)));

TensorType textEmbeddingsQuantBinaryType = TensorType.fromSpec("tensor<int8>(c{},x[96])");
adapter.createField(new Field("text_embeddings_quant_binary", new TensorDataType(textEmbeddingsQuantBinaryType)));

var text_embeddings_expression = (StatementExpression)tester.expressionFrom("input text_embeddings | summary text_embeddings | attribute text_embeddings");
var text_embeddings_quant_binary_expression = (StatementExpression)tester.expressionFrom("input text_embeddings | binarize | pack_bits | attribute text_embeddings_quant_binary | index text_embeddings_quant_binary");

var script = new ScriptExpression(text_embeddings_expression, text_embeddings_quant_binary_expression);
script.verify(adapter);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,6 @@ public void requireThatHashCodeAndEqualsAreImplemented() {
assertEquals(exp.hashCode(), newSelectInput(foo, "bar", "baz").hashCode());
}

@Test
public void requireThatExpressionCanBeVerified() {
SimpleTestAdapter adapter = new SimpleTestAdapter();
adapter.createField(new Field("my_int", DataType.INT));
adapter.createField(new Field("my_str", DataType.STRING));

Expression exp = newSelectInput(new AttributeExpression("my_int"), "my_int");
assertVerify(adapter, null, exp);
assertVerify(adapter, DataType.INT, exp);
assertVerify(adapter, DataType.STRING, exp);

assertVerifyThrows(adapter, newSelectInput(new AttributeExpression("my_int"), "my_str"),
"Invalid expression 'attribute my_int': Can not assign string to field 'my_int' which is int.");
assertVerifyThrows(adapter, newSelectInput(new AttributeExpression("my_int"), "my_unknown"),
"Invalid expression 'select_input { my_unknown: attribute my_int; }': Field 'my_unknown' not found.");
}

@Test
public void requireThatSelectedExpressionIsRun() {
assertSelect(List.of("foo", "bar"), List.of("foo"), "foo");
Expand Down
2 changes: 1 addition & 1 deletion vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ public Optional<Long> sizeOfDimension(String dimension) {

/**
* Returns whether this type can be assigned to the given type,
* i.e if the given type is a generalization of this type.
* i.e. if the given type is a generalization of this type.
*/
public boolean isAssignableTo(TensorType generalization) {
return isConvertibleOrAssignableTo(generalization, false, true);
Expand Down