Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -350,4 +350,8 @@ private MLCommonsSettings() {}
// Feature flag for enabling telemetry static metric collection job -- MLStatsJobProcessor
public static final Setting<Boolean> ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED = Setting
.boolSetting("plugins.ml_commons.metrics_static_collection_enabled", false, Setting.Property.NodeScope, Setting.Property.Final);

// This setting is for refreshing ML key from the ml-config index to the memory.
public static final Setting<Boolean> ML_COMMONS_KEY_REFRESH_ENABLED = Setting
.boolSetting("plugins.ml_commons.key_refresh_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_KEY_REFRESH_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED;
Expand Down Expand Up @@ -48,6 +49,7 @@ public class MLFeatureEnabledSetting {

private volatile Boolean isMetricCollectionEnabled;
private volatile Boolean isStaticMetricCollectionEnabled;
private volatile Boolean isKeyRefreshEnabled;

private final List<SettingsChangeListener> listeners = new ArrayList<>();

Expand All @@ -64,6 +66,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
isRagSearchPipelineEnabled = ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(settings);
isMetricCollectionEnabled = ML_COMMONS_METRIC_COLLECTION_ENABLED.get(settings);
isStaticMetricCollectionEnabled = ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED.get(settings);
isKeyRefreshEnabled = ML_COMMONS_KEY_REFRESH_ENABLED.get(settings);

clusterService
.getClusterSettings()
Expand All @@ -86,6 +89,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> isRagSearchPipelineEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_KEY_REFRESH_ENABLED, it -> isKeyRefreshEnabled = it);
}

/**
Expand Down Expand Up @@ -176,6 +180,14 @@ public boolean isStaticMetricCollectionEnabled() {
return isStaticMetricCollectionEnabled;
}

/**
* Whether the key refresh feature is enabled.
* @return whether the key refresh is enabled.
*/
public boolean isKeyRefreshEnabled() {
return isKeyRefreshEnabled;
}

@VisibleForTesting
public void notifyMultiTenancyListeners(boolean isEnabled) {
for (SettingsChangeListener listener : listeners) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ public void setUp() {
MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED,
MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED,
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED,
MLCommonsSettings.ML_COMMONS_KEY_REFRESH_ENABLED
)
);
when(mockClusterService.getClusterSettings()).thenReturn(mockClusterSettings);
Expand All @@ -65,6 +66,7 @@ public void testDefaults_allFeaturesEnabled() {
.put("plugins.ml_commons.rag_pipeline_feature_enabled", true)
.put("plugins.ml_commons.metrics_collection_enabled", true)
.put("plugins.ml_commons.metrics_static_collection_enabled", true)
.put("plugins.ml_commons.key_refresh_enabled", true)
.build();

MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
Expand All @@ -81,6 +83,7 @@ public void testDefaults_allFeaturesEnabled() {
assertTrue(setting.isRagSearchPipelineEnabled());
assertTrue(setting.isMetricCollectionEnabled());
assertTrue(setting.isStaticMetricCollectionEnabled());
assertTrue(setting.isKeyRefreshEnabled());
}

@Test
Expand All @@ -99,6 +102,7 @@ public void testDefaults_someFeaturesDisabled() {
.put("plugins.ml_commons.rag_pipeline_feature_enabled", false)
.put("plugins.ml_commons.metrics_collection_enabled", false)
.put("plugins.ml_commons.metrics_static_collection_enabled", false)
.put("plugins.ml_commons.key_refresh_enabled", true)
.build();

MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
Expand All @@ -115,6 +119,7 @@ public void testDefaults_someFeaturesDisabled() {
assertFalse(setting.isRagSearchPipelineEnabled());
assertFalse(setting.isMetricCollectionEnabled());
assertFalse(setting.isStaticMetricCollectionEnabled());
assertTrue(setting.isKeyRefreshEnabled());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ public String generateMasterKey() {
return Base64.getEncoder().encodeToString(keyBytes);
}

// Refresh the key from the config index.
public void refreshMasterKey(ActionListener<Boolean> listener) {
try {
// Currently, we only handle no tenant case.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that because the tenantId is not available when running the syncUp on each node? otherwise just fill in the tenantId in the initMasterKeyFromIndex(tenantId) it should refresh the masterKey.

initMasterKeyFromIndex(null);
listener.onResponse(true);
} catch (Exception e) {
log.info("Refreshing ML key failed.");
listener.onFailure(e);
}
}

private JceMasterKey createJceMasterKey(String tenantId) {
byte[] bytes = Base64.getDecoder().decode(tenantMasterKeys.get(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID)));
return JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NOPADDING");
Expand All @@ -130,6 +142,10 @@ private void initMasterKey(String tenantId) {
if (tenantMasterKeys.containsKey(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID))) {
return;
}
initMasterKeyFromIndex(tenantId);
}

private void initMasterKeyFromIndex(String tenantId) {
String masterKeyId = MASTER_KEY;
if (tenantId != null) {
masterKeyId = MASTER_KEY + "_" + hashString(tenantId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,64 @@ public void encrypt_GetSourceAsMapIsNull_ShouldThrowResourceNotFound() throws Ex
encryptor.encrypt("test", TENANT_ID);
}

@Test
public void refreshMLKey() throws IOException {
doAnswer(invocation -> {
ActionListener<Boolean> actionListener = (ActionListener) invocation.getArgument(0);
actionListener.onResponse(true);
return null;
}).when(mlIndicesHandler).initMLConfigIndex(any());

GetResponse response = prepareMLConfigResponse(null);

doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onResponse(response);
return null;
}).when(client).get(any(), any());

ActionListener<Boolean> refreshMLKeyListener = ActionListener.wrap(r -> {
if (r) {
Assert.assertNull(encryptor.getMasterKey(null));
String encrypted = encryptor.encrypt("test", null);
String decrypted = encryptor.decrypt(encrypted, null);
Assert.assertEquals("test", decrypted);
Assert.assertEquals(masterKey.get(DEFAULT_TENANT_ID), encryptor.getMasterKey(null));
return;
}
}, e -> { throw new MLException(e); });
Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler);
((EncryptorImpl) encryptor).refreshMasterKey(refreshMLKeyListener);
}

@Test
public void refreshMLKeyException() throws IOException {
exceptionRule.expect(MLException.class);
exceptionRule.expectMessage("test exception");
doThrow(new RuntimeException("test exception")).when(mlIndicesHandler).initMLConfigIndex(any());

GetResponse response = prepareMLConfigResponse(null);

doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onResponse(response);
return null;
}).when(client).get(any(), any());

ActionListener<Boolean> refreshMLKeyListener = ActionListener.wrap(r -> {
if (r) {
Assert.assertNull(encryptor.getMasterKey(null));
String encrypted = encryptor.encrypt("test", null);
String decrypted = encryptor.decrypt(encrypted, null);
Assert.assertEquals("test", decrypted);
Assert.assertEquals(masterKey.get(DEFAULT_TENANT_ID), encryptor.getMasterKey(null));
return;
}
}, e -> { throw new MLException(e); });
Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler);
((EncryptorImpl) encryptor).refreshMasterKey(refreshMLKeyListener);
}

// Helper method to prepare a valid IndexResponse
private IndexResponse prepareIndexResponse() {
ShardId shardId = new ShardId(ML_CONFIG_INDEX, "index_uuid", 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeRequest;
Expand All @@ -36,6 +38,7 @@
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.engine.utils.FileUtils;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
Expand Down Expand Up @@ -66,6 +69,8 @@ public class TransportSyncUpOnNodeAction extends
private volatile Integer mlTaskTimeout;

private final MLModelCacheHelper mlModelCacheHelper;
private final EncryptorImpl encryptor;
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportSyncUpOnNodeAction(
Expand All @@ -80,7 +85,9 @@ public TransportSyncUpOnNodeAction(
Client client,
NamedXContentRegistry xContentRegistry,
MLEngine mlEngine,
MLModelCacheHelper mlModelCacheHelper
MLModelCacheHelper mlModelCacheHelper,
EncryptorImpl encryptor,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(
MLSyncUpAction.NAME,
Expand All @@ -103,6 +110,8 @@ public TransportSyncUpOnNodeAction(
this.xContentRegistry = xContentRegistry;
this.mlEngine = mlEngine;
this.mlModelCacheHelper = mlModelCacheHelper;
this.encryptor = encryptor;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;

this.mlTaskTimeout = ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS, it -> { mlTaskTimeout = it; });
Expand Down Expand Up @@ -182,6 +191,14 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncU

cleanUpLocalCache(runningDeployModelTasks);
cleanUpLocalCacheFiles();
if (mlFeatureEnabledSetting.isKeyRefreshEnabled()) {
encryptor.refreshMasterKey(ActionListener.wrap(r -> {
if (r) {
log.debug("Refresh ML key completed.");
return;
}
}, e -> { log.error("Failed to refresh ML key", e); }));
}

return new MLSyncUpNodeResponse(
clusterService.localNode(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ public MLSyncUpCron(
@Override
public void run() {
initMLConfig();
if (!clusterService.state().metadata().indices().containsKey(ML_MODEL_INDEX)) {
// no need to run sync up job if no model index
if (!clusterService.state().metadata().indices().containsKey(ML_MODEL_INDEX) && !mlFeatureEnabledSetting.isKeyRefreshEnabled()) {
// no need to run sync up job if no model index and ML key refresh disabled
log.debug("Skipping sync up job - ML model index not found");
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,8 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED,
MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED,
MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED,
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED,
MLCommonsSettings.ML_COMMONS_KEY_REFRESH_ENABLED
);
return settings;
}
Expand Down
Loading
Loading