Skip to content

Commit 82e1faa

Browse files
author
Kaival Parikh
committed
Implement off-heap quantized scoring
1 parent 77f0d1f commit 82e1faa

File tree

9 files changed

+509
-40
lines changed

9 files changed

+509
-40
lines changed

lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,12 @@ private FlatVectorScorerUtil() {}
3737
public static FlatVectorsScorer getLucene99FlatVectorsScorer() {
3838
return IMPL.getLucene99FlatVectorsScorer();
3939
}
40+
41+
/**
42+
* Returns a FlatVectorsScorer that supports the quantized Lucene99 format. Scorers retrieved
43+
* through this method may be optimized on certain platforms.
44+
*/
45+
public static FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
46+
return IMPL.getLucene99ScalarQuantizedVectorsScorer();
47+
}
4048
}

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
package org.apache.lucene.codecs.lucene99;
1919

2020
import java.io.IOException;
21-
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
2221
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
2322
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
2423
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
24+
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
2525
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
2626
import org.apache.lucene.index.SegmentReadState;
2727
import org.apache.lucene.index.SegmentWriteState;
@@ -70,7 +70,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
7070

7171
final byte bits;
7272
final boolean compress;
73-
final Lucene99ScalarQuantizedVectorScorer flatVectorScorer;
73+
final FlatVectorsScorer flatVectorScorer;
7474

7575
/** Constructs a format using default graph construction parameters */
7676
public Lucene99ScalarQuantizedVectorsFormat() {
@@ -117,8 +117,7 @@ public Lucene99ScalarQuantizedVectorsFormat(
117117
this.bits = (byte) bits;
118118
this.confidenceInterval = confidenceInterval;
119119
this.compress = compress;
120-
this.flatVectorScorer =
121-
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
120+
this.flatVectorScorer = FlatVectorScorerUtil.getLucene99ScalarQuantizedVectorsScorer();
122121
}
123122

124123
public static float calculateDefaultConfidenceInterval(int vectorDimension) {

lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
2121
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
22+
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer;
2223
import org.apache.lucene.store.IndexInput;
2324

2425
/** Default provider returning scalar implementations. */
@@ -40,6 +41,11 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() {
4041
return DefaultFlatVectorScorer.INSTANCE;
4142
}
4243

44+
@Override
45+
public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
46+
return new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer());
47+
}
48+
4349
@Override
4450
public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) {
4551
return new PostingDecodingUtil(input);

lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ public static VectorizationProvider getInstance() {
122122
/** Returns a FlatVectorsScorer that supports the Lucene99 format. */
123123
public abstract FlatVectorsScorer getLucene99FlatVectorsScorer();
124124

125+
/** Returns a FlatVectorsScorer that supports the quantized Lucene99 format. */
126+
public abstract FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer();
127+
125128
/** Create a new {@link PostingDecodingUtil} for the given {@link IndexInput}. */
126129
public abstract PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException;
127130

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.lucene.internal.vectorization;
19+
20+
import static java.lang.foreign.ValueLayout.JAVA_BYTE;
21+
import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
22+
import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantizeQuery;
23+
import static org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.factory;
24+
import static org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.getSegment;
25+
26+
import java.io.IOException;
27+
import java.lang.foreign.Arena;
28+
import java.lang.foreign.MemorySegment;
29+
import org.apache.lucene.index.VectorSimilarityFunction;
30+
import org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.FloatToFloatFunction;
31+
import org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.MemorySegmentScorer;
32+
import org.apache.lucene.store.MemorySegmentAccessInput;
33+
import org.apache.lucene.util.hnsw.RandomVectorScorer;
34+
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
35+
import org.apache.lucene.util.quantization.ScalarQuantizer;
36+
37+
class Lucene99MemorySegmentScalarQuantizedScorer
38+
extends RandomVectorScorer.AbstractRandomVectorScorer {
39+
40+
private final VectorSimilarityFunction function;
41+
private final QuantizedByteVectorValues values;
42+
private final MemorySegmentAccessInput input;
43+
private final MemorySegmentScorer scorer;
44+
private final FloatToFloatFunction scaler;
45+
private final float constMultiplier;
46+
private final int vectorByteSize;
47+
private final int entrySize;
48+
private final MemorySegment query;
49+
private final float queryOffset;
50+
private final byte[][] docScratch;
51+
52+
public Lucene99MemorySegmentScalarQuantizedScorer(
53+
VectorSimilarityFunction function,
54+
QuantizedByteVectorValues values,
55+
MemorySegmentAccessInput input,
56+
float[] target) {
57+
58+
super(values);
59+
this.function = function;
60+
this.values = values;
61+
this.input = input;
62+
this.scorer = factory(function, values, false);
63+
this.scaler = factory(function);
64+
65+
ScalarQuantizer quantizer = values.getScalarQuantizer();
66+
this.constMultiplier = quantizer.getConstantMultiplier();
67+
this.vectorByteSize = values.getVectorByteLength();
68+
this.entrySize = vectorByteSize + Float.BYTES;
69+
70+
byte[] targetBytes = new byte[target.length];
71+
this.queryOffset = quantizeQuery(target, targetBytes, function, quantizer);
72+
this.query = Arena.ofAuto().allocateFrom(JAVA_BYTE, targetBytes);
73+
74+
this.docScratch = new byte[1][];
75+
}
76+
77+
@Override
78+
public float score(int node) throws IOException {
79+
MemorySegment segment = getSegment(input, entrySize, node, docScratch);
80+
MemorySegment doc = segment.reinterpret(vectorByteSize);
81+
float docOffset = segment.get(JAVA_FLOAT, vectorByteSize);
82+
return scaler.scale(scorer.score(query, doc) * constMultiplier + queryOffset + docOffset);
83+
}
84+
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.lucene.internal.vectorization;
19+
20+
import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
21+
import static org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.factory;
22+
import static org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.getSegment;
23+
24+
import java.io.IOException;
25+
import java.lang.foreign.MemorySegment;
26+
import org.apache.lucene.index.VectorSimilarityFunction;
27+
import org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.FloatToFloatFunction;
28+
import org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.MemorySegmentScorer;
29+
import org.apache.lucene.store.MemorySegmentAccessInput;
30+
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
31+
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
32+
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
33+
34+
class Lucene99MemorySegmentScalarQuantizedScorerSupplier implements RandomVectorScorerSupplier {
35+
36+
private final VectorSimilarityFunction function;
37+
private final QuantizedByteVectorValues values;
38+
private final MemorySegmentAccessInput input;
39+
private final MemorySegmentScorer scorer;
40+
private final FloatToFloatFunction scaler;
41+
private final float constMultiplier;
42+
private final int vectorByteSize;
43+
private final int entrySize;
44+
45+
public Lucene99MemorySegmentScalarQuantizedScorerSupplier(
46+
VectorSimilarityFunction function,
47+
QuantizedByteVectorValues values,
48+
MemorySegmentAccessInput input) {
49+
50+
this.function = function;
51+
this.values = values;
52+
this.input = input;
53+
this.scorer = factory(function, values, true);
54+
this.scaler = factory(function);
55+
this.constMultiplier = values.getScalarQuantizer().getConstantMultiplier();
56+
this.vectorByteSize = values.getVectorByteLength();
57+
this.entrySize = vectorByteSize + Float.BYTES;
58+
}
59+
60+
@Override
61+
public UpdateableRandomVectorScorer scorer() {
62+
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) {
63+
64+
private final MemorySegment[] doc = new MemorySegment[1];
65+
private final float[] docOffset = new float[1];
66+
private final byte[][] docScratch = new byte[1][];
67+
private final byte[][] queryScratch = new byte[1][];
68+
69+
@Override
70+
public void setScoringOrdinal(int node) throws IOException {
71+
MemorySegment segment = getSegment(input, entrySize, node, docScratch);
72+
doc[0] = segment.reinterpret(vectorByteSize);
73+
docOffset[0] = segment.get(JAVA_FLOAT, vectorByteSize);
74+
}
75+
76+
@Override
77+
public float score(int node) throws IOException {
78+
MemorySegment segment = getSegment(input, entrySize, node, queryScratch);
79+
MemorySegment query = segment.reinterpret(vectorByteSize);
80+
float queryOffset = segment.get(JAVA_FLOAT, vectorByteSize);
81+
return scaler.scale(
82+
scorer.score(query, doc[0]) * constMultiplier + queryOffset + docOffset[0]);
83+
}
84+
};
85+
}
86+
87+
@Override
88+
public RandomVectorScorerSupplier copy() throws IOException {
89+
return new Lucene99MemorySegmentScalarQuantizedScorerSupplier(function, values, input);
90+
}
91+
}
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.lucene.internal.vectorization;
19+
20+
import java.io.IOException;
21+
import java.lang.foreign.MemorySegment;
22+
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
23+
import org.apache.lucene.index.KnnVectorValues;
24+
import org.apache.lucene.index.VectorSimilarityFunction;
25+
import org.apache.lucene.store.MemorySegmentAccessInput;
26+
import org.apache.lucene.util.VectorUtil;
27+
import org.apache.lucene.util.hnsw.RandomVectorScorer;
28+
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
29+
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
30+
31+
public class Lucene99MemorySegmentScalarQuantizedVectorScorer implements FlatVectorsScorer {
32+
33+
public static final Lucene99MemorySegmentScalarQuantizedVectorScorer INSTANCE =
34+
new Lucene99MemorySegmentScalarQuantizedVectorScorer();
35+
36+
private static final FlatVectorsScorer NON_QUANTIZED_DELEGATE =
37+
Lucene99MemorySegmentFlatVectorsScorer.INSTANCE;
38+
39+
@Override
40+
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
41+
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
42+
throws IOException {
43+
if (vectorValues instanceof QuantizedByteVectorValues values
44+
&& values.getSlice() instanceof MemorySegmentAccessInput input) {
45+
return new Lucene99MemorySegmentScalarQuantizedScorerSupplier(
46+
similarityFunction, values, input);
47+
}
48+
// It is possible to get to this branch during initial indexing and flush
49+
return NON_QUANTIZED_DELEGATE.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
50+
}
51+
52+
@Override
53+
public RandomVectorScorer getRandomVectorScorer(
54+
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
55+
throws IOException {
56+
if (vectorValues instanceof QuantizedByteVectorValues values
57+
&& values.getSlice() instanceof MemorySegmentAccessInput input) {
58+
checkDimensions(target.length, vectorValues.dimension());
59+
return new Lucene99MemorySegmentScalarQuantizedScorer(
60+
similarityFunction, values, input, target);
61+
}
62+
// It is possible to get to this branch during initial indexing and flush
63+
return NON_QUANTIZED_DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target);
64+
}
65+
66+
@Override
67+
public RandomVectorScorer getRandomVectorScorer(
68+
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
69+
throws IOException {
70+
return NON_QUANTIZED_DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target);
71+
}
72+
73+
@Override
74+
public String toString() {
75+
return getClass().getSimpleName() + "()";
76+
}
77+
78+
private static void checkDimensions(int queryLen, int fieldLen) {
79+
if (queryLen != fieldLen) {
80+
throw new IllegalArgumentException(
81+
"vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen);
82+
}
83+
}
84+
85+
static MemorySegment getSegment(
86+
MemorySegmentAccessInput input, int entrySize, int node, byte[][] scratch)
87+
throws IOException {
88+
long pos = (long) entrySize * node;
89+
MemorySegment segment = input.segmentSliceOrNull(pos, entrySize);
90+
if (segment == null) {
91+
if (scratch[0] == null) {
92+
scratch[0] = new byte[entrySize];
93+
}
94+
input.readBytes(pos, scratch[0], 0, entrySize);
95+
segment = MemorySegment.ofArray(scratch[0]);
96+
}
97+
return segment;
98+
}
99+
100+
@FunctionalInterface
101+
interface MemorySegmentScorer {
102+
float score(MemorySegment query, MemorySegment doc);
103+
}
104+
105+
@FunctionalInterface
106+
interface FloatToFloatFunction {
107+
float scale(float score);
108+
}
109+
110+
static MemorySegmentScorer factory(
111+
VectorSimilarityFunction function,
112+
QuantizedByteVectorValues values,
113+
boolean isScorerSupplier) {
114+
return switch (function) {
115+
case EUCLIDEAN -> {
116+
if (values.getScalarQuantizer().getBits() < 7) {
117+
// TODO
118+
throw new UnsupportedOperationException();
119+
}
120+
yield PanamaVectorUtilSupport::squareDistance;
121+
}
122+
case DOT_PRODUCT, COSINE, MAXIMUM_INNER_PRODUCT -> {
123+
if (values.getScalarQuantizer().getBits() <= 4) {
124+
if (values.getVectorByteLength() != values.dimension()) {
125+
if (isScorerSupplier) {
126+
yield (query, doc) -> PanamaVectorUtilSupport.int4DotProduct(query, true, doc, true);
127+
}
128+
yield (query, doc) -> PanamaVectorUtilSupport.int4DotProduct(query, false, doc, true);
129+
}
130+
yield (query, doc) -> PanamaVectorUtilSupport.int4DotProduct(query, false, doc, false);
131+
}
132+
yield PanamaVectorUtilSupport::dotProduct;
133+
}
134+
};
135+
}
136+
137+
static FloatToFloatFunction factory(VectorSimilarityFunction function) {
138+
return switch (function) {
139+
case EUCLIDEAN -> score -> (1 / (1f + score));
140+
case DOT_PRODUCT, COSINE -> score -> Math.max((1f + score) / 2, 0);
141+
case MAXIMUM_INNER_PRODUCT -> VectorUtil::scaleMaxInnerProductScore;
142+
};
143+
}
144+
}

0 commit comments

Comments
 (0)