Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cosine parameter #4

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ This Plugin allows you to score Elasticsearch documents based on embedding-vecto

## General
* Updated version for ES 6.1 of [this plugin](https://github.com/lior-k/fast-elasticsearch-vector-scoring)
* Cosine support removed.
* This plugin was inspired from [This elasticsearch vector scoring plugin](https://github.com/MLnick/elasticsearch-vector-scoring) and [this discussion](https://discuss.elastic.co/t/vector-scoring/85227/6) to achieve 10 times faster processing over the original.
* lior-k gained this substantial speed improvement by using the lucene index directly
* lior-k developed it for their workplace which needs to pick KNN from a set of ~4M vectors. Their current ES setup is able to answer this in ~80ms
* Cosine support
* Vector norms can be stored with vector and can be used when cosine distance calculating, so that they are not recalculated for each request


## Elasticsearch version
* Currently designed for Elasticsearch 6.1.2.
* Currently designed for Elasticsearch 6.3.1.


## Setup
Expand Down Expand Up @@ -42,14 +43,45 @@ Place this into an elasticsearch checkout, add the plugin to the projects list i
## Usage

### Documents
* Each document you score should have a field containing the base64 representation of your vector. for example:
* Each document you score should have a field containing the base64 representation of your vector. for example vector:
```
{
"id": 1,
....
"content_vector": "v7l48eAAAAA/s4VHwAAAAD+R7I5AAAAAv8MBMAAAAAA/yEI3AAAAAL/IWkeAAAAAv7s480AAAAC/v6DUgAAAAL+wJi0gAAAAP76VqUAAAAC/sL1ZYAAAAL/dyq/gAAAAP62FVcAAAAC/tQRvYAAAAL+j6ycAAAAAP6v1KcAAAAC/bN5hQAAAAL+u9ItAAAAAP4ckTsAAAAC/pmkjYAAAAD+cYpwAAAAAP5renEAAAAC/qY0HQAAAAD+wyYGgAAAAP5WrCcAAAAA/qzjTQAAAAD++LBzAAAAAP49wNKAAAAC/vu/aIAAAAD+hqXfAAAAAP4FfNCAAAAA/pjC64AAAAL+qwT2gAAAAv6S3OGAAAAC/gfMtgAAAAD/If5ZAAAAAP5mcXOAAAAC/xYAU4AAAAL+2nlfAAAAAP7sCXOAAAAA/petBIAAAAD9soYnAAAAAv5R7X+AAAAC/pgM/IAAAAL+ojI/gAAAAP2gPz2AAAAA/3FonoAAAAL/IHg1AAAAAv6p1SmAAAAA/tvKlQAAAAD/I2OMAAAAAP3FBiCAAAAA/wEd8IAAAAL94wI9AAAAAP2Y1IIAAAAA/rnS4wAAAAL9vriVgAAAAv1QxoCAAAAC/1/qu4AAAAL+inZFAAAAAv7aGA+AAAAA/lqYVYAAAAD+kNP0AAAAAP730BiAAAAA="
}
```
[-0.09950172156095505, 0.07625244557857513, 0.017503950744867325, -0.14847373962402344, 0.1895207166671753, -0.19025510549545288, -0.10633774101734161, -0.12354782223701477, -0.06308252364397049, 0.11947114765644073, -0.0653892382979393, -0.4654960334300995, 0.057657890021800995, -0.08209892362356186, -0.03890344500541687, 0.054604820907115936, -0.0035240077413618565, -0.06045947223901749, 0.011299720034003258, -0.043770890682935715, 0.02771991491317749, 0.02623981609940529, -0.04990408569574356, 0.06557474285364151, 0.021160271018743515, 0.0531679168343544, 0.11786060035228729, 0.015350733883678913, -0.12084735184907913, 0.034496061503887177, 0.008482367731630802, 0.0433405302464962, -0.05225555971264839, -0.040460359305143356, -0.008764605969190598, 0.19139364361763, 0.02501053921878338, -0.16797123849391937, -0.08835361897945404, 0.10550480335950851, 0.04281047359108925, 0.0034949961118400097, -0.020001886412501335, -0.04299351945519447, -0.04794740304350853, 0.0029372263234108686, 0.4430026113986969, -0.18841710686683655, -0.051676105707883835, 0.08963997662067413, 0.19411885738372803, 0.004212886560708284, 0.1271815448999405, -0.006043014116585255, 0.0027108797803521156, 0.05948426574468613, -0.0038672189693897963, -0.0012325347634032369, -0.3746754825115204, -0.03635839372873306, -0.0879824087023735, 0.02211793325841427, 0.03946676850318909, 0.11700475960969925]

```
should be converted to:
```
{
"id": 1,
....
"content_vector": "v7l48eAAAAA/s4VHwAAAAD+R7I5AAAAAv8MBMAAAAAA/yEI3AAAAAL/IWkeAAAAAv7s480AAAAC/v6DUgAAAAL+wJi0gAAAAP76VqUAAAAC/sL1ZYAAAAL/dyq/gAAAAP62FVcAAAAC/tQRvYAAAAL+j6ycAAAAAP6v1KcAAAAC/bN5hQAAAAL+u9ItAAAAAP4ckTsAAAAC/pmkjYAAAAD+cYpwAAAAAP5renEAAAAC/qY0HQAAAAD+wyYGgAAAAP5WrCcAAAAA/qzjTQAAAAD++LBzAAAAAP49wNKAAAAC/vu/aIAAAAD+hqXfAAAAAP4FfNCAAAAA/pjC64AAAAL+qwT2gAAAAv6S3OGAAAAC/gfMtgAAAAD/If5ZAAAAAP5mcXOAAAAC/xYAU4AAAAL+2nlfAAAAAP7sCXOAAAAA/petBIAAAAD9soYnAAAAAv5R7X+AAAAC/pgM/IAAAAL+ojI/gAAAAP2gPz2AAAAA/3FonoAAAAL/IHg1AAAAAv6p1SmAAAAA/tvKlQAAAAD/I2OMAAAAAP3FBiCAAAAA/wEd8IAAAAL94wI9AAAAAP2Y1IIAAAAA/rnS4wAAAAL9vriVgAAAAv1QxoCAAAAC/1/qu4AAAAL+inZFAAAAAv7aGA+AAAAA/lqYVYAAAAD+kNP0AAAAAP730BiAAAAA/8AAAAq9PCQ=="
}
```

Also in this example I saved vector norm as last element of array and convert vector and its norm to base64 representation

### Converting a vector to Base64

**Python**
```
import struct
import base64
import math


def calc_vector_norm(vector):
norm = 0
for elem in vector:
norm += elem**2
return math.sqrt(norm)


def encode_v(vector):
vector.append(calc_vector_norm(vector))
return base64.b64encode(
struct.pack('>%sd' % len(vector), *vector)
).decode('utf-8')
```

* Use this field mapping:
```
PUT my_index
Expand Down Expand Up @@ -88,6 +120,8 @@ POST /_search
"source": "vector_scoring",
"lang": "binary_vector_score",
"params": {
"use_stored_vector_norm": true,
"cosine": true,
"vector_field": "content_vector",
"vector": [
-0.09217305481433868,
Expand Down Expand Up @@ -167,5 +201,7 @@ POST /_search
* The example above shows a vector of 64 dimensions
* Parameters:
1. `field_vector`: The field containing the base64 vector.
3. `vector`: The vector (comma separated) to compare to.
2. `vector`: The vector (comma separated) to compare to.
3. `cosine`: (boolean) calculate cosine distance or dot product of vectors
4. `use_stored_vector_norm`: (boolean) if `true` stored norm will be used when cosine distance will calculating

58 changes: 54 additions & 4 deletions src/main/java/com/gosololaw/elasticsearch/VectorScoringPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,16 @@ public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> co
if (!context.equals(SearchScript.CONTEXT)) {
throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]");
}

// we use the script "source" as the script identifier
if ("vector_scoring".equals(scriptSource)) {
SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() {

private final double[] inputVector;
private final double inputVectorNorm;
final String field;
final boolean useStoredVectorNorm;
final boolean cosine;
{
final Object field = p.get("vector_field");
if (field == null)
Expand All @@ -74,9 +79,26 @@ public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> co
// get query inputVector - convert to primitive
final ArrayList<Double> tmp = (ArrayList<Double>) p.get("vector");
this.inputVector = new double[tmp.size()];

for (int i = 0; i < inputVector.length; i++) {
inputVector[i] = tmp.get(i);
}

final Object cosine = p.get("cosine");
this.cosine = cosine != null && (boolean)cosine;

if (this.cosine) {
double norm = 0.0f;
for (int i = 0; i < inputVector.length; i++) {
norm += inputVector[i] * inputVector[i];
}
this.inputVectorNorm = Math.sqrt(norm);
} else {
this.inputVectorNorm = 0;
}

final Object useStoredVectorNorm = p.get("use_stored_vector_norm");
this.useStoredVectorNorm = useStoredVectorNorm != null && (boolean)useStoredVectorNorm;
}

@Override
Expand All @@ -101,17 +123,23 @@ public double runAsDouble() {
if (!is_value) return 0;
final byte[] bytes;
try {
bytes = accessor.binaryValue().bytes;
bytes = accessor.binaryValue().bytes;
} catch (IOException e) {
return 0;
return 0;
}

final int input_vector_size = inputVector.length;

final ByteArrayDataInput doc_vector = new ByteArrayDataInput(bytes);

doc_vector.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls
final int doc_vector_length = doc_vector.readVInt(); // returns the number of bytes to read
if(doc_vector_length != input_vector_size * DOUBLE_SIZE) {
int doc_vector_length = doc_vector.readVInt(); // returns the number of bytes to read

if (useStoredVectorNorm) {
doc_vector_length = doc_vector_length - DOUBLE_SIZE;
}

if(doc_vector_length < input_vector_size * DOUBLE_SIZE) {
return 0.0;
}
final int position = doc_vector.getPosition();
Expand All @@ -124,6 +152,28 @@ public double runAsDouble() {
for (int i = 0; i < input_vector_size; i++) {
score += docVector[i] * inputVector[i];
}

if (!cosine) {
return score;
}


double docVectorNorm = 0.0f;

if (useStoredVectorNorm) {
doc_vector.skipBytes(doc_vector_length);
docVectorNorm = Double.longBitsToDouble(doc_vector.readLong());
} else {
for (int i = 0; i < input_vector_size; i++) {
docVectorNorm += docVector[i] * docVector[i];
}
docVectorNorm = Math.sqrt(docVectorNorm);
}

score /= inputVectorNorm;
score /= docVectorNorm;


return score;
}
};
Expand Down