Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
Expand Down Expand Up @@ -62,57 +66,140 @@
value = 3,
jvmArgsAppend = {"-Xmx2g", "-Xms2g", "-XX:+AlwaysPreTouch"})
public class VectorScorerBenchmark {
private static final float EPSILON = 1e-4f;

@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
public int size;

@Param({"0", "1", "2", "4", "6", "8", "16", "20", "32", "50", "64", "100", "128", "255", "256"})
public int padBytes;

Directory dir;
IndexInput in;
KnnVectorValues vectorValues;
IndexInput bytesIn;
IndexInput floatsIn;
KnnVectorValues byteVectorValues;
KnnVectorValues floatVectorValues;
byte[] vec1, vec2;
UpdateableRandomVectorScorer scorer;
float[] floatsA, floatsB;
float expectedBytes, expectedFloats;
UpdateableRandomVectorScorer byteScorer;
UpdateableRandomVectorScorer floatScorer;

@Setup(Level.Iteration)
public void init() throws IOException {
Random random = ThreadLocalRandom.current();

vec1 = new byte[size];
vec2 = new byte[size];
ThreadLocalRandom.current().nextBytes(vec1);
ThreadLocalRandom.current().nextBytes(vec2);
random.nextBytes(vec1);
random.nextBytes(vec2);
expectedBytes = DOT_PRODUCT.compare(vec1, vec2);

// random float arrays for float methods
floatsA = new float[size];
floatsB = new float[size];
for (int i = 0; i < size; ++i) {
floatsA[i] = random.nextFloat();
floatsB[i] = random.nextFloat();
}
expectedFloats = DOT_PRODUCT.compare(floatsA, floatsB);

dir = new MMapDirectory(Files.createTempDirectory("VectorScorerBenchmark"));
try (IndexOutput out = dir.createOutput("vector.data", IOContext.DEFAULT)) {
try (IndexOutput out = dir.createOutput("byteVector.data", IOContext.DEFAULT)) {
out.writeBytes(new byte[padBytes], 0, padBytes);

out.writeBytes(vec1, 0, vec1.length);
out.writeBytes(vec2, 0, vec2.length);
}
in = dir.openInput("vector.data", IOContext.DEFAULT);
vectorValues = vectorValues(size, 2, in, DOT_PRODUCT);
scorer =
try (IndexOutput out = dir.createOutput("floatVector.data", IOContext.DEFAULT)) {
out.writeBytes(new byte[padBytes], 0, padBytes);

byte[] buffer = new byte[size * Float.BYTES];
ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().put(floatsA);
out.writeBytes(buffer, 0, buffer.length);
ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().put(floatsB);
out.writeBytes(buffer, 0, buffer.length);
}

bytesIn = dir.openInput("byteVector.data", IOContext.DEFAULT);
byteVectorValues = byteVectorValues(DOT_PRODUCT);
byteScorer =
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
.getRandomVectorScorerSupplier(DOT_PRODUCT, byteVectorValues)
.scorer();
byteScorer.setScoringOrdinal(0);

floatsIn = dir.openInput("floatVector.data", IOContext.DEFAULT);
floatVectorValues = floatVectorValues(DOT_PRODUCT);
floatScorer =
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
.getRandomVectorScorerSupplier(DOT_PRODUCT, vectorValues)
.getRandomVectorScorerSupplier(DOT_PRODUCT, floatVectorValues)
.scorer();
scorer.setScoringOrdinal(0);
floatScorer.setScoringOrdinal(0);
}

@TearDown
public void teardown() throws IOException {
IOUtils.close(dir, in);
IOUtils.close(dir, bytesIn);
}

@Benchmark
public float binaryDotProductDefault() throws IOException {
return scorer.score(1);
float result = byteScorer.score(1);
if (Math.abs(result - expectedBytes) > EPSILON) {
throw new RuntimeException("Expected " + result + " but got " + expectedBytes);
}
return result;
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public float binaryDotProductMemSeg() throws IOException {
return scorer.score(1);
float result = byteScorer.score(1);
if (Math.abs(result - expectedBytes) > EPSILON) {
throw new RuntimeException("Expected " + result + " but got " + expectedBytes);
}
return result;
}

@Benchmark
public float floatDotProductDefault() throws IOException {
float result = floatScorer.score(1);
if (Math.abs(result - expectedFloats) > EPSILON) {
throw new RuntimeException("Expected " + result + " but got " + expectedFloats);
}
return result;
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public float floatDotProductMemSeg() throws IOException {
float result = floatScorer.score(1);
if (Math.abs(result - expectedFloats) > EPSILON) {
throw new RuntimeException("Expected " + result + " but got " + expectedFloats);
}
return result;
}

static KnnVectorValues vectorValues(
int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
KnnVectorValues byteVectorValues(VectorSimilarityFunction sim) throws IOException {
return new OffHeapByteVectorValues.DenseOffHeapVectorValues(
dims, size, in.slice("test", 0, in.length()), dims, new ThrowingFlatVectorScorer(), sim);
size,
2,
bytesIn.slice("test", padBytes, size * 2L),
size,
new ThrowingFlatVectorScorer(),
sim);
}

KnnVectorValues floatVectorValues(VectorSimilarityFunction sim) throws IOException {
int byteSize = size * Float.BYTES;
return new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
size,
2,
floatsIn.slice("test", padBytes, byteSize * 2L),
byteSize,
new ThrowingFlatVectorScorer(),
sim);
}

static final class ThrowingFlatVectorScorer implements FlatVectorsScorer {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,18 @@ public long ramBytesUsed() {
return total;
}

/** Align vectors for optimal vectorized performance. */
private static long alignOutput(IndexOutput output, VectorEncoding encoding) throws IOException {
return output.alignFilePointer(
switch (encoding) {
case BYTE -> Float.BYTES;
case FLOAT32 -> 64;
});
}

private void writeField(FieldWriter<?> fieldData, int maxDoc) throws IOException {
// write vector values
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
long vectorDataOffset = alignOutput(vectorData, fieldData.fieldInfo.getVectorEncoding());
switch (fieldData.fieldInfo.getVectorEncoding()) {
case BYTE -> writeByteVectors(fieldData);
case FLOAT32 -> writeFloat32Vectors(fieldData);
Expand Down Expand Up @@ -190,43 +199,39 @@ private void writeSortingField(FieldWriter<?> fieldData, int maxDoc, Sorter.DocM
mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, null, ordMap, newDocsWithField);

// write vector values
long vectorDataOffset =
switch (fieldData.fieldInfo.getVectorEncoding()) {
case BYTE -> writeSortedByteVectors(fieldData, ordMap);
case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap);
};
long vectorDataOffset = alignOutput(vectorData, fieldData.fieldInfo.getVectorEncoding());
switch (fieldData.fieldInfo.getVectorEncoding()) {
case BYTE -> writeSortedByteVectors(fieldData, ordMap);
case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap);
}
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;

writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, newDocsWithField);
}

private long writeSortedFloat32Vectors(FieldWriter<?> fieldData, int[] ordMap)
private void writeSortedFloat32Vectors(FieldWriter<?> fieldData, int[] ordMap)
throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (int ordinal : ordMap) {
float[] vector = (float[]) fieldData.vectors.get(ordinal);
buffer.asFloatBuffer().put(vector);
vectorData.writeBytes(buffer.array(), buffer.array().length);
}
return vectorDataOffset;
}

private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
private void writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
for (int ordinal : ordMap) {
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
vectorData.writeBytes(vector, vector.length);
}
return vectorDataOffset;
}

@Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
// Since we know we will not be searching for additional indexing, we can just write the
// the vectors directly to the new segment.
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
long vectorDataOffset = alignOutput(vectorData, fieldInfo.getVectorEncoding());
// No need to use temporary file as we don't have to re-open for reading
DocsWithFieldSet docsWithField =
switch (fieldInfo.getVectorEncoding()) {
Expand All @@ -252,7 +257,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
@Override
public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
FieldInfo fieldInfo, MergeState mergeState) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
long vectorDataOffset = alignOutput(vectorData, fieldInfo.getVectorEncoding());
IndexOutput tempVectorData =
segmentWriteState.directory.createTempOutput(
vectorData.getName(), "temp", segmentWriteState.context);
Expand Down
Loading