Skip to content

Commit 6184e1d

Browse files
committed
Add text_chat_detectors_with_tools test for chat completions unary
Signed-off-by: declark1 <[email protected]>
1 parent 53f0a91 commit 6184e1d

File tree

1 file changed

+174
-4
lines changed

1 file changed

+174
-4
lines changed

tests/chat_completions_unary.rs

Lines changed: 174 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,16 @@ use common::{
3232
use fms_guardrails_orchestr8::{
3333
clients::{
3434
chunker::MODEL_ID_HEADER_NAME as CHUNKER_MODEL_ID_HEADER_NAME,
35-
detector::{ContentAnalysisRequest, ContentAnalysisResponse},
35+
detector::{ChatDetectionRequest, ContentAnalysisRequest, ContentAnalysisResponse},
3636
openai::{
3737
ChatCompletion, ChatCompletionChoice, CompletionDetectionWarning, CompletionDetections,
3838
CompletionInputDetections, CompletionOutputDetections, Content, ContentPart,
39-
ContentType, Message, Role, TokenizeResponse,
39+
ContentType, Function, Message, Role, StopReason, TokenizeResponse, Tool, ToolCall,
4040
},
4141
},
4242
models::{
43-
DetectionWarningReason, DetectorParams, UNSUITABLE_INPUT_MESSAGE, UNSUITABLE_OUTPUT_MESSAGE,
43+
DetectionResult, DetectionWarningReason, DetectorParams, UNSUITABLE_INPUT_MESSAGE,
44+
UNSUITABLE_OUTPUT_MESSAGE,
4445
},
4546
orchestrator::types::Detection,
4647
pb::{
@@ -55,7 +56,7 @@ use serde_json::json;
5556
use test_log::test;
5657
use tracing::debug;
5758

58-
use crate::common::openai::TOKENIZE_ENDPOINT;
59+
use crate::common::{detectors::TEXT_CHAT_DETECTOR_ENDPOINT, openai::TOKENIZE_ENDPOINT};
5960

6061
pub mod common;
6162

@@ -891,6 +892,175 @@ async fn output_detections() -> Result<(), anyhow::Error> {
891892
Ok(())
892893
}
893894

895+
#[test(tokio::test)]
896+
async fn text_chat_detectors_with_tools() -> Result<(), anyhow::Error> {
897+
let tools: Vec<Tool> = serde_json::from_value(json!([{
898+
"type": "function",
899+
"function": {
900+
"name": "get_current_weather",
901+
"description": "Get the current weather in a given location",
902+
"parameters": {
903+
"type": "object",
904+
"properties": {
905+
"location": {
906+
"type": "string",
907+
"description": "The city and state, e.g. San Francisco, CA"
908+
},
909+
"unit": {
910+
"type": "string",
911+
"enum": [
912+
"celsius",
913+
"fahrenheit"
914+
]
915+
}
916+
},
917+
"required": [
918+
"location"
919+
]
920+
}
921+
}
922+
}]))
923+
.unwrap();
924+
let mut openai_server = MockServer::new_http("openai");
925+
openai_server.mock(|when, then| {
926+
when.post().path(CHAT_COMPLETIONS_ENDPOINT).json(json!({
927+
"model": "test-0B",
928+
"messages": vec![Message {
929+
role: Role::User,
930+
content: Some(Content::Text("What's the weather in Boston today?".into())),
931+
..Default::default()
932+
}],
933+
"tools": tools.clone(),
934+
}));
935+
then.json(ChatCompletion {
936+
id: "chatcmpl-test".into(),
937+
object: "chat.completion".into(),
938+
created: 1749227854,
939+
model: "test-0B".into(),
940+
choices: vec![ChatCompletionChoice {
941+
index: 0,
942+
message: Message {
943+
role: Role::Assistant,
944+
tool_calls: Some(vec![ToolCall {
945+
index: Some(0),
946+
id: "chatcmpl-tool-test".into(),
947+
r#type: "function".into(),
948+
function: Some(Function {
949+
name: "get_current_weather".into(),
950+
arguments: "{\"location\": \"Boston, MA\", \"unit\": \"fahrenheit\"}"
951+
.into(),
952+
}),
953+
..Default::default()
954+
}]),
955+
..Default::default()
956+
},
957+
logprobs: None,
958+
finish_reason: "tool_calls".into(),
959+
stop_reason: Some(StopReason::Integer(128008)),
960+
}],
961+
..Default::default()
962+
});
963+
});
964+
965+
// Add assistant message with tool_calls
966+
let messages = vec![
967+
Message {
968+
role: Role::User,
969+
content: Some(Content::Text("What's the weather in Boston today?".into())),
970+
..Default::default()
971+
},
972+
Message {
973+
role: Role::Assistant,
974+
tool_calls: Some(vec![ToolCall {
975+
index: Some(0),
976+
id: "chatcmpl-tool-test".into(),
977+
r#type: "function".into(),
978+
function: Some(Function {
979+
name: "get_current_weather".into(),
980+
arguments: "{\"location\": \"Boston, MA\", \"unit\": \"fahrenheit\"}".into(),
981+
}),
982+
..Default::default()
983+
}]),
984+
..Default::default()
985+
},
986+
];
987+
let mut granite_guardian_text_chat_server = MockServer::new_http("granite_guardian_text_chat");
988+
granite_guardian_text_chat_server.mock(|when, then| {
989+
when.post()
990+
.path(TEXT_CHAT_DETECTOR_ENDPOINT)
991+
.header("detector-id", "granite_guardian_text_chat")
992+
.json(ChatDetectionRequest {
993+
messages,
994+
tools: tools.clone(),
995+
detector_params: [("risk_name", "function_call")].into_iter().collect(),
996+
});
997+
then.json(vec![DetectionResult {
998+
detection: "Yes".into(),
999+
detection_type: "risk".into(),
1000+
score: 0.974,
1001+
metadata: [("confidence".into(), "High".into())].into(),
1002+
..Default::default()
1003+
}]);
1004+
});
1005+
1006+
let test_server = TestOrchestratorServer::builder()
1007+
.config_path(ORCHESTRATOR_CONFIG_FILE_PATH)
1008+
.openai_server(&openai_server)
1009+
.detector_servers([&granite_guardian_text_chat_server])
1010+
.build()
1011+
.await?;
1012+
1013+
let response = test_server
1014+
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
1015+
.json(&json!({
1016+
"model": "test-0B",
1017+
"detectors": {
1018+
"input": {},
1019+
"output": {
1020+
"granite_guardian_text_chat": {
1021+
"risk_name": "function_call"
1022+
},
1023+
},
1024+
},
1025+
"messages": [Message {
1026+
role: Role::User,
1027+
content: Some(Content::Text("What's the weather in Boston today?".into())),
1028+
..Default::default()
1029+
}],
1030+
"tools": tools.clone(),
1031+
}))
1032+
.send()
1033+
.await?;
1034+
assert_eq!(response.status(), StatusCode::OK);
1035+
1036+
let completion = response.json::<ChatCompletion>().await?;
1037+
1038+
// Validate detections
1039+
assert_eq!(
1040+
completion.detections,
1041+
Some(CompletionDetections {
1042+
input: vec![],
1043+
output: vec![CompletionOutputDetections {
1044+
choice_index: 0,
1045+
results: vec![Detection {
1046+
detector_id: Some("granite_guardian_text_chat".into()),
1047+
detection_type: "risk".into(),
1048+
detection: "Yes".into(),
1049+
score: 0.974,
1050+
metadata: [(
1051+
"confidence".into(),
1052+
serde_json::Value::String("High".into())
1053+
)]
1054+
.into(),
1055+
..Default::default()
1056+
},],
1057+
}],
1058+
})
1059+
);
1060+
1061+
Ok(())
1062+
}
1063+
8941064
// Validates that requests with output detector configured returns propagated errors
8951065
// from detector, chunker and completions server when applicable
8961066
#[test(tokio::test)]

0 commit comments

Comments
 (0)