From 854c69b0602e909d543c16424377579f0d46efac Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Wed, 30 Jan 2019 17:23:41 -0500 Subject: [PATCH 01/25] Changed script to work with Elastic 6.0.0. Inline scripts are now depreciated so using 'source' field instead --- README.md | 4 +- pom.xml | 12 +- .../plugin/VectorScoringPlugin.java | 154 +++++++++++++- .../script/VectorScoreScript.java | 193 ------------------ .../VectorScoringScriptEngineService.java | 78 ------- .../EmbeddedElasticsearchServer.java | 11 +- .../com/liorkn/elasticsearch/PluginTest.java | 2 +- 7 files changed, 159 insertions(+), 295 deletions(-) delete mode 100755 src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java delete mode 100755 src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java diff --git a/README.md b/README.md index 9806c21..7088177 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ give it a try. ## Elasticsearch version -* Currently designed for Elasticsearch 5.6.0. +* Currently designed for Elasticsearch 6.0.0. * for Elasticsearch 5.2.2 use branch `es-5.2.2` * for Elasticsearch 2.4.4 use branch `es-2.4.4` @@ -146,7 +146,7 @@ func convertBase64ToArray(base64Str string) ([]float64, error) { "boost_mode": "replace", "script_score": { "script": { - "inline": "binary_vector_score", + "source": "binary_vector_score", "lang": "knn", "params": { "cosine": false, diff --git a/pom.xml b/pom.xml index d765043..bcdee7c 100755 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ elasticsearch-binary-vector-scoring com.liorkn.elasticsearch elasticsearch-binary-vector-scoring - 5.6.0 + 6.0.0 ElasticSearch Plugin for Binary Vector Scoring @@ -27,7 +27,7 @@ ${project.basedir}/src/main/resources/license-check/license_header_definition.xml warn - 5.6.0 + 6.0.0 2.4 4.4.8 4.12 @@ -65,7 +65,7 @@ org.elasticsearch.plugin - transport-netty3-client + transport-netty4-client ${elasticsearch.version} test @@ -86,12 +86,6 @@ - - - - - - org.codelibs.elasticsearch.module lang-painless diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index 88a3599..b0a666e 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -13,13 +13,25 @@ */ package com.liorkn.elasticsearch.plugin; -import com.liorkn.elasticsearch.service.VectorScoringScriptEngineService; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.util.Collection; +import java.util.Map; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.Term; +import org.apache.lucene.store.ByteArrayDataInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ScriptPlugin; -import org.elasticsearch.script.ScriptEngineService; - +import org.elasticsearch.script.ScriptContext; +import org.elasticsearch.script.ScriptEngine; +import org.elasticsearch.script.SearchScript; +import java.util.ArrayList; /** * This class is instantiated when Elasticsearch loads the plugin for the * first time. If you change the name of this plugin, make sure to update @@ -27,9 +39,141 @@ */ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { - public final ScriptEngineService getScriptEngineService(Settings settings) { - return new VectorScoringScriptEngineService(settings); + @Override + public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { + return new MyExpertScriptEngine(); } + /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ + private static class MyExpertScriptEngine implements ScriptEngine { + @Override + public String getType() { + return "knn"; + } + + private static final int DOUBLE_SIZE = 8; + + @Override + public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { + + if (context.equals(SearchScript.CONTEXT) == false) { + throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); + } + + // we use the script "source" as the script identifier + if ("binary_vector_score".equals(scriptSource)) { + SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { + final String field; + final boolean cosine; + { + if (p.containsKey("vector") == false) { + throw new IllegalArgumentException("Missing parameter [vector]"); + } + if (p.containsKey("field") == false) { + throw new IllegalArgumentException("Missing parameter [field]"); + } + if (p.containsKey("cosine") == false) { + throw new IllegalArgumentException("Missing parameter [cosine]"); + } + field = p.get("field").toString(); + cosine = (boolean) p.get("cosine"); + } + + final ArrayList searchVector = (ArrayList) p.get("vector"); + double magnitude; + { + if (cosine) { + // calc magnitude + double queryVectorNorm = 0.0; + // compute query inputVector norm once + for (Double v : this.searchVector) { + queryVectorNorm += v.doubleValue() * v.doubleValue(); + } + magnitude = Math.sqrt(queryVectorNorm); + } else { + magnitude = 0.0; + } + } + + @Override + public SearchScript newInstance(LeafReaderContext context) throws IOException { + return new SearchScript(p, lookup, context) { + BinaryDocValues docAccess = context.reader().getBinaryDocValues(field); + int currentDocid = -1; + + @Override + public void setDocument(int docid) { + // Move to desired document + try { + docAccess.advanceExact(docid); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + currentDocid = docid; + } + + @Override + public double runAsDouble() { + if (currentDocid < 0) { + return 0.0; + } + //actually run scoring + final int size = searchVector.size(); + + try { + final byte[] bytes = docAccess.binaryValue().bytes; + final ByteArrayDataInput input = new ByteArrayDataInput(bytes); + input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls + final int len = input.readVInt(); // returns the number of bytes to read//if submitted vector is different size + if (len != size * DOUBLE_SIZE) { + return 0.0; + } + + final int position = input.getPosition(); + final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer(); + final double[] docVector = new double[size]; + doubleBuffer.get(docVector); + double docVectorNorm = 0.0f; + double score = 0; + for (int i = 0; i < size; i++) { + // doc inputVector norm + if(cosine) { + docVectorNorm += docVector[i]*docVector[i]; + } + // dot product + score += docVector[i] * searchVector.get(i).doubleValue(); + } + if(cosine) { + // cosine similarity score + if (docVectorNorm == 0 || magnitude == 0){ + return 0f; + } else { + return score / (Math.sqrt(docVectorNorm) * magnitude); + } + } else { + return score; + } + } catch (Exception e) { + return 0; + } + } + }; + } + + @Override + public boolean needs_score() { + return false; + } + }; + return context.factoryClazz.cast(factory); + } + throw new IllegalArgumentException("Unknown script name " + scriptSource); + } + + @Override + public void close() { + // optionally close resources + } + } } \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java deleted file mode 100755 index 0b87cf3..0000000 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ /dev/null @@ -1,193 +0,0 @@ -/* -Based on: https://discuss.elastic.co/t/vector-scoring/85227/4 -and https://github.com/MLnick/elasticsearch-vector-scoring - -another slower implementation using strings: https://github.com/ginobefun/elasticsearch-feature-vector-scoring - -storing arrays is no luck - lucine index doesn't keep the array members orders -https://www.elastic.co/guide/en/elasticsearch/guide/current/complex-core-fields.html - -Delimited Payload Token Filter: https://www.elastic.co/guide/en/elasticsearch/reference/2.4/analysis-delimited-payload-tokenfilter.html - - - */ - -package com.liorkn.elasticsearch.script; - -import com.liorkn.elasticsearch.Util; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.store.ByteArrayDataInput; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.script.ExecutableScript; -import org.elasticsearch.script.LeafSearchScript; -import org.elasticsearch.script.ScriptException; - -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; -import java.util.ArrayList; -import java.util.Base64; -import java.util.Map; - -/** - * Script that scores documents based on cosine similarity embedding vectors. - */ -public final class VectorScoreScript implements LeafSearchScript, ExecutableScript { - - public static final String SCRIPT_NAME = "binary_vector_score"; - - private static final int DOUBLE_SIZE = 8; - - // the field containing the vectors to be scored against - public final String field; - - private int docId; - private BinaryDocValues binaryEmbeddingReader; - - private final double[] inputVector; - private final double magnitude; - - private final boolean cosine; - - @Override - public long runAsLong() { - return ((Number)this.run()).longValue(); - } - @Override - public double runAsDouble() { - return ((Number)this.run()).doubleValue(); - } - @Override - public void setNextVar(String name, Object value) {} - @Override - public void setDocument(int docId) { - this.docId = docId; - } - - public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { - if(binaryEmbeddingReader == null) { - throw new IllegalStateException("binaryEmbeddingReader can't be null"); - } - this.binaryEmbeddingReader = binaryEmbeddingReader; - } - - - /** - * Factory that is registered in - * {@link VectorScoringPlugin#onModule(org.elasticsearch.script.ScriptModule)} - * method when the plugin is loaded. - */ - public static class Factory { - - /** - * This method is called for every search on every shard. - * - * @param params - * list of script parameters passed with the query - * @return new native script - */ - public ExecutableScript newScript(@Nullable Map params) throws ScriptException { - return new VectorScoreScript(params); - } - - /** - * Indicates if document scores may be needed by the produced scripts. - * - * @return {@code true} if scores are needed. - */ - public boolean needsScores() { - return false; - } - - } - - - /** - * Init - * @param params index that a scored are placed in this parameter. Initialize them here. - */ - @SuppressWarnings("unchecked") - public VectorScoreScript(Map params) { - final Object cosineBool = params.get("cosine"); - cosine = cosineBool != null ? - (boolean)cosineBool : - true; - - final Object field = params.get("field"); - if (field == null) - throw new IllegalArgumentException("binary_vector_score script requires field input"); - this.field = field.toString(); - - // get query inputVector - convert to primitive - final Object vector = params.get("vector"); - if(vector != null) { - final ArrayList tmp = (ArrayList) vector; - inputVector = new double[tmp.size()]; - for (int i = 0; i < inputVector.length; i++) { - inputVector[i] = tmp.get(i); - } - } else { - final Object encodedVector = params.get("encoded_vector"); - if(encodedVector == null) { - throw new IllegalArgumentException("Must have at 'vector' or 'encoded_vector' as a parameter"); - } - inputVector = Util.convertBase64ToArray((String) encodedVector); - } - - if(cosine) { - // calc magnitude - double queryVectorNorm = 0.0; - // compute query inputVector norm once - for (double v : inputVector) { - queryVectorNorm += v * v; - } - magnitude = Math.sqrt(queryVectorNorm); - } else { - magnitude = 0.0; - } - } - - - /** - * Called for each document - * @return cosine similarity of the current document against the input inputVector - */ - @Override - public final Object run() { - final int size = inputVector.length; - - final byte[] bytes = binaryEmbeddingReader.get(docId).bytes; - final ByteArrayDataInput input = new ByteArrayDataInput(bytes); - input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls - final int len = input.readVInt(); // returns the number of bytes to read - if(len != size * DOUBLE_SIZE) { - return 0.0; - } - final int position = input.getPosition(); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer(); - - final double[] docVector = new double[size]; - doubleBuffer.get(docVector); - - double docVectorNorm = 0.0f; - double score = 0; - for (int i = 0; i < size; i++) { - // doc inputVector norm - if(cosine) { - docVectorNorm += docVector[i]*docVector[i]; - } - // dot product - score += docVector[i] * inputVector[i]; - } - if(cosine) { - // cosine similarity score - if (docVectorNorm == 0 || magnitude == 0){ - return 0f; - } else { - return score / (Math.sqrt(docVectorNorm) * magnitude); - } - } else { - return score; - } - } - -} \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java b/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java deleted file mode 100755 index 58db087..0000000 --- a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java +++ /dev/null @@ -1,78 +0,0 @@ -package com.liorkn.elasticsearch.service; - -import com.liorkn.elasticsearch.script.VectorScoreScript; -import org.apache.lucene.index.LeafReaderContext; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.component.AbstractComponent; -import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.script.CompiledScript; -import org.elasticsearch.script.ExecutableScript; -import org.elasticsearch.script.LeafSearchScript; -import org.elasticsearch.script.ScriptEngineService; -import org.elasticsearch.script.SearchScript; -import org.elasticsearch.search.lookup.SearchLookup; - -import java.io.IOException; -import java.util.Map; - -/** - * Created by Lior Knaany on 5/14/17. - */ -public class VectorScoringScriptEngineService extends AbstractComponent implements ScriptEngineService{ - - public static final String NAME = "knn"; - - @Inject - public VectorScoringScriptEngineService(Settings settings) { - super(settings); - } - - @Override - public Object compile(String scriptName, String scriptSource, Map params) { - return new VectorScoreScript.Factory(); - } - - - @Override - public boolean isInlineScriptEnabled() { - return true; - } - - @Override - public String getType() { - return NAME; - } - - @Override - public String getExtension() { - return NAME; - } - - @Override - public ExecutableScript executable(CompiledScript compiledScript, @Nullable Map vars) { - VectorScoreScript.Factory scriptFactory = (VectorScoreScript.Factory) compiledScript.compiled(); - return scriptFactory.newScript(vars); - } - - @Override - public SearchScript search(CompiledScript compiledScript, final SearchLookup lookup, @Nullable final Map vars) { - final VectorScoreScript.Factory scriptFactory = (VectorScoreScript.Factory) compiledScript.compiled(); - final VectorScoreScript script = (VectorScoreScript) scriptFactory.newScript(vars); - return new SearchScript() { - @Override - public LeafSearchScript getLeafSearchScript(LeafReaderContext context) throws IOException { - script.setBinaryEmbeddingReader(context.reader().getBinaryDocValues(script.field)); - return script; - } - @Override - public boolean needsScores() { - return scriptFactory.needsScores(); - } - }; - } - - @Override - public void close() { - } -} diff --git a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java index 5627240..1cbdb9c 100644 --- a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java +++ b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java @@ -10,7 +10,7 @@ import org.elasticsearch.node.Node; import org.elasticsearch.node.NodeValidationException; import org.elasticsearch.painless.PainlessPlugin; -import org.elasticsearch.transport.Netty3Plugin; +import org.elasticsearch.transport.Netty4Plugin; import java.io.File; import java.io.IOException; @@ -41,13 +41,10 @@ private EmbeddedElasticsearchServer(String defaultDataDirectory, int port) throw Settings.Builder settings = Settings.builder() .put("http.enabled", "true") - .put("transport.type", "local") - .put("http.type", "netty3") + .put("http.type", "netty4") .put("path.data", dataDirectory) .put("path.home", DEFAULT_HOME_DIRECTORY) - .put("script.inline", "on") - .put("node.max_local_storage_nodes", 10000) - .put("script.stored", "on"); + .put("node.max_local_storage_nodes", 10000); startNodeInAvailablePort(settings); } @@ -61,7 +58,7 @@ private void startNodeInAvailablePort(Settings.Builder settings) throws NodeVali settings.put("http.port", String.valueOf(this.port)); // this a hack in order to load Groovy plug in since we want to enable the usage of scripts - node = new NodeExt(settings.build() , Arrays.asList(Netty3Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); + node = new NodeExt(settings.build() , Arrays.asList(Netty4Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); node.start(); success = true; System.out.println(EmbeddedElasticsearchServer.class.getName() + ": Using port: " + this.port); diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index b95b65c..c134ea6 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -92,7 +92,7 @@ public void test() throws Exception { " \"boost_mode\": \"replace\"," + " \"script_score\": {" + " \"script\": {" + - " \"inline\": \"binary_vector_score\"," + + " \"source\": \"binary_vector_score\"," + " \"lang\": \"knn\"," + " \"params\": {" + " \"cosine\": false," + From 7cc89d0eb2ca1f67b076f12eb2db642736ad8f9f Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 31 Jan 2019 13:59:06 -0500 Subject: [PATCH 02/25] renamed engine appropriately --- .../com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index b0a666e..3f97f65 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -41,11 +41,11 @@ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { @Override public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { - return new MyExpertScriptEngine(); + return new VectorScoringPluginEngine(); } /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ - private static class MyExpertScriptEngine implements ScriptEngine { + private static class VectorScoringPluginEngine implements ScriptEngine { @Override public String getType() { return "knn"; From ec09b19c47934712de65943060d9bf3a4ac78cf6 Mon Sep 17 00:00:00 2001 From: ran22 Date: Sun, 10 Feb 2019 16:27:02 +0200 Subject: [PATCH 03/25] optimize plugin and use float32 (#19) Note that all vectors should be float instead of doubles This introduces ~ 50% performance boost. so it's worth it --- README.md | 63 ++++----- .../java/com/liorkn/elasticsearch/Util.java | 19 +-- .../script/VectorScoreScript.java | 120 ++++++++---------- .../VectorScoringScriptEngineService.java | 1 - 4 files changed, 95 insertions(+), 108 deletions(-) diff --git a/README.md b/README.md index 9806c21..6ac333a 100644 --- a/README.md +++ b/README.md @@ -42,28 +42,29 @@ give it a try. * The vector can be of any dimension ### Converting a vector to Base64 -to convert an array of doubles to a base64 string we use these example methods: +to convert an array of float32 to a base64 string we use these example methods: **Java** ``` -public static final String convertArrayToBase64(double[] array) { - final int capacity = 8 * array.length; - final ByteBuffer bb = ByteBuffer.allocate(capacity); - for (int i = 0; i < array.length; i++) { - bb.putDouble(array[i]); - } - bb.rewind(); - final ByteBuffer encodedBB = Base64.getEncoder().encode(bb); - return new String(encodedBB.array()); +public static float[] convertBase64ToArray(String base64Str) { + final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes()); + final FloatBuffer floatBuffer = ByteBuffer.wrap(decode).asFloatBuffer(); + final float[] dims = new float[floatBuffer.capacity()]; + floatBuffer.get(dims); + + return dims; } -public static double[] convertBase64ToArray(String base64Str) { - final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes()); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(decode).asDoubleBuffer(); +public static String convertArrayToBase64(float[] array) { + final int capacity = Float.BYTES * array.length; + final ByteBuffer bb = ByteBuffer.allocate(capacity); + for (float v : array) { + bb.putFloat(v); + } + bb.rewind(); + final ByteBuffer encodedBB = Base64.getEncoder().encode(bb); - final double[] dims = new double[doubleBuffer.capacity()]; - doubleBuffer.get(dims); - return dims; + return new String(encodedBB.array()); } ``` **Python** @@ -71,14 +72,14 @@ public static double[] convertBase64ToArray(String base64Str) { import base64 import numpy as np -dbig = np.dtype('>f8') +dfloat32 = np.dtype('>f4') def decode_float_list(base64_string): bytes = base64.b64decode(base64_string) - return np.frombuffer(bytes, dtype=dbig).tolist() + return np.frombuffer(bytes, dtype=dfloat32).tolist() def encode_array(arr): - base64_str = base64.b64encode(np.array(arr).astype(dbig)).decode("utf-8") + base64_str = base64.b64encode(np.array(arr).astype(dfloat32)).decode("utf-8") return base64_str ``` @@ -87,11 +88,11 @@ def encode_array(arr): require 'base64' def decode_float_list(base64_string) - Base64.strict_decode64(base64_string).unpack('G*') + Base64.strict_decode64(base64_string).unpack('g*') end def encode_array(arr) - Base64.strict_encode64(arr.pack('G*')) + Base64.strict_encode64(arr.pack('g*')) end ``` @@ -103,12 +104,12 @@ import( "encoding/base64" ) -func convertArrayToBase64(array []float64) string { - bytes := make([]byte, 0, 8*len(array)) +func convertArrayToBase64(array []float32) string { + bytes := make([]byte, 0, 4*len(array)) for _, a := range array { - bits := math.Float64bits(a) - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, bits) + bits := math.Float32bits(a) + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, bits) bytes = append(bytes, b...) } @@ -116,18 +117,18 @@ func convertArrayToBase64(array []float64) string { return encoded } -func convertBase64ToArray(base64Str string) ([]float64, error) { +func convertBase64ToArray(base64Str string) ([]float32, error) { decoded, err := base64.StdEncoding.DecodeString(base64Str) if err != nil { return nil, err } length := len(decoded) - array := make([]float64, 0, length/8) + array := make([]float32, 0, length/4) - for i := 0; i < len(decoded); i += 8 { - bits := binary.BigEndian.Uint64(decoded[i : i+8]) - f := math.Float64frombits(bits) + for i := 0; i < len(decoded); i += 4 { + bits := binary.BigEndian.Uint64(decoded[i : i+4]) + f := math.Float32frombits(bits) array = append(array, f) } return array, nil diff --git a/src/main/java/com/liorkn/elasticsearch/Util.java b/src/main/java/com/liorkn/elasticsearch/Util.java index de81af8..045a0d0 100644 --- a/src/main/java/com/liorkn/elasticsearch/Util.java +++ b/src/main/java/com/liorkn/elasticsearch/Util.java @@ -1,7 +1,7 @@ package com.liorkn.elasticsearch; import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; import java.util.Base64; /** @@ -9,23 +9,24 @@ */ public class Util { - public static final double[] convertBase64ToArray(String base64Str) { + public static float[] convertBase64ToArray(String base64Str) { final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes()); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(decode).asDoubleBuffer(); + final FloatBuffer floatBuffer = ByteBuffer.wrap(decode).asFloatBuffer(); + final float[] dims = new float[floatBuffer.capacity()]; + floatBuffer.get(dims); - final double[] dims = new double[doubleBuffer.capacity()]; - doubleBuffer.get(dims); return dims; } - public static final String convertArrayToBase64(double[] array) { - final int capacity = 8 * array.length; + public static String convertArrayToBase64(double[] array) { + final int capacity = Float.BYTES * array.length; final ByteBuffer bb = ByteBuffer.allocate(capacity); - for (int i = 0; i < array.length; i++) { - bb.putDouble(array[i]); + for (double v : array) { + bb.putFloat((float) v); } bb.rewind(); final ByteBuffer encodedBB = Base64.getEncoder().encode(bb); + return new String(encodedBB.array()); } } diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index 0b87cf3..c9e0401 100755 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -8,8 +8,6 @@ https://www.elastic.co/guide/en/elasticsearch/guide/current/complex-core-fields.html Delimited Payload Token Filter: https://www.elastic.co/guide/en/elasticsearch/reference/2.4/analysis-delimited-payload-tokenfilter.html - - */ package com.liorkn.elasticsearch.script; @@ -22,42 +20,78 @@ import org.elasticsearch.script.LeafSearchScript; import org.elasticsearch.script.ScriptException; -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; import java.util.ArrayList; -import java.util.Base64; import java.util.Map; + /** * Script that scores documents based on cosine similarity embedding vectors. */ public final class VectorScoreScript implements LeafSearchScript, ExecutableScript { - public static final String SCRIPT_NAME = "binary_vector_score"; - - private static final int DOUBLE_SIZE = 8; - // the field containing the vectors to be scored against public final String field; private int docId; private BinaryDocValues binaryEmbeddingReader; - private final double[] inputVector; - private final double magnitude; + private final float[] inputVector; + private final float magnitude; private final boolean cosine; + @Override + public final Object run() { + return runAsDouble(); + } + @Override public long runAsLong() { - return ((Number)this.run()).longValue(); + return (long) runAsDouble(); } + + /** + * Called for each document + * @return cosine similarity of the current document against the input inputVector + */ @Override public double runAsDouble() { - return ((Number)this.run()).doubleValue(); + final byte[] bytes = binaryEmbeddingReader.get(docId).bytes; + final ByteArrayDataInput input = new ByteArrayDataInput(bytes); + + // MUST appear hear since it affect the next calls + input.readVInt(); // returns the number of values which should be 1 + input.readVInt(); // returns the number of bytes to read + + float score = 0; + + if(cosine) { + float docVectorNorm = 0.0f; + + for (int i = 0; i < inputVector.length; i++) { + float v = Float.intBitsToFloat(input.readInt()); + docVectorNorm += v * v; // inputVector norm + score += v * inputVector[i]; // dot product + } + + if (docVectorNorm == 0 || magnitude == 0) { + return 0f; + } else { + return score / (Math.sqrt(docVectorNorm) * magnitude); + } + } else { + for (int i = 0; i < inputVector.length; i++) { + float v = Float.intBitsToFloat(input.readInt()); + score += v * inputVector[i]; // dot product + } + + return score; + } } + @Override public void setNextVar(String name, Object value) {} + @Override public void setDocument(int docId) { this.docId = docId; @@ -70,7 +104,6 @@ public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { this.binaryEmbeddingReader = binaryEmbeddingReader; } - /** * Factory that is registered in * {@link VectorScoringPlugin#onModule(org.elasticsearch.script.ScriptModule)} @@ -97,10 +130,8 @@ public ExecutableScript newScript(@Nullable Map params) throws S public boolean needsScores() { return false; } - } - /** * Init * @param params index that a scored are placed in this parameter. Initialize them here. @@ -121,9 +152,9 @@ public VectorScoreScript(Map params) { final Object vector = params.get("vector"); if(vector != null) { final ArrayList tmp = (ArrayList) vector; - inputVector = new double[tmp.size()]; + inputVector = new float[tmp.size()]; for (int i = 0; i < inputVector.length; i++) { - inputVector[i] = tmp.get(i); + inputVector[i] = tmp.get(i).floatValue(); } } else { final Object encodedVector = params.get("encoded_vector"); @@ -135,59 +166,14 @@ public VectorScoreScript(Map params) { if(cosine) { // calc magnitude - double queryVectorNorm = 0.0; + float queryVectorNorm = 0.0f; // compute query inputVector norm once - for (double v : inputVector) { + for (float v: inputVector) { queryVectorNorm += v * v; } - magnitude = Math.sqrt(queryVectorNorm); + magnitude = (float) Math.sqrt(queryVectorNorm); } else { - magnitude = 0.0; + magnitude = 0.0f; } } - - - /** - * Called for each document - * @return cosine similarity of the current document against the input inputVector - */ - @Override - public final Object run() { - final int size = inputVector.length; - - final byte[] bytes = binaryEmbeddingReader.get(docId).bytes; - final ByteArrayDataInput input = new ByteArrayDataInput(bytes); - input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls - final int len = input.readVInt(); // returns the number of bytes to read - if(len != size * DOUBLE_SIZE) { - return 0.0; - } - final int position = input.getPosition(); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer(); - - final double[] docVector = new double[size]; - doubleBuffer.get(docVector); - - double docVectorNorm = 0.0f; - double score = 0; - for (int i = 0; i < size; i++) { - // doc inputVector norm - if(cosine) { - docVectorNorm += docVector[i]*docVector[i]; - } - // dot product - score += docVector[i] * inputVector[i]; - } - if(cosine) { - // cosine similarity score - if (docVectorNorm == 0 || magnitude == 0){ - return 0f; - } else { - return score / (Math.sqrt(docVectorNorm) * magnitude); - } - } else { - return score; - } - } - } \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java b/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java index 58db087..269df8a 100755 --- a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java +++ b/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java @@ -33,7 +33,6 @@ public Object compile(String scriptName, String scriptSource, Map Date: Sun, 10 Feb 2019 16:51:06 +0200 Subject: [PATCH 04/25] updated jackson version --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index d765043..7434047 100755 --- a/pom.xml +++ b/pom.xml @@ -31,7 +31,7 @@ 2.4 4.4.8 4.12 - 2.7.4 + 2.8.11.3 From f0c27c5105c7e16efad521fbe49cdacdf66a9382 Mon Sep 17 00:00:00 2001 From: Lior Knaany Date: Mon, 25 Feb 2019 21:57:47 +0200 Subject: [PATCH 05/25] added a cosine score test --- src/test/java/com/liorkn/elasticsearch/PluginTest.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index b95b65c..417db02 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -1,7 +1,9 @@ package com.liorkn.elasticsearch; import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; import org.apache.http.HttpHost; import org.apache.http.entity.ContentType; import org.apache.http.entity.StringEntity; @@ -70,7 +72,6 @@ public static void init() throws Exception { public void test() throws Exception { final Map params = new HashMap<>(); params.put("refresh", "true"); - final ObjectMapper mapper = new ObjectMapper(); final TestObject[] objs = {new TestObject(1, new double[] {0.0, 0.5, 1.0}), new TestObject(2, new double[] {0.2, 0.6, 0.99})}; @@ -95,7 +96,7 @@ public void test() throws Exception { " \"inline\": \"binary_vector_score\"," + " \"lang\": \"knn\"," + " \"params\": {" + - " \"cosine\": false," + + " \"cosine\": true," + " \"field\": \"embedding_vector\"," + " \"vector\": [" + " 0.1, 0.2, 0.3" + @@ -113,6 +114,10 @@ public void test() throws Exception { System.out.println(resBody); Assert.assertEquals("search should return status code 200", 200, res.getStatusLine().getStatusCode()); Assert.assertTrue(String.format("There should be %d documents in the search response", objs.length), resBody.contains("\"hits\":{\"total\":" + objs.length)); + // Testing Scores + final ArrayNode hitsJson = (ArrayNode)mapper.readTree(resBody).get("hits").get("hits"); + Assert.assertEquals(0.9941734, hitsJson.get(0).get("_score").asDouble(), 0); + Assert.assertEquals(0.95618284, hitsJson.get(1).get("_score").asDouble(), 0); } @AfterClass From f688c3a5a74e63f7dfc4236054854d587e342f03 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 4 Apr 2019 09:45:15 -0400 Subject: [PATCH 06/25] Broke it back into multiple files for simplicity, added support for encodedVector field, added runAsLong(), changed searchVector to be a primitive type, made SCRIPT_SOURCE private static class member --- .../java/com/liorkn/elasticsearch/Util.java | 35 ++-- .../engine/VectorScoringScriptEngine.java | 36 ++++ .../plugin/VectorScoringPlugin.java | 149 +---------------- .../script/VectorScoreScript.java | 157 ++++++++++++++++++ 4 files changed, 213 insertions(+), 164 deletions(-) create mode 100644 src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java create mode 100644 src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java diff --git a/src/main/java/com/liorkn/elasticsearch/Util.java b/src/main/java/com/liorkn/elasticsearch/Util.java index de81af8..ed84f95 100644 --- a/src/main/java/com/liorkn/elasticsearch/Util.java +++ b/src/main/java/com/liorkn/elasticsearch/Util.java @@ -9,23 +9,22 @@ */ public class Util { - public static final double[] convertBase64ToArray(String base64Str) { - final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes()); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(decode).asDoubleBuffer(); + public static double[] convertBase64ToArray(String base64Str) { + final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes()); + final DoubleBuffer doubleBuffer = ByteBuffer.wrap(decode).asDoubleBuffer(); + final double[] dims = new double[doubleBuffer.capacity()]; + doubleBuffer.get(dims); + return dims; + } - final double[] dims = new double[doubleBuffer.capacity()]; - doubleBuffer.get(dims); - return dims; - } - - public static final String convertArrayToBase64(double[] array) { - final int capacity = 8 * array.length; - final ByteBuffer bb = ByteBuffer.allocate(capacity); - for (int i = 0; i < array.length; i++) { - bb.putDouble(array[i]); - } - bb.rewind(); - final ByteBuffer encodedBB = Base64.getEncoder().encode(bb); - return new String(encodedBB.array()); - } + public static String convertArrayToBase64(double[] array) { + final int capacity = Double.BYTES * array.length; + final ByteBuffer bb = ByteBuffer.allocate(capacity); + for (double v : array) { + bb.putDouble((double) v); + } + bb.rewind(); + final ByteBuffer encodedBB = Base64.getEncoder().encode(bb); + return new String(encodedBB.array()); + } } diff --git a/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java new file mode 100644 index 0000000..6c00692 --- /dev/null +++ b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java @@ -0,0 +1,36 @@ +package com.liorkn.elasticsearch.engine; + +import com.liorkn.elasticsearch.script.VectorScoreScript; + +import java.util.Map; + +import org.elasticsearch.script.ScriptContext; +import org.elasticsearch.script.ScriptEngine; +import org.elasticsearch.script.SearchScript; + +/** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ +public class VectorScoringScriptEngine implements ScriptEngine { + + public static final String NAME = "knn"; + private static final String SCRIPT_SOURCE = "binary_vector_score"; + + @Override + public String getType() { + return NAME; + } + + @Override + public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { + if (context.equals(SearchScript.CONTEXT) == false) { + throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); + } + + // we use the script "source" as the script identifier + if (!SCRIPT_SOURCE.equals(scriptSource)) { + throw new IllegalArgumentException("Unknown script name " + scriptSource); + } + + SearchScript.Factory factory = VectorScoreScript.VectorScoreScriptFactory::new; + return context.factoryClazz.cast(factory); + } +} diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index 3f97f65..c279636 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -13,25 +13,15 @@ */ package com.liorkn.elasticsearch.plugin; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; +import com.liorkn.elasticsearch.engine.VectorScoringScriptEngine; + import java.util.Collection; -import java.util.Map; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.PostingsEnum; -import org.apache.lucene.index.Term; -import org.apache.lucene.store.ByteArrayDataInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ScriptPlugin; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptEngine; -import org.elasticsearch.script.SearchScript; -import java.util.ArrayList; /** * This class is instantiated when Elasticsearch loads the plugin for the * first time. If you change the name of this plugin, make sure to update @@ -41,139 +31,6 @@ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { @Override public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { - return new VectorScoringPluginEngine(); - } - - /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ - private static class VectorScoringPluginEngine implements ScriptEngine { - @Override - public String getType() { - return "knn"; - } - - private static final int DOUBLE_SIZE = 8; - - @Override - public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { - - if (context.equals(SearchScript.CONTEXT) == false) { - throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); - } - - // we use the script "source" as the script identifier - if ("binary_vector_score".equals(scriptSource)) { - SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { - final String field; - final boolean cosine; - { - if (p.containsKey("vector") == false) { - throw new IllegalArgumentException("Missing parameter [vector]"); - } - if (p.containsKey("field") == false) { - throw new IllegalArgumentException("Missing parameter [field]"); - } - if (p.containsKey("cosine") == false) { - throw new IllegalArgumentException("Missing parameter [cosine]"); - } - field = p.get("field").toString(); - cosine = (boolean) p.get("cosine"); - } - - final ArrayList searchVector = (ArrayList) p.get("vector"); - double magnitude; - { - if (cosine) { - // calc magnitude - double queryVectorNorm = 0.0; - // compute query inputVector norm once - for (Double v : this.searchVector) { - queryVectorNorm += v.doubleValue() * v.doubleValue(); - } - magnitude = Math.sqrt(queryVectorNorm); - } else { - magnitude = 0.0; - } - } - - @Override - public SearchScript newInstance(LeafReaderContext context) throws IOException { - return new SearchScript(p, lookup, context) { - BinaryDocValues docAccess = context.reader().getBinaryDocValues(field); - int currentDocid = -1; - - @Override - public void setDocument(int docid) { - // Move to desired document - try { - docAccess.advanceExact(docid); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - currentDocid = docid; - } - - @Override - public double runAsDouble() { - if (currentDocid < 0) { - return 0.0; - } - //actually run scoring - final int size = searchVector.size(); - - try { - final byte[] bytes = docAccess.binaryValue().bytes; - final ByteArrayDataInput input = new ByteArrayDataInput(bytes); - input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls - final int len = input.readVInt(); // returns the number of bytes to read//if submitted vector is different size - if (len != size * DOUBLE_SIZE) { - return 0.0; - } - - final int position = input.getPosition(); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer(); - - final double[] docVector = new double[size]; - doubleBuffer.get(docVector); - double docVectorNorm = 0.0f; - double score = 0; - for (int i = 0; i < size; i++) { - // doc inputVector norm - if(cosine) { - docVectorNorm += docVector[i]*docVector[i]; - } - // dot product - score += docVector[i] * searchVector.get(i).doubleValue(); - } - if(cosine) { - // cosine similarity score - if (docVectorNorm == 0 || magnitude == 0){ - return 0f; - } else { - return score / (Math.sqrt(docVectorNorm) * magnitude); - } - } else { - return score; - } - } catch (Exception e) { - return 0; - } - } - }; - } - - @Override - public boolean needs_score() { - return false; - } - }; - return context.factoryClazz.cast(factory); - } - throw new IllegalArgumentException("Unknown script name " + scriptSource); - } - - @Override - public void close() { - // optionally close resources - } + return new VectorScoringScriptEngine(); } } \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java new file mode 100644 index 0000000..b682acf --- /dev/null +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -0,0 +1,157 @@ +package com.liorkn.elasticsearch.script; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Map; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.store.ByteArrayDataInput; +import org.elasticsearch.script.SearchScript; +import org.elasticsearch.search.lookup.SearchLookup; + +import com.liorkn.elasticsearch.Util; + +public final class VectorScoreScript extends SearchScript { + + private BinaryDocValues binaryEmbeddingReader; + + private final String field; + private final boolean cosine; + + private final double[] inputVector; + private final double magnitude; + + @Override + public long runAsLong() { + return (long) runAsDouble(); + } + + @Override + public double runAsDouble() { + try { + final byte[] bytes = binaryEmbeddingReader.binaryValue().bytes; + final ByteArrayDataInput input = new ByteArrayDataInput(bytes); + + input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls + + final int len = input.readVInt(); + // in case vector is of different size + if (len != inputVector.length * Double.BYTES) { + return 0.0; + } + + float score = 0; + + if (cosine) { + double docVectorNorm = 0.0f; + for (int i = 0; i < inputVector.length; i++) { + double v = Double.longBitsToDouble(input.readLong()); + docVectorNorm += v * v; // inputVector norm + score += v * inputVector[i]; // dot product + } + + if (docVectorNorm == 0 || magnitude == 0) { + return 0f; + } else { + return score / (Math.sqrt(docVectorNorm) * magnitude); + } + } else { + for (int i = 0; i < inputVector.length; i++) { + double v = Double.longBitsToDouble(input.readLong()); + score += v * inputVector[i]; // dot product + } + + return score; + } + } catch (Exception e) { + return 0.0; + } + } + + @Override + public void setDocument(int docId) { + try { + this.binaryEmbeddingReader.advanceExact(docId); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { + if(binaryEmbeddingReader == null) { + throw new IllegalStateException("binaryEmbeddingReader can't be null"); + } + this.binaryEmbeddingReader = binaryEmbeddingReader; + } + + @SuppressWarnings("unchecked") + public VectorScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { + super(params, lookup, leafContext); + + final Object cosineBool = params.get("cosine"); + this.cosine = cosineBool != null ? + (boolean)cosineBool : + true; + + final Object field = params.get("field"); + if (field == null) + throw new IllegalArgumentException("binary_vector_score script requires field input"); + this.field = field.toString(); + + // get query inputVector - convert to primitive + final Object vector = params.get("vector"); + if(vector != null) { + final ArrayList tmp = (ArrayList) vector; + inputVector = new double[tmp.size()]; + for (int i = 0; i < inputVector.length; i++) { + inputVector[i] = tmp.get(i).doubleValue(); + } + } else { + final Object encodedVector = params.get("encoded_vector"); + if(encodedVector == null) { + throw new IllegalArgumentException("Must have at 'vector' or 'encoded_vector' as a parameter"); + } + inputVector = Util.convertBase64ToArray((String) encodedVector); + } + + if (this.cosine) { + // calc magnitude + double queryVectorNorm = 0.0f; + // compute query inputVector norm once + for (double v: this.inputVector) { + queryVectorNorm += v * v; + } + this.magnitude = (double) Math.sqrt(queryVectorNorm); + } else { + this.magnitude = 0.0f; + } + + try { + this.binaryEmbeddingReader = leafContext.reader().getBinaryDocValues(this.field); + } catch (IOException e) { + throw new IllegalStateException("binaryEmbeddingReader can't be null"); + } + } + + public static class VectorScoreScriptFactory implements LeafFactory { + private final Map params; + private final SearchLookup lookup; + + public VectorScoreScriptFactory(Map params, SearchLookup lookup) { + this.params = params; + this.lookup = lookup; + } + + public boolean needs_score() { + return false; + } + + @Override + public SearchScript newInstance(LeafReaderContext ctx) throws IOException { + return new VectorScoreScript(this.params, this.lookup, ctx); + } + } +} \ No newline at end of file From 2aa0f87d74b7b0606285b44f9f02ff7848ce1b7c Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 4 Apr 2019 09:48:28 -0400 Subject: [PATCH 07/25] Removed unused import --- .../java/com/liorkn/elasticsearch/script/VectorScoreScript.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index b682acf..f2bb77c 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -3,7 +3,6 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Map; import org.apache.lucene.index.BinaryDocValues; From 44fafbd6b499c1b1b752e423c22895fe75dca6ad Mon Sep 17 00:00:00 2001 From: lior-k Date: Thu, 4 Apr 2019 16:54:42 +0300 Subject: [PATCH 08/25] Update README.md Fixed an issue with the go example code --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6ac333a..9ad9f9b 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,7 @@ func convertBase64ToArray(base64Str string) ([]float32, error) { array := make([]float32, 0, length/4) for i := 0; i < len(decoded); i += 4 { - bits := binary.BigEndian.Uint64(decoded[i : i+4]) + bits := binary.BigEndian.Uint32(decoded[i : i+4]) f := math.Float32frombits(bits) array = append(array, f) } From 36784580854593855d7ffcf8c16a635f8e6fac60 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 4 Apr 2019 13:37:13 -0400 Subject: [PATCH 09/25] Removed unused method --- .../com/liorkn/elasticsearch/script/VectorScoreScript.java | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index f2bb77c..53d97dd 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -78,13 +78,6 @@ public void setDocument(int docId) { throw new UncheckedIOException(e); } } - - public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { - if(binaryEmbeddingReader == null) { - throw new IllegalStateException("binaryEmbeddingReader can't be null"); - } - this.binaryEmbeddingReader = binaryEmbeddingReader; - } @SuppressWarnings("unchecked") public VectorScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { From 63d84a3ebe89ba012358b06a1db85bce8e549273 Mon Sep 17 00:00:00 2001 From: Lior Knaany Date: Tue, 23 Apr 2019 21:36:12 +0300 Subject: [PATCH 10/25] changed a double usage in a test to float. as part of the move to floats (for performance reasons), this was missed --- src/main/java/com/liorkn/elasticsearch/Util.java | 2 +- src/test/java/com/liorkn/elasticsearch/PluginTest.java | 4 ++-- src/test/java/com/liorkn/elasticsearch/TestObject.java | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/main/java/com/liorkn/elasticsearch/Util.java b/src/main/java/com/liorkn/elasticsearch/Util.java index 045a0d0..53cf1f6 100644 --- a/src/main/java/com/liorkn/elasticsearch/Util.java +++ b/src/main/java/com/liorkn/elasticsearch/Util.java @@ -18,7 +18,7 @@ public static float[] convertBase64ToArray(String base64Str) { return dims; } - public static String convertArrayToBase64(double[] array) { + public static String convertArrayToBase64(float[] array) { final int capacity = Float.BYTES * array.length; final ByteBuffer bb = ByteBuffer.allocate(capacity); for (double v : array) { diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index 417db02..3a33746 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -72,8 +72,8 @@ public static void init() throws Exception { public void test() throws Exception { final Map params = new HashMap<>(); params.put("refresh", "true"); - final TestObject[] objs = {new TestObject(1, new double[] {0.0, 0.5, 1.0}), - new TestObject(2, new double[] {0.2, 0.6, 0.99})}; + final TestObject[] objs = {new TestObject(1, new float[] {0.0f, 0.5f, 1.0f}), + new TestObject(2, new float[] {0.2f, 0.6f, 0.99f})}; for (int i = 0; i < objs.length; i++) { final TestObject t = objs[i]; diff --git a/src/test/java/com/liorkn/elasticsearch/TestObject.java b/src/test/java/com/liorkn/elasticsearch/TestObject.java index f37d98a..a8d0391 100644 --- a/src/test/java/com/liorkn/elasticsearch/TestObject.java +++ b/src/test/java/com/liorkn/elasticsearch/TestObject.java @@ -10,7 +10,7 @@ public class TestObject { int jobId; String embeddingVector; - double[] vector; + float[] vector; public int getJobId() { return jobId; @@ -20,11 +20,11 @@ public String getEmbeddingVector() { return embeddingVector; } - public double[] getVector() { + public float[] getVector() { return vector; } - public TestObject(int jobId, double[] vector) { + public TestObject(int jobId, float[] vector) { this.jobId = jobId; this.vector = vector; this.embeddingVector = Util.convertArrayToBase64(vector); From d9e04a1968a75c3cbd184f6ada419629eb549bda Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Tue, 23 Apr 2019 15:03:16 -0400 Subject: [PATCH 11/25] Fixed vector to be float instead of double --- .../java/com/liorkn/elasticsearch/Util.java | 38 ++++++++++--------- .../script/VectorScoreScript.java | 20 +++++----- .../com/liorkn/elasticsearch/PluginTest.java | 4 +- .../com/liorkn/elasticsearch/TestObject.java | 8 ++-- 4 files changed, 36 insertions(+), 34 deletions(-) diff --git a/src/main/java/com/liorkn/elasticsearch/Util.java b/src/main/java/com/liorkn/elasticsearch/Util.java index ed84f95..53cf1f6 100644 --- a/src/main/java/com/liorkn/elasticsearch/Util.java +++ b/src/main/java/com/liorkn/elasticsearch/Util.java @@ -1,7 +1,7 @@ package com.liorkn.elasticsearch; import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; import java.util.Base64; /** @@ -9,22 +9,24 @@ */ public class Util { - public static double[] convertBase64ToArray(String base64Str) { - final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes()); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(decode).asDoubleBuffer(); - final double[] dims = new double[doubleBuffer.capacity()]; - doubleBuffer.get(dims); - return dims; - } + public static float[] convertBase64ToArray(String base64Str) { + final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes()); + final FloatBuffer floatBuffer = ByteBuffer.wrap(decode).asFloatBuffer(); + final float[] dims = new float[floatBuffer.capacity()]; + floatBuffer.get(dims); - public static String convertArrayToBase64(double[] array) { - final int capacity = Double.BYTES * array.length; - final ByteBuffer bb = ByteBuffer.allocate(capacity); - for (double v : array) { - bb.putDouble((double) v); - } - bb.rewind(); - final ByteBuffer encodedBB = Base64.getEncoder().encode(bb); - return new String(encodedBB.array()); - } + return dims; + } + + public static String convertArrayToBase64(float[] array) { + final int capacity = Float.BYTES * array.length; + final ByteBuffer bb = ByteBuffer.allocate(capacity); + for (double v : array) { + bb.putFloat((float) v); + } + bb.rewind(); + final ByteBuffer encodedBB = Base64.getEncoder().encode(bb); + + return new String(encodedBB.array()); + } } diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index 53d97dd..cfceeb6 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -20,8 +20,8 @@ public final class VectorScoreScript extends SearchScript { private final String field; private final boolean cosine; - private final double[] inputVector; - private final double magnitude; + private final float[] inputVector; + private final float magnitude; @Override public long runAsLong() { @@ -45,9 +45,9 @@ public double runAsDouble() { float score = 0; if (cosine) { - double docVectorNorm = 0.0f; + float docVectorNorm = 0.0f; for (int i = 0; i < inputVector.length; i++) { - double v = Double.longBitsToDouble(input.readLong()); + float v = Float.intBitsToFloat(input.readInt()); docVectorNorm += v * v; // inputVector norm score += v * inputVector[i]; // dot product } @@ -59,7 +59,7 @@ public double runAsDouble() { } } else { for (int i = 0; i < inputVector.length; i++) { - double v = Double.longBitsToDouble(input.readLong()); + float v = Float.intBitsToFloat(input.readInt()); score += v * inputVector[i]; // dot product } @@ -97,9 +97,9 @@ public VectorScoreScript(Map params, SearchLookup lookup, LeafRe final Object vector = params.get("vector"); if(vector != null) { final ArrayList tmp = (ArrayList) vector; - inputVector = new double[tmp.size()]; + inputVector = new float[tmp.size()]; for (int i = 0; i < inputVector.length; i++) { - inputVector[i] = tmp.get(i).doubleValue(); + inputVector[i] = tmp.get(i).floatValue(); } } else { final Object encodedVector = params.get("encoded_vector"); @@ -111,12 +111,12 @@ public VectorScoreScript(Map params, SearchLookup lookup, LeafRe if (this.cosine) { // calc magnitude - double queryVectorNorm = 0.0f; + float queryVectorNorm = 0.0f; // compute query inputVector norm once - for (double v: this.inputVector) { + for (float v: this.inputVector) { queryVectorNorm += v * v; } - this.magnitude = (double) Math.sqrt(queryVectorNorm); + this.magnitude = (float) Math.sqrt(queryVectorNorm); } else { this.magnitude = 0.0f; } diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index c134ea6..d856ad8 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -71,8 +71,8 @@ public void test() throws Exception { final Map params = new HashMap<>(); params.put("refresh", "true"); final ObjectMapper mapper = new ObjectMapper(); - final TestObject[] objs = {new TestObject(1, new double[] {0.0, 0.5, 1.0}), - new TestObject(2, new double[] {0.2, 0.6, 0.99})}; + final TestObject[] objs = {new TestObject(1, new float[] {0.0f, 0.5f, 1.0f}), + new TestObject(2, new float[] {0.2f, 0.6f, 0.99f})}; for (int i = 0; i < objs.length; i++) { final TestObject t = objs[i]; diff --git a/src/test/java/com/liorkn/elasticsearch/TestObject.java b/src/test/java/com/liorkn/elasticsearch/TestObject.java index f37d98a..8338e31 100644 --- a/src/test/java/com/liorkn/elasticsearch/TestObject.java +++ b/src/test/java/com/liorkn/elasticsearch/TestObject.java @@ -10,7 +10,7 @@ public class TestObject { int jobId; String embeddingVector; - double[] vector; + float[] vector; public int getJobId() { return jobId; @@ -20,13 +20,13 @@ public String getEmbeddingVector() { return embeddingVector; } - public double[] getVector() { + public float[] getVector() { return vector; } - public TestObject(int jobId, double[] vector) { + public TestObject(int jobId, float[] vector) { this.jobId = jobId; this.vector = vector; this.embeddingVector = Util.convertArrayToBase64(vector); } -} +} \ No newline at end of file From 71fae4e9a6b4b25f9c9f0adf2986d8476a17e085 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Wed, 30 Jan 2019 17:23:41 -0500 Subject: [PATCH 12/25] Changed script to work with Elastic 6.0.0. Inline scripts are now depreciated so using 'source' field instead --- README.md | 4 +- pom.xml | 12 +- .../plugin/VectorScoringPlugin.java | 154 ++++++++++++++- .../script/VectorScoreScript.java | 179 ------------------ .../VectorScoringScriptEngineService.java | 77 -------- .../EmbeddedElasticsearchServer.java | 11 +- .../com/liorkn/elasticsearch/PluginTest.java | 2 +- 7 files changed, 159 insertions(+), 280 deletions(-) delete mode 100755 src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java delete mode 100755 src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java diff --git a/README.md b/README.md index 9ad9f9b..5a8f048 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ give it a try. ## Elasticsearch version -* Currently designed for Elasticsearch 5.6.0. +* Currently designed for Elasticsearch 6.0.0. * for Elasticsearch 5.2.2 use branch `es-5.2.2` * for Elasticsearch 2.4.4 use branch `es-2.4.4` @@ -147,7 +147,7 @@ func convertBase64ToArray(base64Str string) ([]float32, error) { "boost_mode": "replace", "script_score": { "script": { - "inline": "binary_vector_score", + "source": "binary_vector_score", "lang": "knn", "params": { "cosine": false, diff --git a/pom.xml b/pom.xml index 7434047..79d7430 100755 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ elasticsearch-binary-vector-scoring com.liorkn.elasticsearch elasticsearch-binary-vector-scoring - 5.6.0 + 6.0.0 ElasticSearch Plugin for Binary Vector Scoring @@ -27,7 +27,7 @@ ${project.basedir}/src/main/resources/license-check/license_header_definition.xml warn - 5.6.0 + 6.0.0 2.4 4.4.8 4.12 @@ -65,7 +65,7 @@ org.elasticsearch.plugin - transport-netty3-client + transport-netty4-client ${elasticsearch.version} test @@ -86,12 +86,6 @@ - - - - - - org.codelibs.elasticsearch.module lang-painless diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index 88a3599..b0a666e 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -13,13 +13,25 @@ */ package com.liorkn.elasticsearch.plugin; -import com.liorkn.elasticsearch.service.VectorScoringScriptEngineService; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.util.Collection; +import java.util.Map; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.Term; +import org.apache.lucene.store.ByteArrayDataInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ScriptPlugin; -import org.elasticsearch.script.ScriptEngineService; - +import org.elasticsearch.script.ScriptContext; +import org.elasticsearch.script.ScriptEngine; +import org.elasticsearch.script.SearchScript; +import java.util.ArrayList; /** * This class is instantiated when Elasticsearch loads the plugin for the * first time. If you change the name of this plugin, make sure to update @@ -27,9 +39,141 @@ */ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { - public final ScriptEngineService getScriptEngineService(Settings settings) { - return new VectorScoringScriptEngineService(settings); + @Override + public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { + return new MyExpertScriptEngine(); } + /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ + private static class MyExpertScriptEngine implements ScriptEngine { + @Override + public String getType() { + return "knn"; + } + + private static final int DOUBLE_SIZE = 8; + + @Override + public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { + + if (context.equals(SearchScript.CONTEXT) == false) { + throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); + } + + // we use the script "source" as the script identifier + if ("binary_vector_score".equals(scriptSource)) { + SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { + final String field; + final boolean cosine; + { + if (p.containsKey("vector") == false) { + throw new IllegalArgumentException("Missing parameter [vector]"); + } + if (p.containsKey("field") == false) { + throw new IllegalArgumentException("Missing parameter [field]"); + } + if (p.containsKey("cosine") == false) { + throw new IllegalArgumentException("Missing parameter [cosine]"); + } + field = p.get("field").toString(); + cosine = (boolean) p.get("cosine"); + } + + final ArrayList searchVector = (ArrayList) p.get("vector"); + double magnitude; + { + if (cosine) { + // calc magnitude + double queryVectorNorm = 0.0; + // compute query inputVector norm once + for (Double v : this.searchVector) { + queryVectorNorm += v.doubleValue() * v.doubleValue(); + } + magnitude = Math.sqrt(queryVectorNorm); + } else { + magnitude = 0.0; + } + } + + @Override + public SearchScript newInstance(LeafReaderContext context) throws IOException { + return new SearchScript(p, lookup, context) { + BinaryDocValues docAccess = context.reader().getBinaryDocValues(field); + int currentDocid = -1; + + @Override + public void setDocument(int docid) { + // Move to desired document + try { + docAccess.advanceExact(docid); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + currentDocid = docid; + } + + @Override + public double runAsDouble() { + if (currentDocid < 0) { + return 0.0; + } + //actually run scoring + final int size = searchVector.size(); + + try { + final byte[] bytes = docAccess.binaryValue().bytes; + final ByteArrayDataInput input = new ByteArrayDataInput(bytes); + input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls + final int len = input.readVInt(); // returns the number of bytes to read//if submitted vector is different size + if (len != size * DOUBLE_SIZE) { + return 0.0; + } + + final int position = input.getPosition(); + final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer(); + final double[] docVector = new double[size]; + doubleBuffer.get(docVector); + double docVectorNorm = 0.0f; + double score = 0; + for (int i = 0; i < size; i++) { + // doc inputVector norm + if(cosine) { + docVectorNorm += docVector[i]*docVector[i]; + } + // dot product + score += docVector[i] * searchVector.get(i).doubleValue(); + } + if(cosine) { + // cosine similarity score + if (docVectorNorm == 0 || magnitude == 0){ + return 0f; + } else { + return score / (Math.sqrt(docVectorNorm) * magnitude); + } + } else { + return score; + } + } catch (Exception e) { + return 0; + } + } + }; + } + + @Override + public boolean needs_score() { + return false; + } + }; + return context.factoryClazz.cast(factory); + } + throw new IllegalArgumentException("Unknown script name " + scriptSource); + } + + @Override + public void close() { + // optionally close resources + } + } } \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java deleted file mode 100755 index c9e0401..0000000 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ /dev/null @@ -1,179 +0,0 @@ -/* -Based on: https://discuss.elastic.co/t/vector-scoring/85227/4 -and https://github.com/MLnick/elasticsearch-vector-scoring - -another slower implementation using strings: https://github.com/ginobefun/elasticsearch-feature-vector-scoring - -storing arrays is no luck - lucine index doesn't keep the array members orders -https://www.elastic.co/guide/en/elasticsearch/guide/current/complex-core-fields.html - -Delimited Payload Token Filter: https://www.elastic.co/guide/en/elasticsearch/reference/2.4/analysis-delimited-payload-tokenfilter.html - */ - -package com.liorkn.elasticsearch.script; - -import com.liorkn.elasticsearch.Util; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.store.ByteArrayDataInput; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.script.ExecutableScript; -import org.elasticsearch.script.LeafSearchScript; -import org.elasticsearch.script.ScriptException; - -import java.util.ArrayList; -import java.util.Map; - - -/** - * Script that scores documents based on cosine similarity embedding vectors. - */ -public final class VectorScoreScript implements LeafSearchScript, ExecutableScript { - - // the field containing the vectors to be scored against - public final String field; - - private int docId; - private BinaryDocValues binaryEmbeddingReader; - - private final float[] inputVector; - private final float magnitude; - - private final boolean cosine; - - @Override - public final Object run() { - return runAsDouble(); - } - - @Override - public long runAsLong() { - return (long) runAsDouble(); - } - - /** - * Called for each document - * @return cosine similarity of the current document against the input inputVector - */ - @Override - public double runAsDouble() { - final byte[] bytes = binaryEmbeddingReader.get(docId).bytes; - final ByteArrayDataInput input = new ByteArrayDataInput(bytes); - - // MUST appear hear since it affect the next calls - input.readVInt(); // returns the number of values which should be 1 - input.readVInt(); // returns the number of bytes to read - - float score = 0; - - if(cosine) { - float docVectorNorm = 0.0f; - - for (int i = 0; i < inputVector.length; i++) { - float v = Float.intBitsToFloat(input.readInt()); - docVectorNorm += v * v; // inputVector norm - score += v * inputVector[i]; // dot product - } - - if (docVectorNorm == 0 || magnitude == 0) { - return 0f; - } else { - return score / (Math.sqrt(docVectorNorm) * magnitude); - } - } else { - for (int i = 0; i < inputVector.length; i++) { - float v = Float.intBitsToFloat(input.readInt()); - score += v * inputVector[i]; // dot product - } - - return score; - } - } - - @Override - public void setNextVar(String name, Object value) {} - - @Override - public void setDocument(int docId) { - this.docId = docId; - } - - public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { - if(binaryEmbeddingReader == null) { - throw new IllegalStateException("binaryEmbeddingReader can't be null"); - } - this.binaryEmbeddingReader = binaryEmbeddingReader; - } - - /** - * Factory that is registered in - * {@link VectorScoringPlugin#onModule(org.elasticsearch.script.ScriptModule)} - * method when the plugin is loaded. - */ - public static class Factory { - - /** - * This method is called for every search on every shard. - * - * @param params - * list of script parameters passed with the query - * @return new native script - */ - public ExecutableScript newScript(@Nullable Map params) throws ScriptException { - return new VectorScoreScript(params); - } - - /** - * Indicates if document scores may be needed by the produced scripts. - * - * @return {@code true} if scores are needed. - */ - public boolean needsScores() { - return false; - } - } - - /** - * Init - * @param params index that a scored are placed in this parameter. Initialize them here. - */ - @SuppressWarnings("unchecked") - public VectorScoreScript(Map params) { - final Object cosineBool = params.get("cosine"); - cosine = cosineBool != null ? - (boolean)cosineBool : - true; - - final Object field = params.get("field"); - if (field == null) - throw new IllegalArgumentException("binary_vector_score script requires field input"); - this.field = field.toString(); - - // get query inputVector - convert to primitive - final Object vector = params.get("vector"); - if(vector != null) { - final ArrayList tmp = (ArrayList) vector; - inputVector = new float[tmp.size()]; - for (int i = 0; i < inputVector.length; i++) { - inputVector[i] = tmp.get(i).floatValue(); - } - } else { - final Object encodedVector = params.get("encoded_vector"); - if(encodedVector == null) { - throw new IllegalArgumentException("Must have at 'vector' or 'encoded_vector' as a parameter"); - } - inputVector = Util.convertBase64ToArray((String) encodedVector); - } - - if(cosine) { - // calc magnitude - float queryVectorNorm = 0.0f; - // compute query inputVector norm once - for (float v: inputVector) { - queryVectorNorm += v * v; - } - magnitude = (float) Math.sqrt(queryVectorNorm); - } else { - magnitude = 0.0f; - } - } -} \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java b/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java deleted file mode 100755 index 269df8a..0000000 --- a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java +++ /dev/null @@ -1,77 +0,0 @@ -package com.liorkn.elasticsearch.service; - -import com.liorkn.elasticsearch.script.VectorScoreScript; -import org.apache.lucene.index.LeafReaderContext; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.component.AbstractComponent; -import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.script.CompiledScript; -import org.elasticsearch.script.ExecutableScript; -import org.elasticsearch.script.LeafSearchScript; -import org.elasticsearch.script.ScriptEngineService; -import org.elasticsearch.script.SearchScript; -import org.elasticsearch.search.lookup.SearchLookup; - -import java.io.IOException; -import java.util.Map; - -/** - * Created by Lior Knaany on 5/14/17. - */ -public class VectorScoringScriptEngineService extends AbstractComponent implements ScriptEngineService{ - - public static final String NAME = "knn"; - - @Inject - public VectorScoringScriptEngineService(Settings settings) { - super(settings); - } - - @Override - public Object compile(String scriptName, String scriptSource, Map params) { - return new VectorScoreScript.Factory(); - } - - @Override - public boolean isInlineScriptEnabled() { - return true; - } - - @Override - public String getType() { - return NAME; - } - - @Override - public String getExtension() { - return NAME; - } - - @Override - public ExecutableScript executable(CompiledScript compiledScript, @Nullable Map vars) { - VectorScoreScript.Factory scriptFactory = (VectorScoreScript.Factory) compiledScript.compiled(); - return scriptFactory.newScript(vars); - } - - @Override - public SearchScript search(CompiledScript compiledScript, final SearchLookup lookup, @Nullable final Map vars) { - final VectorScoreScript.Factory scriptFactory = (VectorScoreScript.Factory) compiledScript.compiled(); - final VectorScoreScript script = (VectorScoreScript) scriptFactory.newScript(vars); - return new SearchScript() { - @Override - public LeafSearchScript getLeafSearchScript(LeafReaderContext context) throws IOException { - script.setBinaryEmbeddingReader(context.reader().getBinaryDocValues(script.field)); - return script; - } - @Override - public boolean needsScores() { - return scriptFactory.needsScores(); - } - }; - } - - @Override - public void close() { - } -} diff --git a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java index 5627240..1cbdb9c 100644 --- a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java +++ b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java @@ -10,7 +10,7 @@ import org.elasticsearch.node.Node; import org.elasticsearch.node.NodeValidationException; import org.elasticsearch.painless.PainlessPlugin; -import org.elasticsearch.transport.Netty3Plugin; +import org.elasticsearch.transport.Netty4Plugin; import java.io.File; import java.io.IOException; @@ -41,13 +41,10 @@ private EmbeddedElasticsearchServer(String defaultDataDirectory, int port) throw Settings.Builder settings = Settings.builder() .put("http.enabled", "true") - .put("transport.type", "local") - .put("http.type", "netty3") + .put("http.type", "netty4") .put("path.data", dataDirectory) .put("path.home", DEFAULT_HOME_DIRECTORY) - .put("script.inline", "on") - .put("node.max_local_storage_nodes", 10000) - .put("script.stored", "on"); + .put("node.max_local_storage_nodes", 10000); startNodeInAvailablePort(settings); } @@ -61,7 +58,7 @@ private void startNodeInAvailablePort(Settings.Builder settings) throws NodeVali settings.put("http.port", String.valueOf(this.port)); // this a hack in order to load Groovy plug in since we want to enable the usage of scripts - node = new NodeExt(settings.build() , Arrays.asList(Netty3Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); + node = new NodeExt(settings.build() , Arrays.asList(Netty4Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); node.start(); success = true; System.out.println(EmbeddedElasticsearchServer.class.getName() + ": Using port: " + this.port); diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index 3a33746..b7d5747 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -93,7 +93,7 @@ public void test() throws Exception { " \"boost_mode\": \"replace\"," + " \"script_score\": {" + " \"script\": {" + - " \"inline\": \"binary_vector_score\"," + + " \"source\": \"binary_vector_score\"," + " \"lang\": \"knn\"," + " \"params\": {" + " \"cosine\": true," + From 343ee3155b159902af57d9964727fef37104346a Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 31 Jan 2019 13:59:06 -0500 Subject: [PATCH 13/25] renamed engine appropriately --- .../com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index b0a666e..3f97f65 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -41,11 +41,11 @@ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { @Override public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { - return new MyExpertScriptEngine(); + return new VectorScoringPluginEngine(); } /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ - private static class MyExpertScriptEngine implements ScriptEngine { + private static class VectorScoringPluginEngine implements ScriptEngine { @Override public String getType() { return "knn"; From 9fdb85fb81f0955ac0bccbabc08bcad89b7163a1 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 4 Apr 2019 09:45:15 -0400 Subject: [PATCH 14/25] Broke it back into multiple files for simplicity, added support for encodedVector field, added runAsLong(), changed searchVector to be a primitive type, made SCRIPT_SOURCE private static class member --- .../engine/VectorScoringScriptEngine.java | 36 ++++ .../plugin/VectorScoringPlugin.java | 149 +---------------- .../script/VectorScoreScript.java | 157 ++++++++++++++++++ 3 files changed, 196 insertions(+), 146 deletions(-) create mode 100644 src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java create mode 100644 src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java diff --git a/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java new file mode 100644 index 0000000..6c00692 --- /dev/null +++ b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java @@ -0,0 +1,36 @@ +package com.liorkn.elasticsearch.engine; + +import com.liorkn.elasticsearch.script.VectorScoreScript; + +import java.util.Map; + +import org.elasticsearch.script.ScriptContext; +import org.elasticsearch.script.ScriptEngine; +import org.elasticsearch.script.SearchScript; + +/** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ +public class VectorScoringScriptEngine implements ScriptEngine { + + public static final String NAME = "knn"; + private static final String SCRIPT_SOURCE = "binary_vector_score"; + + @Override + public String getType() { + return NAME; + } + + @Override + public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { + if (context.equals(SearchScript.CONTEXT) == false) { + throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); + } + + // we use the script "source" as the script identifier + if (!SCRIPT_SOURCE.equals(scriptSource)) { + throw new IllegalArgumentException("Unknown script name " + scriptSource); + } + + SearchScript.Factory factory = VectorScoreScript.VectorScoreScriptFactory::new; + return context.factoryClazz.cast(factory); + } +} diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index 3f97f65..c279636 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -13,25 +13,15 @@ */ package com.liorkn.elasticsearch.plugin; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; +import com.liorkn.elasticsearch.engine.VectorScoringScriptEngine; + import java.util.Collection; -import java.util.Map; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.PostingsEnum; -import org.apache.lucene.index.Term; -import org.apache.lucene.store.ByteArrayDataInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ScriptPlugin; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptEngine; -import org.elasticsearch.script.SearchScript; -import java.util.ArrayList; /** * This class is instantiated when Elasticsearch loads the plugin for the * first time. If you change the name of this plugin, make sure to update @@ -41,139 +31,6 @@ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { @Override public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { - return new VectorScoringPluginEngine(); - } - - /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ - private static class VectorScoringPluginEngine implements ScriptEngine { - @Override - public String getType() { - return "knn"; - } - - private static final int DOUBLE_SIZE = 8; - - @Override - public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { - - if (context.equals(SearchScript.CONTEXT) == false) { - throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); - } - - // we use the script "source" as the script identifier - if ("binary_vector_score".equals(scriptSource)) { - SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { - final String field; - final boolean cosine; - { - if (p.containsKey("vector") == false) { - throw new IllegalArgumentException("Missing parameter [vector]"); - } - if (p.containsKey("field") == false) { - throw new IllegalArgumentException("Missing parameter [field]"); - } - if (p.containsKey("cosine") == false) { - throw new IllegalArgumentException("Missing parameter [cosine]"); - } - field = p.get("field").toString(); - cosine = (boolean) p.get("cosine"); - } - - final ArrayList searchVector = (ArrayList) p.get("vector"); - double magnitude; - { - if (cosine) { - // calc magnitude - double queryVectorNorm = 0.0; - // compute query inputVector norm once - for (Double v : this.searchVector) { - queryVectorNorm += v.doubleValue() * v.doubleValue(); - } - magnitude = Math.sqrt(queryVectorNorm); - } else { - magnitude = 0.0; - } - } - - @Override - public SearchScript newInstance(LeafReaderContext context) throws IOException { - return new SearchScript(p, lookup, context) { - BinaryDocValues docAccess = context.reader().getBinaryDocValues(field); - int currentDocid = -1; - - @Override - public void setDocument(int docid) { - // Move to desired document - try { - docAccess.advanceExact(docid); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - currentDocid = docid; - } - - @Override - public double runAsDouble() { - if (currentDocid < 0) { - return 0.0; - } - //actually run scoring - final int size = searchVector.size(); - - try { - final byte[] bytes = docAccess.binaryValue().bytes; - final ByteArrayDataInput input = new ByteArrayDataInput(bytes); - input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls - final int len = input.readVInt(); // returns the number of bytes to read//if submitted vector is different size - if (len != size * DOUBLE_SIZE) { - return 0.0; - } - - final int position = input.getPosition(); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer(); - - final double[] docVector = new double[size]; - doubleBuffer.get(docVector); - double docVectorNorm = 0.0f; - double score = 0; - for (int i = 0; i < size; i++) { - // doc inputVector norm - if(cosine) { - docVectorNorm += docVector[i]*docVector[i]; - } - // dot product - score += docVector[i] * searchVector.get(i).doubleValue(); - } - if(cosine) { - // cosine similarity score - if (docVectorNorm == 0 || magnitude == 0){ - return 0f; - } else { - return score / (Math.sqrt(docVectorNorm) * magnitude); - } - } else { - return score; - } - } catch (Exception e) { - return 0; - } - } - }; - } - - @Override - public boolean needs_score() { - return false; - } - }; - return context.factoryClazz.cast(factory); - } - throw new IllegalArgumentException("Unknown script name " + scriptSource); - } - - @Override - public void close() { - // optionally close resources - } + return new VectorScoringScriptEngine(); } } \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java new file mode 100644 index 0000000..b682acf --- /dev/null +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -0,0 +1,157 @@ +package com.liorkn.elasticsearch.script; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Map; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.store.ByteArrayDataInput; +import org.elasticsearch.script.SearchScript; +import org.elasticsearch.search.lookup.SearchLookup; + +import com.liorkn.elasticsearch.Util; + +public final class VectorScoreScript extends SearchScript { + + private BinaryDocValues binaryEmbeddingReader; + + private final String field; + private final boolean cosine; + + private final double[] inputVector; + private final double magnitude; + + @Override + public long runAsLong() { + return (long) runAsDouble(); + } + + @Override + public double runAsDouble() { + try { + final byte[] bytes = binaryEmbeddingReader.binaryValue().bytes; + final ByteArrayDataInput input = new ByteArrayDataInput(bytes); + + input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls + + final int len = input.readVInt(); + // in case vector is of different size + if (len != inputVector.length * Double.BYTES) { + return 0.0; + } + + float score = 0; + + if (cosine) { + double docVectorNorm = 0.0f; + for (int i = 0; i < inputVector.length; i++) { + double v = Double.longBitsToDouble(input.readLong()); + docVectorNorm += v * v; // inputVector norm + score += v * inputVector[i]; // dot product + } + + if (docVectorNorm == 0 || magnitude == 0) { + return 0f; + } else { + return score / (Math.sqrt(docVectorNorm) * magnitude); + } + } else { + for (int i = 0; i < inputVector.length; i++) { + double v = Double.longBitsToDouble(input.readLong()); + score += v * inputVector[i]; // dot product + } + + return score; + } + } catch (Exception e) { + return 0.0; + } + } + + @Override + public void setDocument(int docId) { + try { + this.binaryEmbeddingReader.advanceExact(docId); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { + if(binaryEmbeddingReader == null) { + throw new IllegalStateException("binaryEmbeddingReader can't be null"); + } + this.binaryEmbeddingReader = binaryEmbeddingReader; + } + + @SuppressWarnings("unchecked") + public VectorScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { + super(params, lookup, leafContext); + + final Object cosineBool = params.get("cosine"); + this.cosine = cosineBool != null ? + (boolean)cosineBool : + true; + + final Object field = params.get("field"); + if (field == null) + throw new IllegalArgumentException("binary_vector_score script requires field input"); + this.field = field.toString(); + + // get query inputVector - convert to primitive + final Object vector = params.get("vector"); + if(vector != null) { + final ArrayList tmp = (ArrayList) vector; + inputVector = new double[tmp.size()]; + for (int i = 0; i < inputVector.length; i++) { + inputVector[i] = tmp.get(i).doubleValue(); + } + } else { + final Object encodedVector = params.get("encoded_vector"); + if(encodedVector == null) { + throw new IllegalArgumentException("Must have at 'vector' or 'encoded_vector' as a parameter"); + } + inputVector = Util.convertBase64ToArray((String) encodedVector); + } + + if (this.cosine) { + // calc magnitude + double queryVectorNorm = 0.0f; + // compute query inputVector norm once + for (double v: this.inputVector) { + queryVectorNorm += v * v; + } + this.magnitude = (double) Math.sqrt(queryVectorNorm); + } else { + this.magnitude = 0.0f; + } + + try { + this.binaryEmbeddingReader = leafContext.reader().getBinaryDocValues(this.field); + } catch (IOException e) { + throw new IllegalStateException("binaryEmbeddingReader can't be null"); + } + } + + public static class VectorScoreScriptFactory implements LeafFactory { + private final Map params; + private final SearchLookup lookup; + + public VectorScoreScriptFactory(Map params, SearchLookup lookup) { + this.params = params; + this.lookup = lookup; + } + + public boolean needs_score() { + return false; + } + + @Override + public SearchScript newInstance(LeafReaderContext ctx) throws IOException { + return new VectorScoreScript(this.params, this.lookup, ctx); + } + } +} \ No newline at end of file From b97a7c6f297603f79c27b219d011b6db749c46b5 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 4 Apr 2019 09:48:28 -0400 Subject: [PATCH 15/25] Removed unused import --- .../java/com/liorkn/elasticsearch/script/VectorScoreScript.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index b682acf..f2bb77c 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -3,7 +3,6 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Map; import org.apache.lucene.index.BinaryDocValues; From 4479ef2957da0c82b49de6f7aff40522eb250c46 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 4 Apr 2019 13:37:13 -0400 Subject: [PATCH 16/25] Removed unused method --- .../com/liorkn/elasticsearch/script/VectorScoreScript.java | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index f2bb77c..53d97dd 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -78,13 +78,6 @@ public void setDocument(int docId) { throw new UncheckedIOException(e); } } - - public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { - if(binaryEmbeddingReader == null) { - throw new IllegalStateException("binaryEmbeddingReader can't be null"); - } - this.binaryEmbeddingReader = binaryEmbeddingReader; - } @SuppressWarnings("unchecked") public VectorScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { From 028c8b3717a2437a3fcd2d154479232ffe15381b Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Tue, 23 Apr 2019 15:03:16 -0400 Subject: [PATCH 17/25] Fixed vector to be float instead of double --- .../script/VectorScoreScript.java | 20 +++++++++---------- .../com/liorkn/elasticsearch/PluginTest.java | 9 ++------- .../com/liorkn/elasticsearch/TestObject.java | 2 +- 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index 53d97dd..cfceeb6 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -20,8 +20,8 @@ public final class VectorScoreScript extends SearchScript { private final String field; private final boolean cosine; - private final double[] inputVector; - private final double magnitude; + private final float[] inputVector; + private final float magnitude; @Override public long runAsLong() { @@ -45,9 +45,9 @@ public double runAsDouble() { float score = 0; if (cosine) { - double docVectorNorm = 0.0f; + float docVectorNorm = 0.0f; for (int i = 0; i < inputVector.length; i++) { - double v = Double.longBitsToDouble(input.readLong()); + float v = Float.intBitsToFloat(input.readInt()); docVectorNorm += v * v; // inputVector norm score += v * inputVector[i]; // dot product } @@ -59,7 +59,7 @@ public double runAsDouble() { } } else { for (int i = 0; i < inputVector.length; i++) { - double v = Double.longBitsToDouble(input.readLong()); + float v = Float.intBitsToFloat(input.readInt()); score += v * inputVector[i]; // dot product } @@ -97,9 +97,9 @@ public VectorScoreScript(Map params, SearchLookup lookup, LeafRe final Object vector = params.get("vector"); if(vector != null) { final ArrayList tmp = (ArrayList) vector; - inputVector = new double[tmp.size()]; + inputVector = new float[tmp.size()]; for (int i = 0; i < inputVector.length; i++) { - inputVector[i] = tmp.get(i).doubleValue(); + inputVector[i] = tmp.get(i).floatValue(); } } else { final Object encodedVector = params.get("encoded_vector"); @@ -111,12 +111,12 @@ public VectorScoreScript(Map params, SearchLookup lookup, LeafRe if (this.cosine) { // calc magnitude - double queryVectorNorm = 0.0f; + float queryVectorNorm = 0.0f; // compute query inputVector norm once - for (double v: this.inputVector) { + for (float v: this.inputVector) { queryVectorNorm += v * v; } - this.magnitude = (double) Math.sqrt(queryVectorNorm); + this.magnitude = (float) Math.sqrt(queryVectorNorm); } else { this.magnitude = 0.0f; } diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index b7d5747..d856ad8 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -1,9 +1,7 @@ package com.liorkn.elasticsearch; import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ArrayNode; import org.apache.http.HttpHost; import org.apache.http.entity.ContentType; import org.apache.http.entity.StringEntity; @@ -72,6 +70,7 @@ public static void init() throws Exception { public void test() throws Exception { final Map params = new HashMap<>(); params.put("refresh", "true"); + final ObjectMapper mapper = new ObjectMapper(); final TestObject[] objs = {new TestObject(1, new float[] {0.0f, 0.5f, 1.0f}), new TestObject(2, new float[] {0.2f, 0.6f, 0.99f})}; @@ -96,7 +95,7 @@ public void test() throws Exception { " \"source\": \"binary_vector_score\"," + " \"lang\": \"knn\"," + " \"params\": {" + - " \"cosine\": true," + + " \"cosine\": false," + " \"field\": \"embedding_vector\"," + " \"vector\": [" + " 0.1, 0.2, 0.3" + @@ -114,10 +113,6 @@ public void test() throws Exception { System.out.println(resBody); Assert.assertEquals("search should return status code 200", 200, res.getStatusLine().getStatusCode()); Assert.assertTrue(String.format("There should be %d documents in the search response", objs.length), resBody.contains("\"hits\":{\"total\":" + objs.length)); - // Testing Scores - final ArrayNode hitsJson = (ArrayNode)mapper.readTree(resBody).get("hits").get("hits"); - Assert.assertEquals(0.9941734, hitsJson.get(0).get("_score").asDouble(), 0); - Assert.assertEquals(0.95618284, hitsJson.get(1).get("_score").asDouble(), 0); } @AfterClass diff --git a/src/test/java/com/liorkn/elasticsearch/TestObject.java b/src/test/java/com/liorkn/elasticsearch/TestObject.java index a8d0391..8338e31 100644 --- a/src/test/java/com/liorkn/elasticsearch/TestObject.java +++ b/src/test/java/com/liorkn/elasticsearch/TestObject.java @@ -29,4 +29,4 @@ public TestObject(int jobId, float[] vector) { this.vector = vector; this.embeddingVector = Util.convertArrayToBase64(vector); } -} +} \ No newline at end of file From aa8be17edbfa4fd2f54fc7471d8253cdca360eb2 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Wed, 30 Jan 2019 17:23:41 -0500 Subject: [PATCH 18/25] Changed script to work with Elastic 6.0.0. Inline scripts are now depreciated so using 'source' field instead --- README.md | 4 +- pom.xml | 12 +- .../plugin/VectorScoringPlugin.java | 154 ++++++++++++++- .../script/VectorScoreScript.java | 179 ------------------ .../VectorScoringScriptEngineService.java | 77 -------- .../EmbeddedElasticsearchServer.java | 11 +- .../com/liorkn/elasticsearch/PluginTest.java | 2 +- 7 files changed, 159 insertions(+), 280 deletions(-) delete mode 100755 src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java delete mode 100755 src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java diff --git a/README.md b/README.md index 9ad9f9b..5a8f048 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ give it a try. ## Elasticsearch version -* Currently designed for Elasticsearch 5.6.0. +* Currently designed for Elasticsearch 6.0.0. * for Elasticsearch 5.2.2 use branch `es-5.2.2` * for Elasticsearch 2.4.4 use branch `es-2.4.4` @@ -147,7 +147,7 @@ func convertBase64ToArray(base64Str string) ([]float32, error) { "boost_mode": "replace", "script_score": { "script": { - "inline": "binary_vector_score", + "source": "binary_vector_score", "lang": "knn", "params": { "cosine": false, diff --git a/pom.xml b/pom.xml index 7434047..79d7430 100755 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ elasticsearch-binary-vector-scoring com.liorkn.elasticsearch elasticsearch-binary-vector-scoring - 5.6.0 + 6.0.0 ElasticSearch Plugin for Binary Vector Scoring @@ -27,7 +27,7 @@ ${project.basedir}/src/main/resources/license-check/license_header_definition.xml warn - 5.6.0 + 6.0.0 2.4 4.4.8 4.12 @@ -65,7 +65,7 @@ org.elasticsearch.plugin - transport-netty3-client + transport-netty4-client ${elasticsearch.version} test @@ -86,12 +86,6 @@ - - - - - - org.codelibs.elasticsearch.module lang-painless diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index 88a3599..b0a666e 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -13,13 +13,25 @@ */ package com.liorkn.elasticsearch.plugin; -import com.liorkn.elasticsearch.service.VectorScoringScriptEngineService; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.util.Collection; +import java.util.Map; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.Term; +import org.apache.lucene.store.ByteArrayDataInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ScriptPlugin; -import org.elasticsearch.script.ScriptEngineService; - +import org.elasticsearch.script.ScriptContext; +import org.elasticsearch.script.ScriptEngine; +import org.elasticsearch.script.SearchScript; +import java.util.ArrayList; /** * This class is instantiated when Elasticsearch loads the plugin for the * first time. If you change the name of this plugin, make sure to update @@ -27,9 +39,141 @@ */ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { - public final ScriptEngineService getScriptEngineService(Settings settings) { - return new VectorScoringScriptEngineService(settings); + @Override + public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { + return new MyExpertScriptEngine(); } + /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ + private static class MyExpertScriptEngine implements ScriptEngine { + @Override + public String getType() { + return "knn"; + } + + private static final int DOUBLE_SIZE = 8; + + @Override + public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { + + if (context.equals(SearchScript.CONTEXT) == false) { + throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); + } + + // we use the script "source" as the script identifier + if ("binary_vector_score".equals(scriptSource)) { + SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { + final String field; + final boolean cosine; + { + if (p.containsKey("vector") == false) { + throw new IllegalArgumentException("Missing parameter [vector]"); + } + if (p.containsKey("field") == false) { + throw new IllegalArgumentException("Missing parameter [field]"); + } + if (p.containsKey("cosine") == false) { + throw new IllegalArgumentException("Missing parameter [cosine]"); + } + field = p.get("field").toString(); + cosine = (boolean) p.get("cosine"); + } + + final ArrayList searchVector = (ArrayList) p.get("vector"); + double magnitude; + { + if (cosine) { + // calc magnitude + double queryVectorNorm = 0.0; + // compute query inputVector norm once + for (Double v : this.searchVector) { + queryVectorNorm += v.doubleValue() * v.doubleValue(); + } + magnitude = Math.sqrt(queryVectorNorm); + } else { + magnitude = 0.0; + } + } + + @Override + public SearchScript newInstance(LeafReaderContext context) throws IOException { + return new SearchScript(p, lookup, context) { + BinaryDocValues docAccess = context.reader().getBinaryDocValues(field); + int currentDocid = -1; + + @Override + public void setDocument(int docid) { + // Move to desired document + try { + docAccess.advanceExact(docid); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + currentDocid = docid; + } + + @Override + public double runAsDouble() { + if (currentDocid < 0) { + return 0.0; + } + //actually run scoring + final int size = searchVector.size(); + + try { + final byte[] bytes = docAccess.binaryValue().bytes; + final ByteArrayDataInput input = new ByteArrayDataInput(bytes); + input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls + final int len = input.readVInt(); // returns the number of bytes to read//if submitted vector is different size + if (len != size * DOUBLE_SIZE) { + return 0.0; + } + + final int position = input.getPosition(); + final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer(); + final double[] docVector = new double[size]; + doubleBuffer.get(docVector); + double docVectorNorm = 0.0f; + double score = 0; + for (int i = 0; i < size; i++) { + // doc inputVector norm + if(cosine) { + docVectorNorm += docVector[i]*docVector[i]; + } + // dot product + score += docVector[i] * searchVector.get(i).doubleValue(); + } + if(cosine) { + // cosine similarity score + if (docVectorNorm == 0 || magnitude == 0){ + return 0f; + } else { + return score / (Math.sqrt(docVectorNorm) * magnitude); + } + } else { + return score; + } + } catch (Exception e) { + return 0; + } + } + }; + } + + @Override + public boolean needs_score() { + return false; + } + }; + return context.factoryClazz.cast(factory); + } + throw new IllegalArgumentException("Unknown script name " + scriptSource); + } + + @Override + public void close() { + // optionally close resources + } + } } \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java deleted file mode 100755 index c9e0401..0000000 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ /dev/null @@ -1,179 +0,0 @@ -/* -Based on: https://discuss.elastic.co/t/vector-scoring/85227/4 -and https://github.com/MLnick/elasticsearch-vector-scoring - -another slower implementation using strings: https://github.com/ginobefun/elasticsearch-feature-vector-scoring - -storing arrays is no luck - lucine index doesn't keep the array members orders -https://www.elastic.co/guide/en/elasticsearch/guide/current/complex-core-fields.html - -Delimited Payload Token Filter: https://www.elastic.co/guide/en/elasticsearch/reference/2.4/analysis-delimited-payload-tokenfilter.html - */ - -package com.liorkn.elasticsearch.script; - -import com.liorkn.elasticsearch.Util; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.store.ByteArrayDataInput; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.script.ExecutableScript; -import org.elasticsearch.script.LeafSearchScript; -import org.elasticsearch.script.ScriptException; - -import java.util.ArrayList; -import java.util.Map; - - -/** - * Script that scores documents based on cosine similarity embedding vectors. - */ -public final class VectorScoreScript implements LeafSearchScript, ExecutableScript { - - // the field containing the vectors to be scored against - public final String field; - - private int docId; - private BinaryDocValues binaryEmbeddingReader; - - private final float[] inputVector; - private final float magnitude; - - private final boolean cosine; - - @Override - public final Object run() { - return runAsDouble(); - } - - @Override - public long runAsLong() { - return (long) runAsDouble(); - } - - /** - * Called for each document - * @return cosine similarity of the current document against the input inputVector - */ - @Override - public double runAsDouble() { - final byte[] bytes = binaryEmbeddingReader.get(docId).bytes; - final ByteArrayDataInput input = new ByteArrayDataInput(bytes); - - // MUST appear hear since it affect the next calls - input.readVInt(); // returns the number of values which should be 1 - input.readVInt(); // returns the number of bytes to read - - float score = 0; - - if(cosine) { - float docVectorNorm = 0.0f; - - for (int i = 0; i < inputVector.length; i++) { - float v = Float.intBitsToFloat(input.readInt()); - docVectorNorm += v * v; // inputVector norm - score += v * inputVector[i]; // dot product - } - - if (docVectorNorm == 0 || magnitude == 0) { - return 0f; - } else { - return score / (Math.sqrt(docVectorNorm) * magnitude); - } - } else { - for (int i = 0; i < inputVector.length; i++) { - float v = Float.intBitsToFloat(input.readInt()); - score += v * inputVector[i]; // dot product - } - - return score; - } - } - - @Override - public void setNextVar(String name, Object value) {} - - @Override - public void setDocument(int docId) { - this.docId = docId; - } - - public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { - if(binaryEmbeddingReader == null) { - throw new IllegalStateException("binaryEmbeddingReader can't be null"); - } - this.binaryEmbeddingReader = binaryEmbeddingReader; - } - - /** - * Factory that is registered in - * {@link VectorScoringPlugin#onModule(org.elasticsearch.script.ScriptModule)} - * method when the plugin is loaded. - */ - public static class Factory { - - /** - * This method is called for every search on every shard. - * - * @param params - * list of script parameters passed with the query - * @return new native script - */ - public ExecutableScript newScript(@Nullable Map params) throws ScriptException { - return new VectorScoreScript(params); - } - - /** - * Indicates if document scores may be needed by the produced scripts. - * - * @return {@code true} if scores are needed. - */ - public boolean needsScores() { - return false; - } - } - - /** - * Init - * @param params index that a scored are placed in this parameter. Initialize them here. - */ - @SuppressWarnings("unchecked") - public VectorScoreScript(Map params) { - final Object cosineBool = params.get("cosine"); - cosine = cosineBool != null ? - (boolean)cosineBool : - true; - - final Object field = params.get("field"); - if (field == null) - throw new IllegalArgumentException("binary_vector_score script requires field input"); - this.field = field.toString(); - - // get query inputVector - convert to primitive - final Object vector = params.get("vector"); - if(vector != null) { - final ArrayList tmp = (ArrayList) vector; - inputVector = new float[tmp.size()]; - for (int i = 0; i < inputVector.length; i++) { - inputVector[i] = tmp.get(i).floatValue(); - } - } else { - final Object encodedVector = params.get("encoded_vector"); - if(encodedVector == null) { - throw new IllegalArgumentException("Must have at 'vector' or 'encoded_vector' as a parameter"); - } - inputVector = Util.convertBase64ToArray((String) encodedVector); - } - - if(cosine) { - // calc magnitude - float queryVectorNorm = 0.0f; - // compute query inputVector norm once - for (float v: inputVector) { - queryVectorNorm += v * v; - } - magnitude = (float) Math.sqrt(queryVectorNorm); - } else { - magnitude = 0.0f; - } - } -} \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java b/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java deleted file mode 100755 index 269df8a..0000000 --- a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java +++ /dev/null @@ -1,77 +0,0 @@ -package com.liorkn.elasticsearch.service; - -import com.liorkn.elasticsearch.script.VectorScoreScript; -import org.apache.lucene.index.LeafReaderContext; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.component.AbstractComponent; -import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.script.CompiledScript; -import org.elasticsearch.script.ExecutableScript; -import org.elasticsearch.script.LeafSearchScript; -import org.elasticsearch.script.ScriptEngineService; -import org.elasticsearch.script.SearchScript; -import org.elasticsearch.search.lookup.SearchLookup; - -import java.io.IOException; -import java.util.Map; - -/** - * Created by Lior Knaany on 5/14/17. - */ -public class VectorScoringScriptEngineService extends AbstractComponent implements ScriptEngineService{ - - public static final String NAME = "knn"; - - @Inject - public VectorScoringScriptEngineService(Settings settings) { - super(settings); - } - - @Override - public Object compile(String scriptName, String scriptSource, Map params) { - return new VectorScoreScript.Factory(); - } - - @Override - public boolean isInlineScriptEnabled() { - return true; - } - - @Override - public String getType() { - return NAME; - } - - @Override - public String getExtension() { - return NAME; - } - - @Override - public ExecutableScript executable(CompiledScript compiledScript, @Nullable Map vars) { - VectorScoreScript.Factory scriptFactory = (VectorScoreScript.Factory) compiledScript.compiled(); - return scriptFactory.newScript(vars); - } - - @Override - public SearchScript search(CompiledScript compiledScript, final SearchLookup lookup, @Nullable final Map vars) { - final VectorScoreScript.Factory scriptFactory = (VectorScoreScript.Factory) compiledScript.compiled(); - final VectorScoreScript script = (VectorScoreScript) scriptFactory.newScript(vars); - return new SearchScript() { - @Override - public LeafSearchScript getLeafSearchScript(LeafReaderContext context) throws IOException { - script.setBinaryEmbeddingReader(context.reader().getBinaryDocValues(script.field)); - return script; - } - @Override - public boolean needsScores() { - return scriptFactory.needsScores(); - } - }; - } - - @Override - public void close() { - } -} diff --git a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java index 5627240..1cbdb9c 100644 --- a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java +++ b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java @@ -10,7 +10,7 @@ import org.elasticsearch.node.Node; import org.elasticsearch.node.NodeValidationException; import org.elasticsearch.painless.PainlessPlugin; -import org.elasticsearch.transport.Netty3Plugin; +import org.elasticsearch.transport.Netty4Plugin; import java.io.File; import java.io.IOException; @@ -41,13 +41,10 @@ private EmbeddedElasticsearchServer(String defaultDataDirectory, int port) throw Settings.Builder settings = Settings.builder() .put("http.enabled", "true") - .put("transport.type", "local") - .put("http.type", "netty3") + .put("http.type", "netty4") .put("path.data", dataDirectory) .put("path.home", DEFAULT_HOME_DIRECTORY) - .put("script.inline", "on") - .put("node.max_local_storage_nodes", 10000) - .put("script.stored", "on"); + .put("node.max_local_storage_nodes", 10000); startNodeInAvailablePort(settings); } @@ -61,7 +58,7 @@ private void startNodeInAvailablePort(Settings.Builder settings) throws NodeVali settings.put("http.port", String.valueOf(this.port)); // this a hack in order to load Groovy plug in since we want to enable the usage of scripts - node = new NodeExt(settings.build() , Arrays.asList(Netty3Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); + node = new NodeExt(settings.build() , Arrays.asList(Netty4Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); node.start(); success = true; System.out.println(EmbeddedElasticsearchServer.class.getName() + ": Using port: " + this.port); diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index 3a33746..b7d5747 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -93,7 +93,7 @@ public void test() throws Exception { " \"boost_mode\": \"replace\"," + " \"script_score\": {" + " \"script\": {" + - " \"inline\": \"binary_vector_score\"," + + " \"source\": \"binary_vector_score\"," + " \"lang\": \"knn\"," + " \"params\": {" + " \"cosine\": true," + From 51032b5fab67d50deb8b4a50b26c3ad0a179be11 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 31 Jan 2019 13:59:06 -0500 Subject: [PATCH 19/25] renamed engine appropriately --- .../com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index b0a666e..3f97f65 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -41,11 +41,11 @@ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { @Override public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { - return new MyExpertScriptEngine(); + return new VectorScoringPluginEngine(); } /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ - private static class MyExpertScriptEngine implements ScriptEngine { + private static class VectorScoringPluginEngine implements ScriptEngine { @Override public String getType() { return "knn"; From a0365c412798b3483764836cdc63362995facef3 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 4 Apr 2019 09:45:15 -0400 Subject: [PATCH 20/25] Broke it back into multiple files for simplicity, added support for encodedVector field, added runAsLong(), changed searchVector to be a primitive type, made SCRIPT_SOURCE private static class member --- .../engine/VectorScoringScriptEngine.java | 36 ++++ .../plugin/VectorScoringPlugin.java | 149 +---------------- .../script/VectorScoreScript.java | 157 ++++++++++++++++++ 3 files changed, 196 insertions(+), 146 deletions(-) create mode 100644 src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java create mode 100644 src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java diff --git a/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java new file mode 100644 index 0000000..6c00692 --- /dev/null +++ b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java @@ -0,0 +1,36 @@ +package com.liorkn.elasticsearch.engine; + +import com.liorkn.elasticsearch.script.VectorScoreScript; + +import java.util.Map; + +import org.elasticsearch.script.ScriptContext; +import org.elasticsearch.script.ScriptEngine; +import org.elasticsearch.script.SearchScript; + +/** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ +public class VectorScoringScriptEngine implements ScriptEngine { + + public static final String NAME = "knn"; + private static final String SCRIPT_SOURCE = "binary_vector_score"; + + @Override + public String getType() { + return NAME; + } + + @Override + public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { + if (context.equals(SearchScript.CONTEXT) == false) { + throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); + } + + // we use the script "source" as the script identifier + if (!SCRIPT_SOURCE.equals(scriptSource)) { + throw new IllegalArgumentException("Unknown script name " + scriptSource); + } + + SearchScript.Factory factory = VectorScoreScript.VectorScoreScriptFactory::new; + return context.factoryClazz.cast(factory); + } +} diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index 3f97f65..c279636 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -13,25 +13,15 @@ */ package com.liorkn.elasticsearch.plugin; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; +import com.liorkn.elasticsearch.engine.VectorScoringScriptEngine; + import java.util.Collection; -import java.util.Map; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.PostingsEnum; -import org.apache.lucene.index.Term; -import org.apache.lucene.store.ByteArrayDataInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ScriptPlugin; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptEngine; -import org.elasticsearch.script.SearchScript; -import java.util.ArrayList; /** * This class is instantiated when Elasticsearch loads the plugin for the * first time. If you change the name of this plugin, make sure to update @@ -41,139 +31,6 @@ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { @Override public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { - return new VectorScoringPluginEngine(); - } - - /** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ - private static class VectorScoringPluginEngine implements ScriptEngine { - @Override - public String getType() { - return "knn"; - } - - private static final int DOUBLE_SIZE = 8; - - @Override - public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { - - if (context.equals(SearchScript.CONTEXT) == false) { - throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); - } - - // we use the script "source" as the script identifier - if ("binary_vector_score".equals(scriptSource)) { - SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { - final String field; - final boolean cosine; - { - if (p.containsKey("vector") == false) { - throw new IllegalArgumentException("Missing parameter [vector]"); - } - if (p.containsKey("field") == false) { - throw new IllegalArgumentException("Missing parameter [field]"); - } - if (p.containsKey("cosine") == false) { - throw new IllegalArgumentException("Missing parameter [cosine]"); - } - field = p.get("field").toString(); - cosine = (boolean) p.get("cosine"); - } - - final ArrayList searchVector = (ArrayList) p.get("vector"); - double magnitude; - { - if (cosine) { - // calc magnitude - double queryVectorNorm = 0.0; - // compute query inputVector norm once - for (Double v : this.searchVector) { - queryVectorNorm += v.doubleValue() * v.doubleValue(); - } - magnitude = Math.sqrt(queryVectorNorm); - } else { - magnitude = 0.0; - } - } - - @Override - public SearchScript newInstance(LeafReaderContext context) throws IOException { - return new SearchScript(p, lookup, context) { - BinaryDocValues docAccess = context.reader().getBinaryDocValues(field); - int currentDocid = -1; - - @Override - public void setDocument(int docid) { - // Move to desired document - try { - docAccess.advanceExact(docid); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - currentDocid = docid; - } - - @Override - public double runAsDouble() { - if (currentDocid < 0) { - return 0.0; - } - //actually run scoring - final int size = searchVector.size(); - - try { - final byte[] bytes = docAccess.binaryValue().bytes; - final ByteArrayDataInput input = new ByteArrayDataInput(bytes); - input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls - final int len = input.readVInt(); // returns the number of bytes to read//if submitted vector is different size - if (len != size * DOUBLE_SIZE) { - return 0.0; - } - - final int position = input.getPosition(); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer(); - - final double[] docVector = new double[size]; - doubleBuffer.get(docVector); - double docVectorNorm = 0.0f; - double score = 0; - for (int i = 0; i < size; i++) { - // doc inputVector norm - if(cosine) { - docVectorNorm += docVector[i]*docVector[i]; - } - // dot product - score += docVector[i] * searchVector.get(i).doubleValue(); - } - if(cosine) { - // cosine similarity score - if (docVectorNorm == 0 || magnitude == 0){ - return 0f; - } else { - return score / (Math.sqrt(docVectorNorm) * magnitude); - } - } else { - return score; - } - } catch (Exception e) { - return 0; - } - } - }; - } - - @Override - public boolean needs_score() { - return false; - } - }; - return context.factoryClazz.cast(factory); - } - throw new IllegalArgumentException("Unknown script name " + scriptSource); - } - - @Override - public void close() { - // optionally close resources - } + return new VectorScoringScriptEngine(); } } \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java new file mode 100644 index 0000000..b682acf --- /dev/null +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -0,0 +1,157 @@ +package com.liorkn.elasticsearch.script; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Map; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.store.ByteArrayDataInput; +import org.elasticsearch.script.SearchScript; +import org.elasticsearch.search.lookup.SearchLookup; + +import com.liorkn.elasticsearch.Util; + +public final class VectorScoreScript extends SearchScript { + + private BinaryDocValues binaryEmbeddingReader; + + private final String field; + private final boolean cosine; + + private final double[] inputVector; + private final double magnitude; + + @Override + public long runAsLong() { + return (long) runAsDouble(); + } + + @Override + public double runAsDouble() { + try { + final byte[] bytes = binaryEmbeddingReader.binaryValue().bytes; + final ByteArrayDataInput input = new ByteArrayDataInput(bytes); + + input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls + + final int len = input.readVInt(); + // in case vector is of different size + if (len != inputVector.length * Double.BYTES) { + return 0.0; + } + + float score = 0; + + if (cosine) { + double docVectorNorm = 0.0f; + for (int i = 0; i < inputVector.length; i++) { + double v = Double.longBitsToDouble(input.readLong()); + docVectorNorm += v * v; // inputVector norm + score += v * inputVector[i]; // dot product + } + + if (docVectorNorm == 0 || magnitude == 0) { + return 0f; + } else { + return score / (Math.sqrt(docVectorNorm) * magnitude); + } + } else { + for (int i = 0; i < inputVector.length; i++) { + double v = Double.longBitsToDouble(input.readLong()); + score += v * inputVector[i]; // dot product + } + + return score; + } + } catch (Exception e) { + return 0.0; + } + } + + @Override + public void setDocument(int docId) { + try { + this.binaryEmbeddingReader.advanceExact(docId); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { + if(binaryEmbeddingReader == null) { + throw new IllegalStateException("binaryEmbeddingReader can't be null"); + } + this.binaryEmbeddingReader = binaryEmbeddingReader; + } + + @SuppressWarnings("unchecked") + public VectorScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { + super(params, lookup, leafContext); + + final Object cosineBool = params.get("cosine"); + this.cosine = cosineBool != null ? + (boolean)cosineBool : + true; + + final Object field = params.get("field"); + if (field == null) + throw new IllegalArgumentException("binary_vector_score script requires field input"); + this.field = field.toString(); + + // get query inputVector - convert to primitive + final Object vector = params.get("vector"); + if(vector != null) { + final ArrayList tmp = (ArrayList) vector; + inputVector = new double[tmp.size()]; + for (int i = 0; i < inputVector.length; i++) { + inputVector[i] = tmp.get(i).doubleValue(); + } + } else { + final Object encodedVector = params.get("encoded_vector"); + if(encodedVector == null) { + throw new IllegalArgumentException("Must have at 'vector' or 'encoded_vector' as a parameter"); + } + inputVector = Util.convertBase64ToArray((String) encodedVector); + } + + if (this.cosine) { + // calc magnitude + double queryVectorNorm = 0.0f; + // compute query inputVector norm once + for (double v: this.inputVector) { + queryVectorNorm += v * v; + } + this.magnitude = (double) Math.sqrt(queryVectorNorm); + } else { + this.magnitude = 0.0f; + } + + try { + this.binaryEmbeddingReader = leafContext.reader().getBinaryDocValues(this.field); + } catch (IOException e) { + throw new IllegalStateException("binaryEmbeddingReader can't be null"); + } + } + + public static class VectorScoreScriptFactory implements LeafFactory { + private final Map params; + private final SearchLookup lookup; + + public VectorScoreScriptFactory(Map params, SearchLookup lookup) { + this.params = params; + this.lookup = lookup; + } + + public boolean needs_score() { + return false; + } + + @Override + public SearchScript newInstance(LeafReaderContext ctx) throws IOException { + return new VectorScoreScript(this.params, this.lookup, ctx); + } + } +} \ No newline at end of file From ac7efaea24ec28cfc5749a8a9cf6a423810d0790 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 4 Apr 2019 09:48:28 -0400 Subject: [PATCH 21/25] Removed unused import --- .../java/com/liorkn/elasticsearch/script/VectorScoreScript.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index b682acf..f2bb77c 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -3,7 +3,6 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Map; import org.apache.lucene.index.BinaryDocValues; From 79a8753ba1ace5b9173ce4797377a05761097821 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 4 Apr 2019 13:37:13 -0400 Subject: [PATCH 22/25] Removed unused method --- .../com/liorkn/elasticsearch/script/VectorScoreScript.java | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index f2bb77c..53d97dd 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -78,13 +78,6 @@ public void setDocument(int docId) { throw new UncheckedIOException(e); } } - - public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { - if(binaryEmbeddingReader == null) { - throw new IllegalStateException("binaryEmbeddingReader can't be null"); - } - this.binaryEmbeddingReader = binaryEmbeddingReader; - } @SuppressWarnings("unchecked") public VectorScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { From 3489091a337633a1cd65cfa40d66e83b71b7a9a7 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Tue, 23 Apr 2019 15:03:16 -0400 Subject: [PATCH 23/25] Fixed vector to be float instead of double --- .../script/VectorScoreScript.java | 20 +++++++++---------- .../com/liorkn/elasticsearch/PluginTest.java | 1 + .../com/liorkn/elasticsearch/TestObject.java | 2 +- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index 53d97dd..cfceeb6 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -20,8 +20,8 @@ public final class VectorScoreScript extends SearchScript { private final String field; private final boolean cosine; - private final double[] inputVector; - private final double magnitude; + private final float[] inputVector; + private final float magnitude; @Override public long runAsLong() { @@ -45,9 +45,9 @@ public double runAsDouble() { float score = 0; if (cosine) { - double docVectorNorm = 0.0f; + float docVectorNorm = 0.0f; for (int i = 0; i < inputVector.length; i++) { - double v = Double.longBitsToDouble(input.readLong()); + float v = Float.intBitsToFloat(input.readInt()); docVectorNorm += v * v; // inputVector norm score += v * inputVector[i]; // dot product } @@ -59,7 +59,7 @@ public double runAsDouble() { } } else { for (int i = 0; i < inputVector.length; i++) { - double v = Double.longBitsToDouble(input.readLong()); + float v = Float.intBitsToFloat(input.readInt()); score += v * inputVector[i]; // dot product } @@ -97,9 +97,9 @@ public VectorScoreScript(Map params, SearchLookup lookup, LeafRe final Object vector = params.get("vector"); if(vector != null) { final ArrayList tmp = (ArrayList) vector; - inputVector = new double[tmp.size()]; + inputVector = new float[tmp.size()]; for (int i = 0; i < inputVector.length; i++) { - inputVector[i] = tmp.get(i).doubleValue(); + inputVector[i] = tmp.get(i).floatValue(); } } else { final Object encodedVector = params.get("encoded_vector"); @@ -111,12 +111,12 @@ public VectorScoreScript(Map params, SearchLookup lookup, LeafRe if (this.cosine) { // calc magnitude - double queryVectorNorm = 0.0f; + float queryVectorNorm = 0.0f; // compute query inputVector norm once - for (double v: this.inputVector) { + for (float v: this.inputVector) { queryVectorNorm += v * v; } - this.magnitude = (double) Math.sqrt(queryVectorNorm); + this.magnitude = (float) Math.sqrt(queryVectorNorm); } else { this.magnitude = 0.0f; } diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index b7d5747..4380eff 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -72,6 +72,7 @@ public static void init() throws Exception { public void test() throws Exception { final Map params = new HashMap<>(); params.put("refresh", "true"); + final ObjectMapper mapper = new ObjectMapper(); final TestObject[] objs = {new TestObject(1, new float[] {0.0f, 0.5f, 1.0f}), new TestObject(2, new float[] {0.2f, 0.6f, 0.99f})}; diff --git a/src/test/java/com/liorkn/elasticsearch/TestObject.java b/src/test/java/com/liorkn/elasticsearch/TestObject.java index a8d0391..8338e31 100644 --- a/src/test/java/com/liorkn/elasticsearch/TestObject.java +++ b/src/test/java/com/liorkn/elasticsearch/TestObject.java @@ -29,4 +29,4 @@ public TestObject(int jobId, float[] vector) { this.vector = vector; this.embeddingVector = Util.convertArrayToBase64(vector); } -} +} \ No newline at end of file From c5b2843a830f93284ff131059a085b997ccb8d6e Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Tue, 23 Apr 2019 15:03:16 -0400 Subject: [PATCH 24/25] Fixed vector to be float instead of double --- src/test/java/com/liorkn/elasticsearch/PluginTest.java | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index 4380eff..d856ad8 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -1,9 +1,7 @@ package com.liorkn.elasticsearch; import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ArrayNode; import org.apache.http.HttpHost; import org.apache.http.entity.ContentType; import org.apache.http.entity.StringEntity; @@ -97,7 +95,7 @@ public void test() throws Exception { " \"source\": \"binary_vector_score\"," + " \"lang\": \"knn\"," + " \"params\": {" + - " \"cosine\": true," + + " \"cosine\": false," + " \"field\": \"embedding_vector\"," + " \"vector\": [" + " 0.1, 0.2, 0.3" + @@ -115,10 +113,6 @@ public void test() throws Exception { System.out.println(resBody); Assert.assertEquals("search should return status code 200", 200, res.getStatusLine().getStatusCode()); Assert.assertTrue(String.format("There should be %d documents in the search response", objs.length), resBody.contains("\"hits\":{\"total\":" + objs.length)); - // Testing Scores - final ArrayNode hitsJson = (ArrayNode)mapper.readTree(resBody).get("hits").get("hits"); - Assert.assertEquals(0.9941734, hitsJson.get(0).get("_score").asDouble(), 0); - Assert.assertEquals(0.95618284, hitsJson.get(1).get("_score").asDouble(), 0); } @AfterClass From 1f1d5cdca3dc6d78c7f79c91320c270044f1e297 Mon Sep 17 00:00:00 2001 From: Iakov Gorelik Date: Thu, 25 Apr 2019 14:28:40 -0400 Subject: [PATCH 25/25] Fixed testing and checking of sizes --- .../liorkn/elasticsearch/script/VectorScoreScript.java | 2 +- src/test/java/com/liorkn/elasticsearch/PluginTest.java | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index cfceeb6..be06fd4 100644 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -38,7 +38,7 @@ public double runAsDouble() { final int len = input.readVInt(); // in case vector is of different size - if (len != inputVector.length * Double.BYTES) { + if (len != inputVector.length * Float.BYTES) { return 0.0; } diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index d856ad8..8427932 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -2,6 +2,8 @@ import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; + import org.apache.http.HttpHost; import org.apache.http.entity.ContentType; import org.apache.http.entity.StringEntity; @@ -70,7 +72,6 @@ public static void init() throws Exception { public void test() throws Exception { final Map params = new HashMap<>(); params.put("refresh", "true"); - final ObjectMapper mapper = new ObjectMapper(); final TestObject[] objs = {new TestObject(1, new float[] {0.0f, 0.5f, 1.0f}), new TestObject(2, new float[] {0.2f, 0.6f, 0.99f})}; @@ -95,7 +96,7 @@ public void test() throws Exception { " \"source\": \"binary_vector_score\"," + " \"lang\": \"knn\"," + " \"params\": {" + - " \"cosine\": false," + + " \"cosine\": true," + " \"field\": \"embedding_vector\"," + " \"vector\": [" + " 0.1, 0.2, 0.3" + @@ -113,6 +114,10 @@ public void test() throws Exception { System.out.println(resBody); Assert.assertEquals("search should return status code 200", 200, res.getStatusLine().getStatusCode()); Assert.assertTrue(String.format("There should be %d documents in the search response", objs.length), resBody.contains("\"hits\":{\"total\":" + objs.length)); + // Testing Scores + final ArrayNode hitsJson = (ArrayNode)mapper.readTree(resBody).get("hits").get("hits"); + Assert.assertEquals(0.9941734, hitsJson.get(0).get("_score").asDouble(), 0); + Assert.assertEquals(0.95618284, hitsJson.get(1).get("_score").asDouble(), 0); } @AfterClass