-
Notifications
You must be signed in to change notification settings - Fork 174
Description
Background
In ml-commons, remote models are accessed via a connector, which defines how to construct and send inference requests. This includes specifying endpoints, authentication, and the structure of the request payload—known as the request_body template.
Currently, when invoking the predict API with text_embedding, the system does not fully support passing user-defined parameters from the predict call through the connector to the remote endpoint. For example, sending a request such as:
POST {{base_url}}/_plugins/_ml/_predict/text_embedding/{{model_id}}
{
"text_docs": ["hello world"],
"parameters": {
"sparse_embedding_format": "TOKEN_ID"
}
}
produces the same output as omitting the parameters field, meaning the value of sparse_embedding_format does not affect the result. This limitation stems from the current invocation method and parameter handling associated with the text_embedding function, which does not fully propagate user-supplied parameters into the final HTTP payload sent for inference.
This design aims to enhance the parameter passing mechanism, ensuring that user-provided parameters during predict calls are correctly integrated with the connector’s request template and properly formatted in the outgoing request to the remote endpoint.
Implementation
The connector’s request_body is designed as a JSON string template with placeholders. The connector’s request_body template includes some parameters hardcoded as fixed values (Static Parameters) and others as placeholders to be replaced by predict call parameters (Dynamic Parameters). When creating the connector, users can optionally define default values for dynamic parameters in the "parameters"field. During a predict call, if a dynamic parameter is not provided, the system will first check whether a default value exists. If a default value is found, it will be used; if no default is defined, the parameter will be removed from the final request payload. Static parameters’ values remain fixed and cannot be overridden.
Example format:
{
"input": "${parameters.input}",
"parameters": {
"sparseEmbeddingFormat": "${parameters.sparseEmbeddingFormat}",
"embeddingContentType": null
}
}
Example Usage Overview: SageMaker Model Deployment, Connector Creation, and API Invocation
In the following section, we will outline how to use this feature within the existing workflow of deploying a model on SageMaker, configuring a connector with a specified request payload format for remote inference, and invoking the predict API. This overview focuses specifically on the changes and additions related to the parameter passing feature in the deployment and connector setup code, rather than providing a full end-to-end tutorial. Detailed code design and configuration snippets will be provided to clearly demonstrate how parameters flow and are managed between the client, connector, and remote endpoint in the context of this feature.
Deploy sparse model to SageMaker
Prepare the model file for SageMaker.
%%writefile handler/neural_sparse_handler.py
import os
import re
import itertools
import json
import torch
from ts.torch_handler.base_handler import BaseHandler
from sentence_transformers.sparse_encoder import SparseEncoder
model_id = os.environ.get(
"MODEL_ID", "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
)
max_bs = int(os.environ.get("MAX_BS", 32))
trust_remote_code = model_id.endswith("gte")
class SparseEncodingModelHandler(BaseHandler):
def __init__(self):
super().__init__()
self.initialized = False
def initialize(self, context):
self.manifest = context.manifest
properties = context.system_properties
# Print initialization parameters
print(f"Initializing SparseEncodingModelHandler with model_id: {model_id}")
# load model and tokenizer
self.device = torch.device(
"cuda:" + str(properties.get("gpu_id"))
if torch.cuda.is_available()
else "cpu"
)
print(f"Using device: {self.device}")
self.model = SparseEncoder(model_id, device=self.device, trust_remote_code=trust_remote_code)
self.initialized = True
def preprocess(self, requests):
inputSentence = []
batch_idx = []
parameters = {}
for request in requests:
request_body = request.get("body")
if isinstance(request_body, bytearray):
request_body = request_body.decode("utf-8")
request_json = json.loads(request_body)
if isinstance(request_json, dict) and "input" in request_json:
inputs = request_json["input"]
parameters = request_json.get("parameters", {})
if isinstance(inputs, list):
inputSentence += inputs
batch_idx.append(len(inputs))
else:
inputSentence.append(inputs)
batch_idx.append(1)
else:
if isinstance(request_json, list):
inputSentence += request_json
batch_idx.append(len(request_json))
else:
inputSentence.append(request_json)
batch_idx.append(1)
return inputSentence, batch_idx, parameters
def handle(self, data, context):
inputSentence, batch_idx, parameters = self.preprocess(data)
model_output = self.model.encode_document(inputSentence, batch_size=max_bs)
sparse_embedding = list(map(dict,self.model.decode(model_output, parameters)))
outputs = [sparse_embedding[s:e]
for s, e in zip([0]+list(itertools.accumulate(batch_idx))[:-1],
itertools.accumulate(batch_idx))]
return outputs
Create connector
POST /_plugins/_ml/connectors/_create
{
"name": "test",
"description": "Test connector for Sagemaker model",
"version": 1,
"protocol": "aws_sigv4",
"credential": {
"access_key": "your access key",
"secret_key": "your secret key"
},
"parameters": {
"region": "{region}",
"service_name": "sagemaker",
"input_docs_processed_step_size": 2,
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"headers": {
"content-type": "application/json"
},
"url": "https://runtime.sagemaker.{region}.amazonaws.com/endpoints/{predictor.endpoint_name}/invocations",
"request_body": "{ \"input\": ${parameters.input}, \"parameters\": {\"sparse_embedding_format\": \"${parameters.sparseEmbeddingFormat}\" }}"
}],
"client_config":{
"max_retry_times": -1,
"max_connection": 60,
"retry_backoff_millis": 10
}
}
Register model
POST /_plugins/_ml/models/_register?deploy=true
{
"name": "test",
"function_name": "remote",
"version": "1.0.0",
"connector_id": "{connector id}",
"description": "Test connector for Sagemaker model"
}
Invoking Text Embedding Prediction API
POST {{base_url}}/_plugins/_ml/_predict/text_embedding/{{model_id}}
{
"text_docs":["hello world"],
"parameters":{
"sparse_embedding_format": "TOKEN_ID"
}
}
Metadata
Metadata
Assignees
Labels
Type
Projects
Status