Skip to content

Commit 0b1643f

Browse files
committed
refresh ml key
Signed-off-by: Jing Zhang <[email protected]>
1 parent 782e97d commit 0b1643f

File tree

7 files changed

+69
-4
lines changed

7 files changed

+69
-4
lines changed

common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,4 +350,8 @@ private MLCommonsSettings() {}
350350
// Feature flag for enabling telemetry static metric collection job -- MLStatsJobProcessor
351351
public static final Setting<Boolean> ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED = Setting
352352
.boolSetting("plugins.ml_commons.metrics_static_collection_enabled", false, Setting.Property.NodeScope, Setting.Property.Final);
353+
354+
// This setting is for refreshing ML key from the ml-config index to the memory.
355+
public static final Setting<Boolean> ML_COMMONS_KEY_REFRESH_ENABLED = Setting
356+
.boolSetting("plugins.ml_commons.key_refresh_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
353357
}

common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED;
1111
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED;
1212
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED;
13+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_KEY_REFRESH_ENABLED;
1314
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED;
1415
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED;
1516
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED;
@@ -50,6 +51,7 @@ public class MLFeatureEnabledSetting {
5051

5152
private volatile Boolean isMetricCollectionEnabled;
5253
private volatile Boolean isStaticMetricCollectionEnabled;
54+
private volatile Boolean isKeyRefreshEnabled;
5355

5456
private final List<SettingsChangeListener> listeners = new ArrayList<>();
5557

@@ -66,6 +68,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
6668
isRagSearchPipelineEnabled = ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(settings);
6769
isMetricCollectionEnabled = ML_COMMONS_METRIC_COLLECTION_ENABLED.get(settings);
6870
isStaticMetricCollectionEnabled = ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED.get(settings);
71+
isKeyRefreshEnabled = ML_COMMONS_KEY_REFRESH_ENABLED.get(settings);
6972

7073
clusterService
7174
.getClusterSettings()
@@ -88,6 +91,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
8891
clusterService
8992
.getClusterSettings()
9093
.addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> isRagSearchPipelineEnabled = it);
94+
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_KEY_REFRESH_ENABLED, it -> isKeyRefreshEnabled = it);
9195
}
9296

9397
/**
@@ -178,6 +182,14 @@ public boolean isStaticMetricCollectionEnabled() {
178182
return isStaticMetricCollectionEnabled;
179183
}
180184

185+
/**
186+
* Whether the key refresh feature is enabled.
187+
* @return whether the key refresh is enabled.
188+
*/
189+
public boolean isKeyRefreshEnabled() {
190+
return isKeyRefreshEnabled;
191+
}
192+
181193
@VisibleForTesting
182194
public void notifyMultiTenancyListeners(boolean isEnabled) {
183195
for (SettingsChangeListener listener : listeners) {

ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,18 @@ public String generateMasterKey() {
121121
return Base64.getEncoder().encodeToString(keyBytes);
122122
}
123123

124+
// Refresh the key from the config index.
125+
public void refreshMasterKey(ActionListener<Boolean> listener) {
126+
try {
127+
// Currently, we only handle no tenant case.
128+
initMasterKeyFromIndex(null);
129+
listener.onResponse(true);
130+
} catch (Exception e) {
131+
log.info("Refreshing ML key failed.");
132+
listener.onFailure(e);
133+
}
134+
}
135+
124136
private JceMasterKey createJceMasterKey(String tenantId) {
125137
byte[] bytes = Base64.getDecoder().decode(tenantMasterKeys.get(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID)));
126138
return JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NOPADDING");
@@ -130,6 +142,10 @@ private void initMasterKey(String tenantId) {
130142
if (tenantMasterKeys.containsKey(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID))) {
131143
return;
132144
}
145+
initMasterKeyFromIndex(tenantId);
146+
}
147+
148+
private void initMasterKeyFromIndex(String tenantId) {
133149
String masterKeyId = MASTER_KEY;
134150
if (tenantId != null) {
135151
masterKeyId = MASTER_KEY + "_" + hashString(tenantId);

plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
import org.opensearch.cluster.service.ClusterService;
2424
import org.opensearch.common.inject.Inject;
2525
import org.opensearch.common.settings.Settings;
26+
import org.opensearch.core.action.ActionListener;
2627
import org.opensearch.core.common.io.stream.StreamInput;
2728
import org.opensearch.core.xcontent.NamedXContentRegistry;
2829
import org.opensearch.ml.common.MLTask;
2930
import org.opensearch.ml.common.MLTaskState;
3031
import org.opensearch.ml.common.MLTaskType;
32+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
3133
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
3234
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
3335
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeRequest;
@@ -36,6 +38,7 @@
3638
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse;
3739
import org.opensearch.ml.engine.MLEngine;
3840
import org.opensearch.ml.engine.ModelHelper;
41+
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
3942
import org.opensearch.ml.engine.utils.FileUtils;
4043
import org.opensearch.ml.model.MLModelCacheHelper;
4144
import org.opensearch.ml.model.MLModelManager;
@@ -66,6 +69,8 @@ public class TransportSyncUpOnNodeAction extends
6669
private volatile Integer mlTaskTimeout;
6770

6871
private final MLModelCacheHelper mlModelCacheHelper;
72+
private final EncryptorImpl encryptor;
73+
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
6974

7075
@Inject
7176
public TransportSyncUpOnNodeAction(
@@ -80,7 +85,9 @@ public TransportSyncUpOnNodeAction(
8085
Client client,
8186
NamedXContentRegistry xContentRegistry,
8287
MLEngine mlEngine,
83-
MLModelCacheHelper mlModelCacheHelper
88+
MLModelCacheHelper mlModelCacheHelper,
89+
EncryptorImpl encryptor,
90+
MLFeatureEnabledSetting mlFeatureEnabledSetting
8491
) {
8592
super(
8693
MLSyncUpAction.NAME,
@@ -103,6 +110,8 @@ public TransportSyncUpOnNodeAction(
103110
this.xContentRegistry = xContentRegistry;
104111
this.mlEngine = mlEngine;
105112
this.mlModelCacheHelper = mlModelCacheHelper;
113+
this.encryptor = encryptor;
114+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
106115

107116
this.mlTaskTimeout = ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS.get(settings);
108117
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS, it -> { mlTaskTimeout = it; });
@@ -182,6 +191,14 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncU
182191

183192
cleanUpLocalCache(runningDeployModelTasks);
184193
cleanUpLocalCacheFiles();
194+
if (mlFeatureEnabledSetting.isKeyRefreshEnabled()) {
195+
encryptor.refreshMasterKey(ActionListener.wrap(r -> {
196+
if (r) {
197+
log.debug("Refresh ML key completed.");
198+
return;
199+
}
200+
}, e -> { log.error("Failed to refresh ML key", e); }));
201+
}
185202

186203
return new MLSyncUpNodeResponse(
187204
clusterService.localNode(),

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1156,7 +1156,8 @@ public List<Setting<?>> getSettings() {
11561156
MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED,
11571157
MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED,
11581158
MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED,
1159-
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED
1159+
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED,
1160+
MLCommonsSettings.ML_COMMONS_KEY_REFRESH_ENABLED
11601161
);
11611162
return settings;
11621163
}

plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static org.mockito.ArgumentMatchers.anyLong;
1212
import static org.mockito.ArgumentMatchers.anyString;
1313
import static org.mockito.ArgumentMatchers.eq;
14+
import static org.mockito.Mockito.doNothing;
1415
import static org.mockito.Mockito.never;
1516
import static org.mockito.Mockito.times;
1617
import static org.mockito.Mockito.verify;
@@ -57,13 +58,15 @@
5758
import org.opensearch.ml.common.MLTask;
5859
import org.opensearch.ml.common.MLTaskType;
5960
import org.opensearch.ml.common.model.MLModelState;
61+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
6062
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
6163
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeRequest;
6264
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse;
6365
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
6466
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse;
6567
import org.opensearch.ml.engine.MLEngine;
6668
import org.opensearch.ml.engine.ModelHelper;
69+
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
6770
import org.opensearch.ml.model.MLModelCacheHelper;
6871
import org.opensearch.ml.model.MLModelManager;
6972
import org.opensearch.ml.task.MLTaskCache;
@@ -119,11 +122,19 @@ public class TransportSyncUpOnNodeActionTests extends OpenSearchTestCase {
119122
@Mock
120123
private MLModelCacheHelper mlModelCacheHelper;
121124

125+
@Mock
126+
private EncryptorImpl encryptor;
127+
128+
@Mock
129+
MLFeatureEnabledSetting mlFeatureEnabledSetting;
130+
122131
@Before
123132
public void setup() throws IOException {
124133
MockitoAnnotations.openMocks(this);
125134
mockSettings(true);
126135
when(clusterService.getClusterName()).thenReturn(new ClusterName("Local Cluster"));
136+
when(mlFeatureEnabledSetting.isKeyRefreshEnabled()).thenReturn(false);
137+
doNothing().when(encryptor).refreshMasterKey(any());
127138
action = new TransportSyncUpOnNodeAction(
128139
transportService,
129140
settings,
@@ -136,7 +147,9 @@ public void setup() throws IOException {
136147
client,
137148
xContentRegistry,
138149
mlEngine,
139-
mlModelCacheHelper
150+
mlModelCacheHelper,
151+
encryptor,
152+
mlFeatureEnabledSetting
140153
);
141154
runningDeployModelTasks = new HashMap<>();
142155
runningDeployModelTasks.put("model1", ImmutableSet.of("node1"));

plugin/src/test/java/org/opensearch/ml/settings/MLFeatureEnabledSettingTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED;
1515
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED;
1616
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED;
17+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_KEY_REFRESH_ENABLED;
1718
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED;
1819
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED;
1920
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED;
@@ -70,7 +71,8 @@ public void setUp() {
7071
ML_COMMONS_MCP_SERVER_ENABLED,
7172
ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED,
7273
ML_COMMONS_METRIC_COLLECTION_ENABLED,
73-
ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED
74+
ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED,
75+
ML_COMMONS_KEY_REFRESH_ENABLED
7476
)
7577
)
7678
);

0 commit comments

Comments
 (0)