diff --git a/docs/openapi.json b/docs/openapi.json index 63572257536..4eb5d73d234 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1738,9 +1738,10 @@ "oneOf": [ { "type": "object", + "description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions.", "required": [ - "type", - "value" + "value", + "type" ], "properties": { "type": { @@ -1749,16 +1750,21 @@ "json" ] }, - "value": { - "description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions." + "value": {} + }, + "example": { + "properties": { + "location": { + "type": "string" + } } } }, { "type": "object", "required": [ - "type", - "value" + "value", + "type" ], "properties": { "type": { @@ -1775,18 +1781,16 @@ { "type": "object", "required": [ - "type", - "value" + "json_schema", + "type" ], "properties": { + "json_schema": {}, "type": { "type": "string", "enum": [ "json_schema" ] - }, - "value": { - "$ref": "#/components/schemas/JsonSchemaConfig" } } } @@ -1882,22 +1886,6 @@ } } }, - "JsonSchemaConfig": { - "type": "object", - "required": [ - "schema" - ], - "properties": { - "name": { - "type": "string", - "description": "Optional name identifier for the schema", - "nullable": true - }, - "schema": { - "description": "The actual JSON schema definition" - } - } - }, "Message": { "allOf": [ { diff --git a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_openai_style_format.json b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_openai_style_format.json new file mode 100644 index 00000000000..be6bd4f9437 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_openai_style_format.json @@ -0,0 +1,23 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "{\"status\":\".OK.\"}", + "role": "assistant" + } + } + ], + "created": 1750877897, + "id": "", + "model": "google/gemma-3-4b-it", + "object": "chat.completion", + "system_fingerprint": "3.3.4-dev0-native", + "usage": { + "completion_tokens": 8, + "prompt_tokens": 36, + "total_tokens": 44 + } +} diff --git a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_simple_status.json b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_simple_status.json new file mode 100644 index 00000000000..be6bd4f9437 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_simple_status.json @@ -0,0 +1,23 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "{\"status\":\".OK.\"}", + "role": "assistant" + } + } + ], + "created": 1750877897, + "id": "", + "model": "google/gemma-3-4b-it", + "object": "chat.completion", + "system_fingerprint": "3.3.4-dev0-native", + "usage": { + "completion_tokens": 8, + "prompt_tokens": 36, + "total_tokens": 44 + } +} diff --git a/integration-tests/models/test_json_schema_constrain.py b/integration-tests/models/test_json_schema_constrain.py index 65b4a7b8e31..0aa91de01b2 100644 --- a/integration-tests/models/test_json_schema_constrain.py +++ b/integration-tests/models/test_json_schema_constrain.py @@ -207,3 +207,87 @@ async def test_json_schema_stream(model_fixture, response_snapshot): assert isinstance(parsed_content["numCats"], int) assert parsed_content["numCats"] >= 0 assert chunks == response_snapshot + + +status_schema = { + "type": "object", + "properties": {"status": {"type": "string"}}, + "required": ["status"], + "additionalProperties": False, +} + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_json_schema_simple_status(model_fixture, response_snapshot): + """Test simple status JSON schema - TGI format.""" + response = requests.post( + f"{model_fixture.base_url}/v1/chat/completions", + json={ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant. You answer with a JSON output with a status string containing the value 'OK'", + }, + {"role": "user", "content": "Please tell me OK"}, + ], + "seed": 42, + "temperature": 0.0, + "response_format": { + "type": "json_schema", + "value": {"name": "test", "schema": status_schema}, + }, + "max_completion_tokens": 8192, + }, + ) + + result = response.json() + + # Validate response format + content = result["choices"][0]["message"]["content"] + parsed_content = json.loads(content) + + assert "status" in parsed_content + assert isinstance(parsed_content["status"], str) + assert result == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_json_schema_openai_style_format(model_fixture, response_snapshot): + """Test OpenAI-style JSON schema format (should also work now).""" + response = requests.post( + f"{model_fixture.base_url}/v1/chat/completions", + json={ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant. You answer with a JSON output with a status string containing the value 'OK'", + }, + {"role": "user", "content": "Please tell me OK"}, + ], + "seed": 42, + "temperature": 0.0, + "response_format": { + "json_schema": { + "name": "test", + "strict": True, + "schema": status_schema, + }, + "type": "json_schema", + }, + "max_completion_tokens": 8192, + }, + ) + + result = response.json() + + # Validate response format + content = result["choices"][0]["message"]["content"] + parsed_content = json.loads(content) + + assert "status" in parsed_content + assert isinstance(parsed_content["status"], str) + assert result == response_snapshot diff --git a/router/src/lib.rs b/router/src/lib.rs index e5622fc22e8..554c5d2f68a 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -224,18 +224,7 @@ impl HubProcessorConfig { #[derive(Clone, Debug, Deserialize, ToSchema, Serialize)] #[cfg_attr(test, derive(PartialEq))] -struct JsonSchemaConfig { - /// Optional name identifier for the schema - #[serde(skip_serializing_if = "Option::is_none")] - name: Option, - - /// The actual JSON schema definition - schema: serde_json::Value, -} - -#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)] -#[cfg_attr(test, derive(PartialEq))] -#[serde(tag = "type", content = "value")] +#[serde(tag = "type")] pub(crate) enum GrammarType { /// A string that represents a [JSON Schema](https://json-schema.org/). /// @@ -244,17 +233,36 @@ pub(crate) enum GrammarType { #[serde(rename = "json")] #[serde(alias = "json_object")] #[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))] - Json(serde_json::Value), + Json { value: serde_json::Value }, #[serde(rename = "regex")] - Regex(String), + Regex { value: String }, /// A JSON Schema specification with additional metadata. /// /// Includes an optional name for the schema, an optional strict flag, and the required schema definition. #[serde(rename = "json_schema")] - #[schema(example = json ! ({"schema": {"properties": {"name": {"type": "string"}, "age": {"type": "integer"}}}, "name": "person_info", "strict": true}))] - JsonSchema(JsonSchemaConfig), + JsonSchema { + #[serde(alias = "value")] + #[serde(deserialize_with = "custom_json_schema::deserialize_json_schema")] + json_schema: serde_json::Value, + }, +} + +mod custom_json_schema { + use serde::{Deserialize, Deserializer}; + use serde_json::Value; + + pub fn deserialize_json_schema<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value: Value = Deserialize::deserialize(deserializer)?; + value + .get("schema") + .cloned() + .ok_or_else(|| serde::de::Error::custom("Expected a 'schema' field")) + } } #[derive(Clone, Debug, Serialize, ToSchema)] @@ -984,7 +992,9 @@ impl ChatRequest { if let Some(tools) = tools { match ToolGrammar::apply(tools, tool_choice)? { Some((updated_tools, tool_schema)) => { - let grammar = GrammarType::Json(serde_json::json!(tool_schema)); + let grammar = GrammarType::Json { + value: serde_json::json!(tool_schema), + }; let inputs: String = infer.apply_chat_template( messages, Some((updated_tools, tool_prompt)), @@ -1836,3 +1846,80 @@ mod tests { ); } } + +#[cfg(test)] +mod grammar_tests { + use super::*; + use serde_json::json; + + #[test] + fn parse_regex() { + let raw = json!({ + "type": "regex", + "value": "^\\d+$" + }); + let parsed: GrammarType = serde_json::from_value(raw).unwrap(); + + match parsed { + GrammarType::Regex { value } => assert_eq!(value, "^\\d+$"), + _ => panic!("Expected Regex variant"), + } + } + + #[test] + fn parse_json_value() { + let raw = json!({ + "type": "json", + "value": { "enum": ["a", "b"] } + }); + let parsed: GrammarType = serde_json::from_value(raw).unwrap(); + + match parsed { + GrammarType::Json { value } => assert_eq!(value, json!({"enum":["a","b"]})), + _ => panic!("Expected Json variant"), + } + } + + #[test] + fn parse_json_schema() { + let raw = json!({ + "type": "json_schema", + "json_schema": { "schema": {"type":"integer"} } + }); + let parsed: GrammarType = serde_json::from_value(raw).unwrap(); + + match parsed { + GrammarType::JsonSchema { json_schema } => { + assert_eq!(json_schema, json!({"type":"integer"})); + } + _ => panic!("Expected JsonSchema variant"), + } + } + + #[test] + fn parse_regex_ip_address() { + let raw = json!({ + "type": "regex", + "value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)" + }); + let parsed: GrammarType = serde_json::from_value(raw).unwrap(); + + match parsed { + GrammarType::Regex { value } => { + assert!(value.contains("25[0-5]")); + } + _ => panic!("Expected Regex variant"), + } + } + + #[test] + fn parse_invalid_type_should_fail() { + let raw = json!({ + "type": "invalid_type", + "value": "test" + }); + + let result: Result = serde_json::from_value(raw); + assert!(result.is_err()); + } +} diff --git a/router/src/server.rs b/router/src/server.rs index 5fbe0403eec..a45322ae599 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -14,6 +14,7 @@ use crate::sagemaker::{ }; use crate::validation::ValidationError; use crate::vertex::vertex_compatibility; +use crate::ChatTokenizeResponse; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -28,7 +29,6 @@ use crate::{ ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, }; -use crate::{ChatTokenizeResponse, JsonSchemaConfig}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; use crate::{MessageBody, ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; @@ -1362,7 +1362,6 @@ CompatGenerateRequest, SagemakerRequest, GenerateRequest, GrammarType, -JsonSchemaConfig, ChatRequest, Message, MessageContent, diff --git a/router/src/validation.rs b/router/src/validation.rs index 28c7f2f8c4f..5af0ea075ad 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -350,13 +350,13 @@ impl Validation { return Err(ValidationError::Grammar); } let valid_grammar = match grammar { - GrammarType::Json(json) => { - let json = match json { + GrammarType::Json { value } => { + let json = match value { // if value is a string, we need to parse it again to make sure its // a valid json Value::String(s) => serde_json::from_str(&s) .map_err(|e| ValidationError::InvalidGrammar(e.to_string())), - Value::Object(_) => Ok(json), + Value::Object(_) => Ok(value), _ => Err(ValidationError::Grammar), }?; @@ -380,29 +380,28 @@ impl Validation { ValidGrammar::Regex(grammar_regex.to_string()) } - GrammarType::JsonSchema(schema_config) => { + GrammarType::JsonSchema { json_schema } => { // Extract the actual schema for validation - let json = &schema_config.schema; - // Check if the json is a valid JSONSchema - jsonschema::draft202012::meta::validate(json) + jsonschema::draft202012::meta::validate(&json_schema) .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; // The schema can be valid but lack properties. // We need properties for the grammar to be successfully parsed in Python. // Therefore, we must check and throw an error if properties are missing. - json.get("properties") + json_schema + .get("properties") .ok_or(ValidationError::InvalidGrammar( "Grammar must have a 'properties' field".to_string(), ))?; // Do compilation in the router for performance - let grammar_regex = json_schema_to_regex(json, None, json) + let grammar_regex = json_schema_to_regex(&json_schema, None, &json_schema) .map_err(ValidationError::RegexFromSchema)?; ValidGrammar::Regex(grammar_regex.to_string()) } - GrammarType::Regex(regex) => ValidGrammar::Regex(regex), + GrammarType::Regex { value } => ValidGrammar::Regex(value), }; Some(valid_grammar) }