Skip to content

Commit c31819c

Browse files
committed
fix test failures
add feature flag Signed-off-by: Jing Zhang <[email protected]>
1 parent 1fbd9d7 commit c31819c

File tree

12 files changed

+190
-32
lines changed

12 files changed

+190
-32
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,9 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
362362
}
363363

364364
private boolean neededStreamParameterInPayload(Map<String, String> parameters) {
365+
if (parameters == null) {
366+
return false;
367+
}
365368
boolean isStream = parameters.containsKey("stream");
366369
if (!isStream) {
367370
return false;

common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,4 +342,8 @@ private MLCommonsSettings() {}
342342
/** This setting sets the remote metadata service name */
343343
public static final Setting<String> REMOTE_METADATA_SERVICE_NAME = Setting
344344
.simpleString("plugins.ml_commons." + REMOTE_METADATA_SERVICE_NAME_KEY, Setting.Property.NodeScope, Setting.Property.Final);
345+
346+
/** This setting is to enable/disable streaming feature. */
347+
public static final Setting<Boolean> ML_COMMONS_STREAM_ENABLED = Setting
348+
.boolSetting("plugins.ml_commons.stream_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
345349
}

common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED;
1717
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED;
1818
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;
19+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STREAM_ENABLED;
1920

2021
import java.util.ArrayList;
2122
import java.util.List;
@@ -40,6 +41,7 @@ public class MLFeatureEnabledSetting {
4041

4142
// This is to identify if this node is in multi-tenancy or not.
4243
private volatile Boolean isMultiTenancyEnabled;
44+
private volatile Boolean isStreamEnabled;
4345

4446
private volatile Boolean isMcpServerEnabled;
4547

@@ -55,6 +57,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
5557
isBatchInferenceEnabled = ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED.get(settings);
5658
isMultiTenancyEnabled = ML_COMMONS_MULTI_TENANCY_ENABLED.get(settings);
5759
isMcpServerEnabled = ML_COMMONS_MCP_SERVER_ENABLED.get(settings);
60+
isStreamEnabled = ML_COMMONS_STREAM_ENABLED.get(settings);
5861

5962
clusterService
6063
.getClusterSettings()
@@ -74,6 +77,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
7477
.getClusterSettings()
7578
.addSettingsUpdateConsumer(ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED, it -> isBatchInferenceEnabled = it);
7679
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MCP_SERVER_ENABLED, it -> isMcpServerEnabled = it);
80+
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_STREAM_ENABLED, it -> isStreamEnabled = it);
7781
}
7882

7983
/**
@@ -144,6 +148,13 @@ public boolean isMcpServerEnabled() {
144148
return isMcpServerEnabled;
145149
}
146150

151+
/** Whether the streaming feature is enabled. If disabled, APIs in ml-commons will block stream.
152+
* @return whether the streaming is enabled.
153+
*/
154+
public boolean isStreamEnabled() {
155+
return isStreamEnabled;
156+
}
157+
147158
public void addListener(SettingsChangeListener listener) {
148159
listeners.add(listener);
149160
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,11 @@ public Map<String, String> getConnectorCredential(Connector connector) {
154154

155155
public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
156156
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
157-
predictable.initModel(mlModel, params, encryptor, streamManager, threadPool);
157+
if (mlModel.getAlgorithm() == FunctionName.REMOTE) {
158+
predictable.initModel(mlModel, params, encryptor, streamManager, threadPool);
159+
} else {
160+
predictable.initModel(mlModel, params, encryptor);
161+
}
158162
return predictable;
159163
}
160164

ml-algorithms/src/main/java/org/opensearch/ml/engine/arrow/RemoteModelStreamProducer.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.engine.arrow;
27

38
import java.nio.charset.StandardCharsets;

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ public class AbstractConnectorExecutorTest {
2121
@Before
2222
public void setUp() {
2323
MockitoAnnotations.initMocks(this);
24+
when(mockConnector.getAccessKey()).thenReturn("access_key");
25+
when(mockConnector.getSecretKey()).thenReturn("secret_key");
26+
when(mockConnector.getSessionToken()).thenReturn("session_token");
27+
when(mockConnector.getRegion()).thenReturn("us-east-1-test");
2428
executor = new AwsConnectorExecutor(mockConnector);
2529
connectorClientConfig = new ConnectorClientConfig();
2630
}

0 commit comments

Comments
 (0)