Skip to content

Commit 667be7e

Browse files
authored
fix: missing field in response, steps with comma, and parsing nested jsons (opensearch-project#4138)
Signed-off-by: Pavan Yekbote <[email protected]>
1 parent 78f014b commit 667be7e

File tree

6 files changed

+124
-20
lines changed

6 files changed

+124
-20
lines changed

common/src/main/java/org/opensearch/ml/common/utils/ToolUtils.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,11 @@ public static Map<String, String> extractInputParameters(Map<String, String> par
9393
StringSubstitutor stringSubstitutor = new StringSubstitutor(parameters, "${parameters.", "}");
9494
String input = stringSubstitutor.replace(parameters.get("input"));
9595
extractedParameters.put("input", input);
96-
Map<String, String> inputParameters = gson
97-
.fromJson(input, TypeToken.getParameterized(Map.class, String.class, String.class).getType());
98-
extractedParameters.putAll(inputParameters);
96+
Map<String, Object> parsedInputParameters = gson
97+
.fromJson(input, TypeToken.getParameterized(Map.class, String.class, Object.class).getType());
98+
extractedParameters.putAll(StringUtils.getParameterMap(parsedInputParameters));
9999
} catch (Exception exception) {
100-
log.info("fail extract parameters from key 'input' due to" + exception.getMessage());
100+
log.info("Failed to extract parameters from key 'input'. Falling back to raw input string.", exception);
101101
}
102102
}
103103
return extractedParameters;

common/src/test/java/org/opensearch/ml/common/utils/ToolUtilsTest.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,4 +280,67 @@ public void testFilterToolOutput_ComplexNestedPath() {
280280
// Should contain only the targeted deep value
281281
assertEquals("targetValue", result);
282282
}
283+
284+
@Test
285+
public void testExtractInputParameters_WithJsonInput() {
286+
Map<String, String> parameters = new HashMap<>();
287+
parameters.put("param1", "value1");
288+
parameters.put("input", "{\"key1\": \"jsonValue1\", \"key2\": \"jsonValue2\"}");
289+
290+
Map<String, Object> attributes = new HashMap<>();
291+
292+
Map<String, String> result = ToolUtils.extractInputParameters(parameters, attributes);
293+
294+
assertEquals(4, result.size());
295+
assertEquals("value1", result.get("param1"));
296+
assertEquals("{\"key1\": \"jsonValue1\", \"key2\": \"jsonValue2\"}", result.get("input"));
297+
assertEquals("jsonValue1", result.get("key1"));
298+
assertEquals("jsonValue2", result.get("key2"));
299+
}
300+
301+
@Test
302+
public void testExtractInputParameters_WithParameterSubstitution() {
303+
Map<String, String> parameters = new HashMap<>();
304+
parameters.put("param1", "substitutedValue");
305+
parameters.put("input", "{\"message\": \"Hello ${parameters.param1}\"}");
306+
307+
Map<String, Object> attributes = new HashMap<>();
308+
309+
Map<String, String> result = ToolUtils.extractInputParameters(parameters, attributes);
310+
311+
assertEquals(3, result.size());
312+
assertEquals("substitutedValue", result.get("param1"));
313+
assertEquals("{\"message\": \"Hello substitutedValue\"}", result.get("input"));
314+
assertEquals("Hello substitutedValue", result.get("message"));
315+
}
316+
317+
@Test
318+
public void testExtractInputParameters_WithInvalidJson() {
319+
Map<String, String> parameters = new HashMap<>();
320+
parameters.put("param1", "value1");
321+
parameters.put("input", "invalid json string");
322+
323+
Map<String, Object> attributes = new HashMap<>();
324+
325+
Map<String, String> result = ToolUtils.extractInputParameters(parameters, attributes);
326+
327+
assertEquals(2, result.size());
328+
assertEquals("value1", result.get("param1"));
329+
assertEquals("invalid json string", result.get("input"));
330+
}
331+
332+
@Test
333+
public void testExtractInputParameters_NoInputParameter() {
334+
Map<String, String> parameters = new HashMap<>();
335+
parameters.put("param1", "value1");
336+
parameters.put("param2", "value2");
337+
338+
Map<String, Object> attributes = new HashMap<>();
339+
340+
Map<String, String> result = ToolUtils.extractInputParameters(parameters, attributes);
341+
342+
assertEquals(2, result.size());
343+
assertEquals("value1", result.get("param1"));
344+
assertEquals("value2", result.get("param2"));
345+
}
283346
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,11 @@ public static Map<String, String> parseLLMOutput(
361361
if (functionCalling != null) {
362362
toolCalls = functionCalling.handle(tmpModelTensorOutput, parameters);
363363
// TODO: support multiple tool calls here
364-
toolName = toolCalls.getFirst().get("tool_name");
365-
toolInput = toolCalls.getFirst().get("tool_input");
366-
toolCallId = toolCalls.getFirst().get("tool_call_id");
364+
if (!toolCalls.isEmpty()) {
365+
toolName = toolCalls.getFirst().get("tool_name");
366+
toolInput = toolCalls.getFirst().get("tool_input");
367+
toolCallId = toolCalls.getFirst().get("tool_call_id");
368+
}
367369
} else {
368370
String toolCallsPath = parameters.get(TOOL_CALLS_PATH);
369371
if (toolCallsPath.startsWith("_llm_response.")) {
@@ -372,9 +374,11 @@ public static Map<String, String> parseLLMOutput(
372374
} else {
373375
toolCalls = JsonPath.read(dataAsMap, toolCallsPath);
374376
}
375-
toolName = JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALLS_TOOL_NAME));
376-
toolInput = StringUtils.toJson(JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALLS_TOOL_INPUT)));
377-
toolCallId = JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALL_ID_PATH));
377+
if (!toolCalls.isEmpty()) {
378+
toolName = JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALLS_TOOL_NAME));
379+
toolInput = StringUtils.toJson(JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALLS_TOOL_INPUT)));
380+
toolCallId = JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALL_ID_PATH));
381+
}
378382
}
379383
String toolCallsMsgPath = parameters.get(INTERACTION_TEMPLATE_ASSISTANT_TOOL_CALLS_PATH);
380384
String toolCallsMsgExcludePath = parameters.get(INTERACTION_TEMPLATE_ASSISTANT_TOOL_CALLS_EXCLUDE_PATH);

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT;
3535

3636
import java.util.ArrayList;
37-
import java.util.Arrays;
3837
import java.util.HashMap;
3938
import java.util.List;
4039
import java.util.Locale;
@@ -401,10 +400,10 @@ private void executePlanningLoop(
401400

402401
planListener.whenComplete(llmOutput -> {
403402
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) llmOutput.getOutput();
404-
Map<String, String> parseLLMOutput = parseLLMOutput(allParams, modelTensorOutput);
403+
Map<String, Object> parseLLMOutput = parseLLMOutput(allParams, modelTensorOutput);
405404

406405
if (parseLLMOutput.get(RESULT_FIELD) != null) {
407-
String finalResult = parseLLMOutput.get(RESULT_FIELD);
406+
String finalResult = (String) parseLLMOutput.get(RESULT_FIELD);
408407
saveAndReturnFinalResult(
409408
(ConversationIndexMemory) memory,
410409
parentInteractionId,
@@ -415,8 +414,7 @@ private void executePlanningLoop(
415414
finalListener
416415
);
417416
} else {
418-
// todo: optimize double conversion of steps (string to list to string)
419-
List<String> steps = Arrays.stream(parseLLMOutput.get(STEPS_FIELD).split(", ")).toList();
417+
List<String> steps = (List<String>) parseLLMOutput.get(STEPS_FIELD);
420418
addSteps(steps, allParams, STEPS_FIELD);
421419

422420
String stepToExecute = steps.getFirst();
@@ -546,8 +544,8 @@ private void executePlanningLoop(
546544
}
547545

548546
@VisibleForTesting
549-
Map<String, String> parseLLMOutput(Map<String, String> allParams, ModelTensorOutput modelTensorOutput) {
550-
Map<String, String> modelOutput = new HashMap<>();
547+
Map<String, Object> parseLLMOutput(Map<String, String> allParams, ModelTensorOutput modelTensorOutput) {
548+
Map<String, Object> modelOutput = new HashMap<>();
551549
Map<String, ?> dataAsMap = modelTensorOutput.getMlModelOutputs().getFirst().getMlModelTensors().getFirst().getDataAsMap();
552550
String llmResponse;
553551
if (dataAsMap.size() == 1 && dataAsMap.containsKey(RESPONSE_FIELD)) {
@@ -571,7 +569,7 @@ Map<String, String> parseLLMOutput(Map<String, String> allParams, ModelTensorOut
571569

572570
if (parsedJson.containsKey(STEPS_FIELD)) {
573571
List<String> steps = (List<String>) parsedJson.get(STEPS_FIELD);
574-
modelOutput.put(STEPS_FIELD, String.join(", ", steps));
572+
modelOutput.put(STEPS_FIELD, steps);
575573
}
576574

577575
if (parsedJson.containsKey(RESULT_FIELD)) {

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1861,4 +1861,41 @@ public void testCreateTool_WithNullRuntimeResources() {
18611861

18621862
verify(factory).create(argThat(toolParamsMap -> ((Map<String, Object>) toolParamsMap).get("param1").equals("value1")));
18631863
}
1864+
1865+
@Test
1866+
public void testParseLLMOutput_PathNotFoundExceptionWithEmptyToolCalls() {
1867+
Map<String, String> parameters = new HashMap<>();
1868+
parameters.put(LLM_RESPONSE_FILTER, "$.output.message.content[0].text");
1869+
parameters.put(TOOL_CALLS_PATH, "$.output.message.content[*].toolUse");
1870+
parameters.put(TOOL_CALLS_TOOL_NAME, "name");
1871+
parameters.put(TOOL_CALLS_TOOL_INPUT, "input");
1872+
parameters.put(TOOL_CALL_ID_PATH, "toolUseId");
1873+
parameters.put(LLM_FINISH_REASON_PATH, "$.stopReason");
1874+
parameters.put(LLM_FINISH_REASON_TOOL_USE, "tool_use");
1875+
1876+
Map<String, Object> dataAsMap = Map
1877+
.of("output", Map.of("message", Map.of("content", Collections.emptyList(), "role", "assistant")), "stopReason", "end_turn");
1878+
1879+
ModelTensorOutput modelTensorOutput = ModelTensorOutput
1880+
.builder()
1881+
.mlModelOutputs(
1882+
List
1883+
.of(
1884+
ModelTensors
1885+
.builder()
1886+
.mlModelTensors(List.of(ModelTensor.builder().name("response").dataAsMap(dataAsMap).build()))
1887+
.build()
1888+
)
1889+
)
1890+
.build();
1891+
1892+
Map<String, String> output = AgentUtils
1893+
.parseLLMOutput(parameters, modelTensorOutput, null, Set.of("test_tool"), new ArrayList<>(), null);
1894+
1895+
Assert.assertEquals("", output.get(THOUGHT));
1896+
Assert.assertEquals("", output.get(ACTION_INPUT));
1897+
Assert.assertEquals("", output.get(TOOL_CALL_ID));
1898+
Assert.assertTrue(output.containsKey(FINAL_ANSWER));
1899+
Assert.assertTrue(output.get(FINAL_ANSWER).contains("[]"));
1900+
}
18641901
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,9 +525,11 @@ public void testParseLLMOutput() {
525525
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
526526
ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
527527

528-
Map<String, String> result = mlPlanExecuteAndReflectAgentRunner.parseLLMOutput(allParams, modelTensorOutput);
528+
Map<String, Object> result = mlPlanExecuteAndReflectAgentRunner.parseLLMOutput(allParams, modelTensorOutput);
529529

530-
assertEquals("step1, step2", result.get(MLPlanExecuteAndReflectAgentRunner.STEPS_FIELD));
530+
List<String> expectedSteps = Arrays.asList("step1", "step2");
531+
List<String> actualSteps = (List<String>) result.get(MLPlanExecuteAndReflectAgentRunner.STEPS_FIELD);
532+
assertEquals(expectedSteps, actualSteps);
531533
assertEquals("final result", result.get(MLPlanExecuteAndReflectAgentRunner.RESULT_FIELD));
532534

533535
modelTensor = ModelTensor.builder().dataAsMap(Map.of(MLPlanExecuteAndReflectAgentRunner.RESPONSE_FIELD, "random response")).build();

0 commit comments

Comments
 (0)