feat(runtime): MCP lifecycle hardening

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
Jobdori 2026-04-04 00:46:37 +09:00
parent 13015f6428
commit 740ae61e73

View file

@ -124,11 +124,11 @@ pub enum McpPhaseResult {
Failure { Failure {
phase: McpLifecyclePhase, phase: McpLifecyclePhase,
error: McpErrorSurface, error: McpErrorSurface,
recoverable: bool,
}, },
Timeout { Timeout {
phase: McpLifecyclePhase, phase: McpLifecyclePhase,
waited: Duration, waited: Duration,
error: McpErrorSurface,
}, },
} }
@ -200,6 +200,15 @@ impl McpLifecycleState {
fn record_result(&mut self, result: McpPhaseResult) { fn record_result(&mut self, result: McpPhaseResult) {
self.phase_results.push(result); self.phase_results.push(result);
} }
fn can_resume_after_error(&self) -> bool {
match self.phase_results.last() {
Some(McpPhaseResult::Failure { error, .. } | McpPhaseResult::Timeout { error, .. }) => {
error.recoverable
}
_ => false,
}
}
} }
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@ -286,34 +295,42 @@ impl McpLifecycleValidator {
let started = Instant::now(); let started = Instant::now();
if let Some(current_phase) = self.state.current_phase() { if let Some(current_phase) = self.state.current_phase() {
if !Self::validate_phase_transition(current_phase, phase) { if current_phase == McpLifecyclePhase::ErrorSurfacing
return self.record_failure( && phase == McpLifecyclePhase::Ready
phase, && !self.state.can_resume_after_error()
McpErrorSurface::new( {
phase, return self.record_failure(McpErrorSurface::new(
None,
format!("invalid MCP lifecycle transition from {current_phase} to {phase}"),
BTreeMap::from([
("from".to_string(), current_phase.to_string()),
("to".to_string(), phase.to_string()),
]),
false,
),
false,
);
}
} else if phase != McpLifecyclePhase::ConfigLoad {
return self.record_failure(
phase,
McpErrorSurface::new(
phase, phase,
None, None,
format!("invalid initial MCP lifecycle phase {phase}"), "cannot return to ready after a non-recoverable MCP lifecycle failure",
BTreeMap::from([("phase".to_string(), phase.to_string())]), BTreeMap::from([
("from".to_string(), current_phase.to_string()),
("to".to_string(), phase.to_string()),
]),
false, false,
), ));
}
if !Self::validate_phase_transition(current_phase, phase) {
return self.record_failure(McpErrorSurface::new(
phase,
None,
format!("invalid MCP lifecycle transition from {current_phase} to {phase}"),
BTreeMap::from([
("from".to_string(), current_phase.to_string()),
("to".to_string(), phase.to_string()),
]),
false,
));
}
} else if phase != McpLifecyclePhase::ConfigLoad {
return self.record_failure(McpErrorSurface::new(
phase,
None,
format!("invalid initial MCP lifecycle phase {phase}"),
BTreeMap::from([("phase".to_string(), phase.to_string())]),
false, false,
); ));
} }
self.state.record_phase(phase); self.state.record_phase(phase);
@ -325,19 +342,11 @@ impl McpLifecycleValidator {
result result
} }
pub fn record_failure( pub fn record_failure(&mut self, error: McpErrorSurface) -> McpPhaseResult {
&mut self, let phase = error.phase;
phase: McpLifecyclePhase,
error: McpErrorSurface,
recoverable: bool,
) -> McpPhaseResult {
self.state.record_error(error.clone()); self.state.record_error(error.clone());
self.state.record_phase(McpLifecyclePhase::ErrorSurfacing); self.state.record_phase(McpLifecyclePhase::ErrorSurfacing);
let result = McpPhaseResult::Failure { let result = McpPhaseResult::Failure { phase, error };
phase,
error,
recoverable,
};
self.state.record_result(result.clone()); self.state.record_result(result.clone());
result result
} }
@ -360,9 +369,13 @@ impl McpLifecycleValidator {
context, context,
true, true,
); );
self.state.record_error(error); self.state.record_error(error.clone());
self.state.record_phase(McpLifecyclePhase::ErrorSurfacing); self.state.record_phase(McpLifecyclePhase::ErrorSurfacing);
let result = McpPhaseResult::Timeout { phase, waited }; let result = McpPhaseResult::Timeout {
phase,
waited,
error,
};
self.state.record_result(result.clone()); self.state.record_result(result.clone());
result result
} }
@ -545,13 +558,9 @@ mod tests {
// then // then
match result { match result {
McpPhaseResult::Failure { McpPhaseResult::Failure { phase, error } => {
phase,
error,
recoverable,
} => {
assert_eq!(phase, McpLifecyclePhase::Ready); assert_eq!(phase, McpLifecyclePhase::Ready);
assert!(!recoverable); assert!(!error.recoverable);
assert_eq!(error.phase, McpLifecyclePhase::Ready); assert_eq!(error.phase, McpLifecyclePhase::Ready);
assert_eq!( assert_eq!(
error.context.get("from").map(String::as_str), error.context.get("from").map(String::as_str),
@ -581,27 +590,25 @@ mod tests {
// when / then // when / then
for phase in McpLifecyclePhase::all() { for phase in McpLifecyclePhase::all() {
let result = validator.record_failure( let result = validator.record_failure(McpErrorSurface::new(
phase, phase,
McpErrorSurface::new( Some("alpha".to_string()),
phase, format!("failure at {phase}"),
Some("alpha".to_string()), BTreeMap::from([("server".to_string(), "alpha".to_string())]),
format!("failure at {phase}"),
BTreeMap::from([("server".to_string(), "alpha".to_string())]),
phase == McpLifecyclePhase::ResourceDiscovery,
),
phase == McpLifecyclePhase::ResourceDiscovery, phase == McpLifecyclePhase::ResourceDiscovery,
); ));
match result { match result {
McpPhaseResult::Failure { McpPhaseResult::Failure {
phase: failed_phase, phase: failed_phase,
error, error,
recoverable,
} => { } => {
assert_eq!(failed_phase, phase); assert_eq!(failed_phase, phase);
assert_eq!(error.phase, phase); assert_eq!(error.phase, phase);
assert_eq!(recoverable, phase == McpLifecyclePhase::ResourceDiscovery); assert_eq!(
error.recoverable,
phase == McpLifecyclePhase::ResourceDiscovery
);
} }
other => panic!("expected failure result, got {other:?}"), other => panic!("expected failure result, got {other:?}"),
} }
@ -628,9 +635,12 @@ mod tests {
McpPhaseResult::Timeout { McpPhaseResult::Timeout {
phase, phase,
waited: actual, waited: actual,
error,
} => { } => {
assert_eq!(phase, McpLifecyclePhase::SpawnConnect); assert_eq!(phase, McpLifecyclePhase::SpawnConnect);
assert_eq!(actual, waited); assert_eq!(actual, waited);
assert!(error.recoverable);
assert_eq!(error.server_name.as_deref(), Some("alpha"));
} }
other => panic!("expected timeout result, got {other:?}"), other => panic!("expected timeout result, got {other:?}"),
} }
@ -707,17 +717,13 @@ mod tests {
let result = validator.run_phase(phase); let result = validator.run_phase(phase);
assert!(matches!(result, McpPhaseResult::Success { .. })); assert!(matches!(result, McpPhaseResult::Success { .. }));
} }
let _ = validator.record_failure( let _ = validator.record_failure(McpErrorSurface::new(
McpLifecyclePhase::ResourceDiscovery, McpLifecyclePhase::ResourceDiscovery,
McpErrorSurface::new( Some("alpha".to_string()),
McpLifecyclePhase::ResourceDiscovery, "resource listing failed",
Some("alpha".to_string()), BTreeMap::from([("reason".to_string(), "timeout".to_string())]),
"resource listing failed",
BTreeMap::from([("reason".to_string(), "timeout".to_string())]),
true,
),
true, true,
); ));
// when // when
let shutdown = validator.run_phase(McpLifecyclePhase::Shutdown); let shutdown = validator.run_phase(McpLifecyclePhase::Shutdown);
@ -758,4 +764,79 @@ mod tests {
let trait_object: &dyn std::error::Error = &error; let trait_object: &dyn std::error::Error = &error;
assert_eq!(trait_object.to_string(), rendered); assert_eq!(trait_object.to_string(), rendered);
} }
#[test]
fn given_nonrecoverable_failure_when_returning_to_ready_then_validator_rejects_resume() {
// given
let mut validator = McpLifecycleValidator::new();
for phase in [
McpLifecyclePhase::ConfigLoad,
McpLifecyclePhase::ServerRegistration,
McpLifecyclePhase::SpawnConnect,
McpLifecyclePhase::InitializeHandshake,
McpLifecyclePhase::ToolDiscovery,
McpLifecyclePhase::Ready,
] {
let result = validator.run_phase(phase);
assert!(matches!(result, McpPhaseResult::Success { .. }));
}
let _ = validator.record_failure(McpErrorSurface::new(
McpLifecyclePhase::Invocation,
Some("alpha".to_string()),
"tool call corrupted the session",
BTreeMap::from([("reason".to_string(), "invalid frame".to_string())]),
false,
));
// when
let result = validator.run_phase(McpLifecyclePhase::Ready);
// then
match result {
McpPhaseResult::Failure { phase, error } => {
assert_eq!(phase, McpLifecyclePhase::Ready);
assert!(!error.recoverable);
assert!(error.message.contains("non-recoverable"));
}
other => panic!("expected failure result, got {other:?}"),
}
assert_eq!(
validator.state().current_phase(),
Some(McpLifecyclePhase::ErrorSurfacing)
);
}
#[test]
fn given_recoverable_failure_when_returning_to_ready_then_validator_allows_resume() {
// given
let mut validator = McpLifecycleValidator::new();
for phase in [
McpLifecyclePhase::ConfigLoad,
McpLifecyclePhase::ServerRegistration,
McpLifecyclePhase::SpawnConnect,
McpLifecyclePhase::InitializeHandshake,
McpLifecyclePhase::ToolDiscovery,
McpLifecyclePhase::Ready,
] {
let result = validator.run_phase(phase);
assert!(matches!(result, McpPhaseResult::Success { .. }));
}
let _ = validator.record_failure(McpErrorSurface::new(
McpLifecyclePhase::Invocation,
Some("alpha".to_string()),
"tool call failed but can be retried",
BTreeMap::from([("reason".to_string(), "upstream timeout".to_string())]),
true,
));
// when
let result = validator.run_phase(McpLifecyclePhase::Ready);
// then
assert!(matches!(result, McpPhaseResult::Success { .. }));
assert_eq!(
validator.state().current_phase(),
Some(McpLifecyclePhase::Ready)
);
}
} }