Skip to content

Commit 6c6cd43

Browse files
authored
Add better support for new embeddings model (#589)
1 parent e16db83 commit 6c6cd43

File tree

5 files changed

+146
-23
lines changed

5 files changed

+146
-23
lines changed

src/platform/configuration/common/configurationService.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,8 @@ export const enum CHAT_MODEL {
530530
// WARNING
531531
// These values are used in the request and are case sensitive. Do not change them unless advised by CAPI.
532532
export const enum EMBEDDING_MODEL {
533-
TEXT3SMALL = "text-embedding-3-small"
533+
TEXT3SMALL = 'text-embedding-3-small',
534+
Metis_1024_I16_Binary = 'metis-1024-I16-Binary',
534535
}
535536

536537
export enum AuthProviderId {

src/platform/embeddings/common/embeddingsComputer.ts

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import { EmbeddingsEndpointFamily } from '../../endpoint/common/endpointProvider
1515
*/
1616
export class EmbeddingType {
1717
public static readonly text3small_512 = new EmbeddingType('text-embedding-3-small-512');
18+
public static readonly metis_1024_I16_Binary = new EmbeddingType('metis-1024-I16-Binary');
1819

1920
constructor(
2021
public readonly id: string
@@ -29,19 +30,38 @@ export class EmbeddingType {
2930
}
3031
}
3132

33+
type EmbeddingQuantization = 'float32' | 'float16' | 'binary';
34+
3235
export interface EmbeddingTypeInfo {
3336
readonly model: EMBEDDING_MODEL;
3437
readonly family: EmbeddingsEndpointFamily;
3538
readonly dimensions: number;
39+
readonly quantization: {
40+
readonly query: EmbeddingQuantization;
41+
readonly document: EmbeddingQuantization;
42+
};
3643
}
3744

38-
const wellKnownEmbeddingMetadata = {
45+
const wellKnownEmbeddingMetadata = Object.freeze<Record<string, EmbeddingTypeInfo>>({
3946
[EmbeddingType.text3small_512.id]: {
4047
model: EMBEDDING_MODEL.TEXT3SMALL,
4148
family: 'text3small',
4249
dimensions: 512,
43-
}
44-
} as const satisfies Record<string, EmbeddingTypeInfo>;
50+
quantization: {
51+
query: 'float32',
52+
document: 'float32'
53+
},
54+
},
55+
[EmbeddingType.metis_1024_I16_Binary.id]: {
56+
model: EMBEDDING_MODEL.Metis_1024_I16_Binary,
57+
family: 'metis',
58+
dimensions: 1024,
59+
quantization: {
60+
query: 'float16',
61+
document: 'binary'
62+
},
63+
},
64+
});
4565

4666
export function getWellKnownEmbeddingTypeInfo(type: EmbeddingType): EmbeddingTypeInfo | undefined {
4767
return wellKnownEmbeddingMetadata[type.id];
@@ -86,7 +106,7 @@ export interface IEmbeddingsComputer {
86106
type: EmbeddingType,
87107
inputs: readonly string[],
88108
options?: ComputeEmbeddingsOptions,
89-
cancellationToken?: CancellationToken,
109+
token?: CancellationToken,
90110
): Promise<Embeddings | undefined>;
91111
}
92112

src/platform/endpoint/common/endpointProvider.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,11 @@ export function isEmbeddingModelInformation(model: IModelAPIResponse): model is
7878
}
7979

8080
export type ChatEndpointFamily = 'gpt-4.1' | 'gpt-4o-mini' | 'copilot-base';
81-
export type EmbeddingsEndpointFamily = 'text3small';
81+
export type EmbeddingsEndpointFamily = 'text3small' | 'metis';
8282

8383
export interface IEndpointProvider {
8484
readonly _serviceBrand: undefined;
85+
8586
/**
8687
* Get the embedding endpoint information
8788
*/

src/platform/workspaceChunkSearch/node/workspaceChunkAndEmbeddingCache.ts

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import { IRange, Range } from '../../../util/vs/editor/common/core/range';
1616
import { IInstantiationService, ServicesAccessor } from '../../../util/vs/platform/instantiation/common/instantiation';
1717
import { FileChunk, FileChunkWithEmbedding } from '../../chunking/common/chunk';
1818
import { stripChunkTextMetadata } from '../../chunking/common/chunkingStringUtils';
19-
import { EmbeddingType, EmbeddingVector } from '../../embeddings/common/embeddingsComputer';
19+
import { Embedding, EmbeddingType, EmbeddingVector } from '../../embeddings/common/embeddingsComputer';
2020
import { IFileSystemService } from '../../filesystem/common/fileSystemService';
2121
import { ILogService } from '../../log/common/logService';
2222
import { FileRepresentation, IWorkspaceFileIndex } from './workspaceFileIndex';
@@ -430,7 +430,6 @@ class DbCache implements IWorkspaceChunkAndEmbeddingCache {
430430
db.exec('DELETE FROM CacheMeta;');
431431
db.prepare('INSERT INTO CacheMeta (version, embeddingModel) VALUES (?, ?)').run(this.version, embeddingType.id);
432432

433-
434433
// Load existing disk db if it exists
435434
const diskCache = await instantiationService.invokeFunction(accessor => DiskCache.readDiskCache(
436435
accessor,
@@ -456,7 +455,10 @@ class DbCache implements IWorkspaceChunkAndEmbeddingCache {
456455
chunk.range.startColumn,
457456
chunk.range.endLineNumber,
458457
chunk.range.endColumn,
459-
Float32Array.from(typeof chunk.embedding === 'string' ? DiskCache.decodeEmbedding(chunk.embedding) : chunk.embedding),
458+
packEmbedding({
459+
type: embeddingType,
460+
value: typeof chunk.embedding === 'string' ? DiskCache.decodeEmbedding(chunk.embedding) : chunk.embedding,
461+
}),
460462
chunk.chunkHash ?? ''
461463
);
462464
}
@@ -532,8 +534,7 @@ class DbCache implements IWorkspaceChunkAndEmbeddingCache {
532534
if (all.length > 0) {
533535
const out = new Map<string, FileChunkWithEmbedding>();
534536
for (const row of all) {
535-
const embeddingData = row.embedding as Uint8Array;
536-
const embedding = Array.from(new Float32Array(embeddingData.buffer, embeddingData.byteOffset, embeddingData.byteLength / Float32Array.BYTES_PER_ELEMENT));
537+
const embedding = unpackEmbedding(this.embeddingType, row.embedding as Uint8Array);
537538

538539
const chunk: FileChunkWithEmbedding = {
539540
chunk: {
@@ -542,10 +543,7 @@ class DbCache implements IWorkspaceChunkAndEmbeddingCache {
542543
rawText: undefined,
543544
range: new Range(row.range_startLineNumber as number, row.range_startColumn as number, row.range_endLineNumber as number, row.range_endColumn as number),
544545
},
545-
embedding: {
546-
type: this.embeddingType,
547-
value: embedding,
548-
},
546+
embedding,
549547
chunkHash: row.chunkHash as string,
550548
};
551549
if (chunk.chunkHash) {
@@ -576,18 +574,14 @@ class DbCache implements IWorkspaceChunkAndEmbeddingCache {
576574
contentVersionId: fileIdResult.contentVersionId as string | undefined,
577575
fileHash: undefined,
578576
value: chunks.map((row): FileChunkWithEmbedding => {
579-
const embeddingData = row.embedding as Uint8Array;
580577
return {
581578
chunk: {
582579
file: file.uri,
583580
text: row.text as string,
584581
rawText: undefined,
585582
range: new Range(row.range_startLineNumber as number, row.range_startColumn as number, row.range_endLineNumber as number, row.range_endColumn as number),
586583
},
587-
embedding: {
588-
type: this.embeddingType,
589-
value: Array.from(new Float32Array(embeddingData.buffer, embeddingData.byteOffset, embeddingData.byteLength / Float32Array.BYTES_PER_ELEMENT)),
590-
},
584+
embedding: unpackEmbedding(this.embeddingType, row.embedding as Uint8Array),
591585
chunkHash: row.chunkHash as string | undefined,
592586
};
593587
}),
@@ -643,16 +637,14 @@ class DbCache implements IWorkspaceChunkAndEmbeddingCache {
643637

644638
this.db.exec('BEGIN TRANSACTION');
645639
for (const chunk of newEntry.value ?? []) {
646-
const float32Array = Float32Array.from(chunk.embedding.value);
647-
const embeddingData = new Uint8Array(float32Array.buffer, float32Array.byteOffset, float32Array.byteLength);
648640
insertStatement.run(
649641
fileResult.lastInsertRowid as number,
650642
chunk.chunk.text,
651643
chunk.chunk.range.startLineNumber,
652644
chunk.chunk.range.startColumn,
653645
chunk.chunk.range.endLineNumber,
654646
chunk.chunk.range.endColumn,
655-
embeddingData,
647+
packEmbedding(chunk.embedding),
656648
chunk.chunkHash ?? '',
657649
);
658650
}
@@ -665,4 +657,52 @@ class DbCache implements IWorkspaceChunkAndEmbeddingCache {
665657

666658
return chunks;
667659
}
660+
}
661+
662+
/**
663+
* Packs the embedding into a binary value for efficient storage.
664+
*/
665+
export function packEmbedding(embedding: Embedding): Uint8Array {
666+
if (embedding.type.equals(EmbeddingType.metis_1024_I16_Binary)) {
667+
// Generate packed binary
668+
if (embedding.value.length % 8 !== 0) {
669+
throw new Error(`Embedding value length must be a multiple of 8 for ${embedding.type.id}, got ${embedding.value.length}`);
670+
}
671+
672+
const data = new Uint8Array(embedding.value.length / 8);
673+
for (let i = 0; i < embedding.value.length; i += 8) {
674+
let value = 0;
675+
for (let j = 0; j < 8; j++) {
676+
value |= (embedding.value[i + j] >= 0 ? 1 : 0) << j;
677+
}
678+
data[i / 8] = value;
679+
}
680+
return data;
681+
}
682+
683+
// All other formats default to float32 for now
684+
const data = Float32Array.from(embedding.value);
685+
return new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
686+
}
687+
688+
/**
689+
* Unpacks an embedding from a binary value packed with {@link packEmbedding}.
690+
*/
691+
export function unpackEmbedding(type: EmbeddingType, data: Uint8Array): Embedding {
692+
if (type.equals(EmbeddingType.metis_1024_I16_Binary)) {
693+
// Old versions may have stored the values as a float32
694+
if (data.length <= 1024) {
695+
const values = new Array(data.length * 8);
696+
for (let i = 0; i < data.length; i++) {
697+
const byte = data[i];
698+
for (let j = 0; j < 8; j++) {
699+
values[i * 8 + j] = (byte & (1 << j)) > 0 ? 0.03125 : -0.03125;
700+
}
701+
}
702+
return { type, value: values };
703+
}
704+
}
705+
706+
const float32Array = new Float32Array(data.buffer, data.byteOffset, data.byteLength / 4);
707+
return { type, value: Array.from(float32Array) };
668708
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*---------------------------------------------------------------------------------------------
2+
* Copyright (c) Microsoft Corporation. All rights reserved.
3+
* Licensed under the MIT License. See License.txt in the project root for license information.
4+
*--------------------------------------------------------------------------------------------*/
5+
6+
import assert from 'assert';
7+
import { suite, test } from 'vitest';
8+
import { Embedding, EmbeddingType } from '../../../embeddings/common/embeddingsComputer';
9+
import { packEmbedding, unpackEmbedding } from '../../node/workspaceChunkAndEmbeddingCache';
10+
11+
suite('Pack Embedding', () => {
12+
test('Text3small should pack and unpack to same values', () => {
13+
const embedding: Embedding = {
14+
type: EmbeddingType.text3small_512,
15+
// Start with float32 array so that we don't check for the very small rounding
16+
// that can happen when going from js number -> float32
17+
value: Array.from(Float32Array.from({ length: 512 }, () => Math.random())),
18+
};
19+
20+
const serialized = packEmbedding(embedding);
21+
const deserialized = unpackEmbedding(EmbeddingType.text3small_512, serialized);
22+
assert.deepStrictEqual(deserialized.value.length, embedding.value.length);
23+
assert.deepStrictEqual(deserialized.value, embedding.value);
24+
});
25+
26+
test('Metis should use binary storage', () => {
27+
const embedding: Embedding = {
28+
type: EmbeddingType.metis_1024_I16_Binary,
29+
value: Array.from({ length: 1024 }, () => Math.random() < 0.5 ? 0.03125 : -0.03125)
30+
};
31+
32+
const serialized = packEmbedding(embedding);
33+
assert.strictEqual(serialized.length, 1024 / 8);
34+
35+
const deserialized = unpackEmbedding(EmbeddingType.metis_1024_I16_Binary, serialized);
36+
assert.deepStrictEqual(deserialized.value.length, embedding.value.length);
37+
assert.deepStrictEqual(deserialized.value, embedding.value);
38+
});
39+
40+
test('Unpack should work with buffer offsets', () => {
41+
const embedding: Embedding = {
42+
type: EmbeddingType.metis_1024_I16_Binary,
43+
value: Array.from({ length: 1024 }, () => Math.random() < 0.5 ? 0.03125 : -0.03125)
44+
};
45+
46+
const serialized = packEmbedding(embedding);
47+
48+
// Now create a new buffer and write the serialized data to it at an offset
49+
const prefixAndSuffixSize = 512;
50+
const buffer = new Uint8Array(serialized.length + prefixAndSuffixSize * 2);
51+
for (let i = 0; i < serialized.length; i++) {
52+
buffer[i + prefixAndSuffixSize] = serialized[i];
53+
}
54+
55+
const serializedCopy = new Uint8Array(buffer.buffer, prefixAndSuffixSize, serialized.length);
56+
57+
const deserialized = unpackEmbedding(EmbeddingType.metis_1024_I16_Binary, serializedCopy);
58+
assert.deepStrictEqual(deserialized.value.length, embedding.value.length);
59+
assert.deepStrictEqual(deserialized.value, embedding.value);
60+
});
61+
});

0 commit comments

Comments
 (0)