From 680053237565a07e6a05462a52e5b91dc088ed49 Mon Sep 17 00:00:00 2001 From: yuye-aws Date: Wed, 21 May 2025 17:01:42 +0800 Subject: [PATCH] POC: Support sparse model return token_id Signed-off-by: yuye-aws --- .../ml/common/input/nlp/TextDocsMLInput.java | 5 +- .../SparseEncodingParameters.java | 95 +++++++++++++++++++ .../engine/algorithms/TextEmbeddingModel.java | 5 + .../SparseEncodingTranslator.java | 30 ++++-- .../ml/plugin/MachineLearningPlugin.java | 4 +- 5 files changed, 127 insertions(+), 12 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java diff --git a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java index 1a2f201dd5..a818137ad7 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java @@ -27,10 +27,7 @@ * ML input class which supports a list fo text docs. * This class can be used for TEXT_EMBEDDING model. */ -@org.opensearch.ml.common.annotation.MLInput(functionNames = { - FunctionName.TEXT_EMBEDDING, - FunctionName.SPARSE_ENCODING, - FunctionName.SPARSE_TOKENIZE }) +@org.opensearch.ml.common.annotation.MLInput(functionNames = { FunctionName.TEXT_EMBEDDING }) public class TextDocsMLInput extends MLInput { public static final String TEXT_DOCS_FIELD = "text_docs"; public static final String RESULT_FILTER_FIELD = "result_filter"; diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java new file mode 100644 index 0000000000..f5bf85d331 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.input.parameter.textembedding; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.MLAlgoParameter; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; + +import lombok.Builder; + +@MLAlgoParameter(algorithms = { FunctionName.SPARSE_ENCODING }) +public class SparseEncodingParameters implements MLAlgoParams { + + public static final String PARSE_FIELD_NAME = FunctionName.SPARSE_ENCODING.name(); + public static final String SPARSE_ENCODING_FORMAT_FIELD = "sparse_encoding_format"; + + @Override + public int getVersion() { + return 1; + } + + @Override + public String getWriteableName() { + return PARSE_FIELD_NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(sparseEncodingType.name()); + } + + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + SparseEncodingParameters::parse + ); + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + xContentBuilder.startObject(); + if (sparseEncodingType != null) { + xContentBuilder.field(SPARSE_ENCODING_FORMAT_FIELD, sparseEncodingType.name()); + } + xContentBuilder.endObject(); + return xContentBuilder; + } + + public enum SparseEncodingFormat { + WORD, + INT + } + + // The type of the content to be embedded + private final SparseEncodingFormat sparseEncodingType; + + @Builder(toBuilder = true) + public SparseEncodingParameters(SparseEncodingFormat sparseEncodingType) { + this.sparseEncodingType = sparseEncodingType; + } + + public SparseEncodingFormat getSparseEncodingType() { + return sparseEncodingType; + } + + public static MLAlgoParams parse(XContentParser parser) throws IOException { + SparseEncodingFormat sparseEncodingType = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + if (fieldName.equals(SPARSE_ENCODING_FORMAT_FIELD)) { + String contentType = parser.text(); + sparseEncodingType = SparseEncodingFormat.valueOf(contentType.toUpperCase(Locale.ROOT)); + } else { + parser.skipChildren(); + } + } + return new SparseEncodingParameters(sparseEncodingType); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java index 63c11ca79d..3d29d3fbef 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java @@ -12,6 +12,7 @@ import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; +import org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -40,6 +41,10 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla for (String doc : textDocsInput.getDocs()) { Input input = new Input(); input.add(doc); + if (mlParams instanceof SparseEncodingParameters) { + input.add("sparse_encoding_format", ((SparseEncodingParameters) mlParams).getSparseEncodingType().name()); + } + output = getPredictor().predict(input); tensorOutputs.add(parseModelTensorOutput(output, resultFilter)); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java index baebbe1972..f67e124695 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java @@ -6,34 +6,48 @@ package org.opensearch.ml.engine.algorithms.sparse_encoding; import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY; +import static org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters.SPARSE_ENCODING_FORMAT_FIELD; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.Iterator; import java.util.List; import java.util.Map; +import org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.engine.algorithms.SentenceTransformerTranslator; +import ai.djl.modality.Input; import ai.djl.modality.Output; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.translate.TranslatorContext; public class SparseEncodingTranslator extends SentenceTransformerTranslator { + + @Override + public NDList processInput(TranslatorContext ctx, Input input) { + String sparse_encoding_format = input.getAsString(SPARSE_ENCODING_FORMAT_FIELD); + if (sparse_encoding_format != null) { + ctx.setAttachment(SPARSE_ENCODING_FORMAT_FIELD, sparse_encoding_format); + } + return super.processInput(ctx, input); + } + @Override public Output processOutput(TranslatorContext ctx, NDList list) { Output output = new Output(200, "OK"); + Object sparseEncodingFormatObject = ctx.getAttachment(SPARSE_ENCODING_FORMAT_FIELD); + String sparseEncodingFormatString = sparseEncodingFormatObject != null + ? sparseEncodingFormatObject.toString() + : SparseEncodingParameters.SparseEncodingFormat.WORD.name(); List outputs = new ArrayList<>(); - Iterator iterator = list.iterator(); - while (iterator.hasNext()) { - NDArray ndArray = iterator.next(); + for (NDArray ndArray : list) { String name = ndArray.getName(); - Map tokenWeightsMap = convertOutput(ndArray); + Map tokenWeightsMap = convertOutput(ndArray, sparseEncodingFormatString); Map wrappedMap = Map.of(ML_MAP_RESPONSE_KEY, Collections.singletonList(tokenWeightsMap)); ModelTensor tensor = ModelTensor.builder().name(name).dataAsMap(wrappedMap).build(); outputs.add(tensor); @@ -44,12 +58,14 @@ public Output processOutput(TranslatorContext ctx, NDList list) { return output; } - private Map convertOutput(NDArray array) { + private Map convertOutput(NDArray array, String sparseEncodingFormat) { Map map = new HashMap<>(); NDArray nonZeroIndices = array.nonzero().squeeze(); for (long index : nonZeroIndices.toLongArray()) { - String s = this.tokenizer.decode(new long[] { index }, true); + String s = sparseEncodingFormat.equals(SparseEncodingParameters.SparseEncodingFormat.INT.name()) + ? Long.toString(index) + : this.tokenizer.decode(new long[] { index }, true); if (!s.isEmpty()) { map.put(s, array.getFloat(index)); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 38a84804f1..bf1e7f27e1 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -135,6 +135,7 @@ import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams; import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.settings.MLCommonsSettings; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; @@ -1038,7 +1039,8 @@ public List getNamedXContent() { RCFSummarizeParams.XCONTENT_REGISTRY, LogisticRegressionParams.XCONTENT_REGISTRY, TextEmbeddingModelConfig.XCONTENT_REGISTRY, - AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY + AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY, + SparseEncodingParameters.XCONTENT_REGISTRY ); }