@@ -32,15 +32,16 @@ use common::{
3232use fms_guardrails_orchestr8:: {
3333 clients:: {
3434 chunker:: MODEL_ID_HEADER_NAME as CHUNKER_MODEL_ID_HEADER_NAME ,
35- detector:: { ContentAnalysisRequest , ContentAnalysisResponse } ,
35+ detector:: { ChatDetectionRequest , ContentAnalysisRequest , ContentAnalysisResponse } ,
3636 openai:: {
3737 ChatCompletion , ChatCompletionChoice , CompletionDetectionWarning , CompletionDetections ,
3838 CompletionInputDetections , CompletionOutputDetections , Content , ContentPart ,
39- ContentType , Message , Role , TokenizeResponse ,
39+ ContentType , Function , Message , Role , StopReason , TokenizeResponse , Tool , ToolCall ,
4040 } ,
4141 } ,
4242 models:: {
43- DetectionWarningReason , DetectorParams , UNSUITABLE_INPUT_MESSAGE , UNSUITABLE_OUTPUT_MESSAGE ,
43+ DetectionResult , DetectionWarningReason , DetectorParams , UNSUITABLE_INPUT_MESSAGE ,
44+ UNSUITABLE_OUTPUT_MESSAGE ,
4445 } ,
4546 orchestrator:: types:: Detection ,
4647 pb:: {
@@ -55,7 +56,7 @@ use serde_json::json;
5556use test_log:: test;
5657use tracing:: debug;
5758
58- use crate :: common:: openai:: TOKENIZE_ENDPOINT ;
59+ use crate :: common:: { detectors :: TEXT_CHAT_DETECTOR_ENDPOINT , openai:: TOKENIZE_ENDPOINT } ;
5960
6061pub mod common;
6162
@@ -891,6 +892,175 @@ async fn output_detections() -> Result<(), anyhow::Error> {
891892 Ok ( ( ) )
892893}
893894
895+ #[ test( tokio:: test) ]
896+ async fn text_chat_detectors_with_tools ( ) -> Result < ( ) , anyhow:: Error > {
897+ let tools: Vec < Tool > = serde_json:: from_value ( json ! ( [ {
898+ "type" : "function" ,
899+ "function" : {
900+ "name" : "get_current_weather" ,
901+ "description" : "Get the current weather in a given location" ,
902+ "parameters" : {
903+ "type" : "object" ,
904+ "properties" : {
905+ "location" : {
906+ "type" : "string" ,
907+ "description" : "The city and state, e.g. San Francisco, CA"
908+ } ,
909+ "unit" : {
910+ "type" : "string" ,
911+ "enum" : [
912+ "celsius" ,
913+ "fahrenheit"
914+ ]
915+ }
916+ } ,
917+ "required" : [
918+ "location"
919+ ]
920+ }
921+ }
922+ } ] ) )
923+ . unwrap ( ) ;
924+ let mut openai_server = MockServer :: new_http ( "openai" ) ;
925+ openai_server. mock ( |when, then| {
926+ when. post ( ) . path ( CHAT_COMPLETIONS_ENDPOINT ) . json ( json ! ( {
927+ "model" : "test-0B" ,
928+ "messages" : vec![ Message {
929+ role: Role :: User ,
930+ content: Some ( Content :: Text ( "What's the weather in Boston today?" . into( ) ) ) ,
931+ ..Default :: default ( )
932+ } ] ,
933+ "tools" : tools. clone( ) ,
934+ } ) ) ;
935+ then. json ( ChatCompletion {
936+ id : "chatcmpl-test" . into ( ) ,
937+ object : "chat.completion" . into ( ) ,
938+ created : 1749227854 ,
939+ model : "test-0B" . into ( ) ,
940+ choices : vec ! [ ChatCompletionChoice {
941+ index: 0 ,
942+ message: Message {
943+ role: Role :: Assistant ,
944+ tool_calls: Some ( vec![ ToolCall {
945+ index: Some ( 0 ) ,
946+ id: "chatcmpl-tool-test" . into( ) ,
947+ r#type: "function" . into( ) ,
948+ function: Some ( Function {
949+ name: "get_current_weather" . into( ) ,
950+ arguments: "{\" location\" : \" Boston, MA\" , \" unit\" : \" fahrenheit\" }"
951+ . into( ) ,
952+ } ) ,
953+ ..Default :: default ( )
954+ } ] ) ,
955+ ..Default :: default ( )
956+ } ,
957+ logprobs: None ,
958+ finish_reason: "tool_calls" . into( ) ,
959+ stop_reason: Some ( StopReason :: Integer ( 128008 ) ) ,
960+ } ] ,
961+ ..Default :: default ( )
962+ } ) ;
963+ } ) ;
964+
965+ // Add assistant message with tool_calls
966+ let messages = vec ! [
967+ Message {
968+ role: Role :: User ,
969+ content: Some ( Content :: Text ( "What's the weather in Boston today?" . into( ) ) ) ,
970+ ..Default :: default ( )
971+ } ,
972+ Message {
973+ role: Role :: Assistant ,
974+ tool_calls: Some ( vec![ ToolCall {
975+ index: Some ( 0 ) ,
976+ id: "chatcmpl-tool-test" . into( ) ,
977+ r#type: "function" . into( ) ,
978+ function: Some ( Function {
979+ name: "get_current_weather" . into( ) ,
980+ arguments: "{\" location\" : \" Boston, MA\" , \" unit\" : \" fahrenheit\" }" . into( ) ,
981+ } ) ,
982+ ..Default :: default ( )
983+ } ] ) ,
984+ ..Default :: default ( )
985+ } ,
986+ ] ;
987+ let mut granite_guardian_text_chat_server = MockServer :: new_http ( "granite_guardian_text_chat" ) ;
988+ granite_guardian_text_chat_server. mock ( |when, then| {
989+ when. post ( )
990+ . path ( TEXT_CHAT_DETECTOR_ENDPOINT )
991+ . header ( "detector-id" , "granite_guardian_text_chat" )
992+ . json ( ChatDetectionRequest {
993+ messages,
994+ tools : tools. clone ( ) ,
995+ detector_params : [ ( "risk_name" , "function_call" ) ] . into_iter ( ) . collect ( ) ,
996+ } ) ;
997+ then. json ( vec ! [ DetectionResult {
998+ detection: "Yes" . into( ) ,
999+ detection_type: "risk" . into( ) ,
1000+ score: 0.974 ,
1001+ metadata: [ ( "confidence" . into( ) , "High" . into( ) ) ] . into( ) ,
1002+ ..Default :: default ( )
1003+ } ] ) ;
1004+ } ) ;
1005+
1006+ let test_server = TestOrchestratorServer :: builder ( )
1007+ . config_path ( ORCHESTRATOR_CONFIG_FILE_PATH )
1008+ . openai_server ( & openai_server)
1009+ . detector_servers ( [ & granite_guardian_text_chat_server] )
1010+ . build ( )
1011+ . await ?;
1012+
1013+ let response = test_server
1014+ . post ( ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT )
1015+ . json ( & json ! ( {
1016+ "model" : "test-0B" ,
1017+ "detectors" : {
1018+ "input" : { } ,
1019+ "output" : {
1020+ "granite_guardian_text_chat" : {
1021+ "risk_name" : "function_call"
1022+ } ,
1023+ } ,
1024+ } ,
1025+ "messages" : [ Message {
1026+ role: Role :: User ,
1027+ content: Some ( Content :: Text ( "What's the weather in Boston today?" . into( ) ) ) ,
1028+ ..Default :: default ( )
1029+ } ] ,
1030+ "tools" : tools. clone( ) ,
1031+ } ) )
1032+ . send ( )
1033+ . await ?;
1034+ assert_eq ! ( response. status( ) , StatusCode :: OK ) ;
1035+
1036+ let completion = response. json :: < ChatCompletion > ( ) . await ?;
1037+
1038+ // Validate detections
1039+ assert_eq ! (
1040+ completion. detections,
1041+ Some ( CompletionDetections {
1042+ input: vec![ ] ,
1043+ output: vec![ CompletionOutputDetections {
1044+ choice_index: 0 ,
1045+ results: vec![ Detection {
1046+ detector_id: Some ( "granite_guardian_text_chat" . into( ) ) ,
1047+ detection_type: "risk" . into( ) ,
1048+ detection: "Yes" . into( ) ,
1049+ score: 0.974 ,
1050+ metadata: [ (
1051+ "confidence" . into( ) ,
1052+ serde_json:: Value :: String ( "High" . into( ) )
1053+ ) ]
1054+ . into( ) ,
1055+ ..Default :: default ( )
1056+ } , ] ,
1057+ } ] ,
1058+ } )
1059+ ) ;
1060+
1061+ Ok ( ( ) )
1062+ }
1063+
8941064// Validates that requests with output detector configured returns propagated errors
8951065// from detector, chunker and completions server when applicable
8961066#[ test( tokio:: test) ]
0 commit comments