-
Notifications
You must be signed in to change notification settings - Fork 174
Add ml-commons passthrough post process function #4111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ml-commons passthrough post process function #4111
Conversation
...ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunction.java
Show resolved
Hide resolved
...ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunction.java
Show resolved
Hide resolved
@q-andy fix spotless :) |
1d97d7b
to
a89b235
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #4111 +/- ##
============================================
+ Coverage 81.80% 81.82% +0.01%
- Complexity 8847 8866 +19
============================================
Files 761 762 +1
Lines 38099 38152 +53
Branches 4250 4263 +13
============================================
+ Hits 31168 31217 +49
+ Misses 5110 5109 -1
- Partials 1821 1826 +5
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Andy Qin <[email protected]>
Most CI passed, just 2 flaky integ test failures are unrelated to this change. Could you take another look @dhrubo-os @Zhangxunmt |
} | ||
|
||
@Override | ||
public List<ModelTensor> process(Map<String, Object> input, MLResultDataType dataType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of Map<String, Object> input
, Since this parameter represents ML Commons inference results or response data, some better name suggestions would be:
mlCommonsResponse
inferenceResponse
responseData
modelResponse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, didn't realize you could rename the parameter variables of an overridden function. Changed to mlCommonsResponse
/** | ||
* A post-processing function for calling a remote ml commons instance that preserves the original neural sparse response structure | ||
* to avoid double-wrapping when receiving responses from another ML-Commons instance. | ||
*/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add example in the doc for better understanding?
@@ -35,6 +36,7 @@ public class MLPostProcessFunction { | |||
public static final String BEDROCK_RERANK = "connector.post_process.bedrock.rerank"; | |||
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding"; | |||
public static final String DEFAULT_RERANK = "connector.post_process.default.rerank"; | |||
public static final String ML_COMMONS_PASSTHROUGH = "connector.post_process.mlcommons.passthrough"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as this is a special kind, can we please add a comment here?
if (dataTypeObj instanceof String) { | ||
try { | ||
dataType = MLResultDataType.valueOf((String) dataTypeObj); | ||
} catch (IllegalArgumentException e) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IllegalArgumentException
is a 500 level error. Should we treat this as a 5xx error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also what is the reason for leaving as null?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left as null instead of treating like 5xx is because even if the model data type is invalid, there may be usecases where we can still parse and use the model data or dataAsMap. E.g. some model types like neural sparse or NER don't include data type as part of the model response, so null is still valid.
Right now since we're focused on neural sparse and dense, my thought process is its better to leave it flexible to be able to possible handle different model response formats. For example, in the future we me add a new datatype for dense models and we might inference that from an older version of ml-commons: perhaps we can still use the data by casting the data at the processor level. Updated the comment to explain this.
long[] shape = null; | ||
if (map.containsKey(ModelTensor.SHAPE_FIELD)) { | ||
Object shapeObj = map.get(ModelTensor.SHAPE_FIELD); | ||
if (shapeObj instanceof List<?> shapeList) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the else we are sending null
for shape, is that expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, e.g. neural sparse and dense models populate different fields, sparse leaves data_type, data, and shape as null.
Dense
{
"name": "sentence_embedding",
"data_type": "FLOAT32",
"shape": [
768
],
"data": [...]
}
Sparse
{
"name": "output",
"dataAsMap": {
"response": [
{ ... }
]
}
}
Number[] data = null; | ||
if (map.containsKey(ModelTensor.DATA_FIELD)) { | ||
Object dataObj = map.get(ModelTensor.DATA_FIELD); | ||
if (dataObj instanceof List<?> dataList) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens if we send data as null to the model Tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above comment, this is valid for neural sparse and other remote model format. Looks like data field is primarily used for vector/numerical info, and if a model output doesn't include that, then it uses dataAsMap instead and data field being null is valid.
|
||
// Handle data array | ||
Number[] data = null; | ||
if (map.containsKey(ModelTensor.DATA_FIELD)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like both of the underlying logic could be put in a common method like:
private <T> T[] processNumericArray(Object obj, Class<T> type) {
if (obj instanceof List<?> list) {
T[] result = (T[]) Array.newInstance(type, list.size());
// ... process the list
return result;
}
return null;
}
after using the post processing function, the name seems changed, the name used to be "response" but after the post processing become "output". Is this change intended? I am sure if the name is used somewhere but better not change the name if not intentionally.
|
Signed-off-by: Andy Qin <[email protected]>
Fixed, the name is relevant for different model types. I changed it so the name will be passthroughed as well, this is just a typo in the PR description. |
Description
Adds a predefined post-process function to unwrap model output when making a remote call to second ml-commons cluster. When calling a remote second ml-commons predict API with a remote connector, the remote connector output will "double wrap" the second ml-commons. Tested compatibility with neural-search
SparseEncodingProcessor
andTextEmbeddingProcessor
.Predict call without post process function (double wrapped):
After using post process function:
Related Issues
Resolves #[Issue number to be closed when this PR is merged]
Check List
--signoff
.By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.