Skip to content

Commit 37b4ddb

Browse files
authored
fix: add compatibility handling for non-standard notifications in async_rw (#247)
- Introduced `is_standard_notification` function to check for standard MCP notifications. - Added `try_parse_with_compatibility` function to handle parsing messages with compatibility for non-standard notifications. - Updated the decoder implementation to utilize the new compatibility handling. - Added unit tests for standard notification checks and compatibility function to ensure correct behavior.
1 parent 47518b3 commit 37b4ddb

File tree

1 file changed

+132
-5
lines changed

1 file changed

+132
-5
lines changed

crates/rmcp/src/transport/async_rw.rs

Lines changed: 132 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,68 @@ fn without_carriage_return(s: &[u8]) -> &[u8] {
168168
}
169169
}
170170

171+
/// Check if a notification method is a standard MCP notification
172+
/// should update this when MCP spec is updated about new notifications
173+
fn is_standard_notification(method: &str) -> bool {
174+
matches!(
175+
method,
176+
"notifications/cancelled"
177+
| "notifications/initialized"
178+
| "notifications/message"
179+
| "notifications/progress"
180+
| "notifications/prompts/list_changed"
181+
| "notifications/resources/list_changed"
182+
| "notifications/resources/updated"
183+
| "notifications/roots/list_changed"
184+
| "notifications/tools/list_changed"
185+
)
186+
}
187+
188+
/// Try to parse a message with compatibility handling for non-standard notifications
189+
fn try_parse_with_compatibility<T: serde::de::DeserializeOwned>(
190+
line: &[u8],
191+
context: &str,
192+
) -> Result<Option<T>, JsonRpcMessageCodecError> {
193+
if let Ok(line_str) = std::str::from_utf8(line) {
194+
match serde_json::from_slice(line) {
195+
Ok(item) => Ok(Some(item)),
196+
Err(e) => {
197+
// Check if this is a non-standard notification that should be ignored
198+
if line_str.contains("\"method\":\"notifications/") {
199+
// Extract the method name to check if it's standard
200+
if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(line_str) {
201+
if let Some(method) = json_value.get("method").and_then(|m| m.as_str()) {
202+
if method.starts_with("notifications/")
203+
&& !is_standard_notification(method)
204+
{
205+
tracing::debug!(
206+
"Ignoring non-standard notification {} {}: {}",
207+
method,
208+
context,
209+
line_str
210+
);
211+
return Ok(None); // Skip this message
212+
}
213+
}
214+
}
215+
}
216+
217+
tracing::debug!(
218+
"Failed to parse message {}: {} | Error: {}",
219+
context,
220+
line_str,
221+
e
222+
);
223+
Err(JsonRpcMessageCodecError::Serde(e))
224+
}
225+
}
226+
} else {
227+
serde_json::from_slice(line)
228+
.map(Some)
229+
.map_err(JsonRpcMessageCodecError::Serde)
230+
}
231+
}
232+
171233
#[derive(Debug, Error)]
172234
pub enum JsonRpcMessageCodecError {
173235
#[error("max line length exceeded")]
@@ -234,8 +296,12 @@ impl<T: DeserializeOwned> Decoder for JsonRpcMessageCodec<T> {
234296
let line = buf.split_to(newline_index + 1);
235297
let line = &line[..line.len() - 1];
236298
let line = without_carriage_return(line);
237-
let item =
238-
serde_json::from_slice(line).map_err(JsonRpcMessageCodecError::Serde)?;
299+
300+
// Use compatibility handling function
301+
let item = match try_parse_with_compatibility(line, "decode")? {
302+
Some(item) => item,
303+
None => return Ok(None), // Skip non-standard message
304+
};
239305
return Ok(Some(item));
240306
}
241307
(false, None) if buf.len() > self.max_length => {
@@ -266,8 +332,12 @@ impl<T: DeserializeOwned> Decoder for JsonRpcMessageCodec<T> {
266332
} else {
267333
let line = buf.split_to(buf.len());
268334
let line = without_carriage_return(&line);
269-
let item =
270-
serde_json::from_slice(line).map_err(JsonRpcMessageCodecError::Serde)?;
335+
336+
// Use compatibility handling function
337+
let item = match try_parse_with_compatibility(line, "decode_eof")? {
338+
Some(item) => item,
339+
None => return Ok(None), // Skip non-standard message
340+
};
271341
Some(item)
272342
}
273343
}
@@ -319,7 +389,7 @@ mod test {
319389
{"jsonrpc":"2.0","method":"subtract","params":[23,42],"id":8}
320390
{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":9}
321391
{"jsonrpc":"2.0","method":"subtract","params":[23,42],"id":10}
322-
392+
323393
"#;
324394

325395
let mut cursor = BufReader::new(data.as_bytes());
@@ -379,4 +449,61 @@ mod test {
379449
// Make sure there are no extra lines
380450
assert!(lines.next().is_none());
381451
}
452+
453+
#[test]
454+
fn test_standard_notification_check() {
455+
// Test that all standard notifications are recognized
456+
assert!(is_standard_notification("notifications/cancelled"));
457+
assert!(is_standard_notification("notifications/initialized"));
458+
assert!(is_standard_notification("notifications/progress"));
459+
assert!(is_standard_notification(
460+
"notifications/resources/list_changed"
461+
));
462+
assert!(is_standard_notification("notifications/resources/updated"));
463+
assert!(is_standard_notification(
464+
"notifications/prompts/list_changed"
465+
));
466+
assert!(is_standard_notification("notifications/tools/list_changed"));
467+
assert!(is_standard_notification("notifications/message"));
468+
assert!(is_standard_notification("notifications/roots/list_changed"));
469+
470+
// Test that non-standard notifications are not recognized
471+
assert!(!is_standard_notification("notifications/stderr"));
472+
assert!(!is_standard_notification("notifications/custom"));
473+
assert!(!is_standard_notification("notifications/debug"));
474+
assert!(!is_standard_notification("some/other/method"));
475+
}
476+
477+
#[test]
478+
fn test_compatibility_function() {
479+
// Test the compatibility function directly
480+
let stderr_message =
481+
r#"{"method":"notifications/stderr","params":{"content":"stderr message"}}"#;
482+
let custom_message = r#"{"method":"notifications/custom","params":{"data":"custom"}}"#;
483+
let standard_message =
484+
r#"{"method":"notifications/message","params":{"level":"info","data":"standard"}}"#;
485+
let progress_message = r#"{"method":"notifications/progress","params":{"progressToken":"token","progress":50}}"#;
486+
487+
// Test with valid JSON - all should parse successfully
488+
let result1 =
489+
try_parse_with_compatibility::<serde_json::Value>(stderr_message.as_bytes(), "test");
490+
let result2 =
491+
try_parse_with_compatibility::<serde_json::Value>(custom_message.as_bytes(), "test");
492+
let result3 =
493+
try_parse_with_compatibility::<serde_json::Value>(standard_message.as_bytes(), "test");
494+
let result4 =
495+
try_parse_with_compatibility::<serde_json::Value>(progress_message.as_bytes(), "test");
496+
497+
// All should parse successfully since they're valid JSON
498+
assert!(result1.is_ok());
499+
assert!(result2.is_ok());
500+
assert!(result3.is_ok());
501+
assert!(result4.is_ok());
502+
503+
// Standard notifications should return Some(value)
504+
assert!(result3.unwrap().is_some());
505+
assert!(result4.unwrap().is_some());
506+
507+
println!("Standard notifications are preserved, non-standard are handled gracefully");
508+
}
382509
}

0 commit comments

Comments
 (0)