Compare commits
1 commit
main
...
gaebal/roa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4e97c6976 |
21 changed files with 99 additions and 7025 deletions
30
ROADMAP.md
30
ROADMAP.md
|
|
@ -268,28 +268,14 @@ Acceptance:
|
||||||
|
|
||||||
## Immediate Backlog (from current real pain)
|
## Immediate Backlog (from current real pain)
|
||||||
|
|
||||||
Priority order: P0 = blocks CI/green state, P1 = blocks integration wiring, P2 = clawability hardening, P3 = swarm-efficiency improvements.
|
1. Worker readiness handshake + trust resolution
|
||||||
|
2. Prompt misdelivery detection and recovery
|
||||||
**P0 — Fix first (CI reliability)**
|
3. Canonical lane event schema in clawhip
|
||||||
1. Isolate `render_diff_report` tests into tmpdir — flaky under `cargo test --workspace`; reads real working-tree state; breaks CI during active worktree ops
|
4. Failure taxonomy + blocker normalization
|
||||||
|
5. Stale-branch detection before workspace tests
|
||||||
**P1 — Next (integration wiring, unblocks verification)**
|
6. MCP structured degraded-startup reporting
|
||||||
2. Add cross-module integration tests — every Phase 1-2 module has unit tests but no integration test connects adjacent modules; wiring gaps are invisible to CI without these
|
7. Structured task packet format
|
||||||
3. Wire lane-completion emitter — `LaneContext::completed` is a passive bool; nothing sets it automatically; need a runtime path from push+green+session-done to policy engine lane-closeout
|
8. Lane board / machine-readable status API
|
||||||
4. Wire `SummaryCompressor` into the lane event pipeline — exported but called nowhere; `LaneEvent` stream never fed through compressor
|
|
||||||
|
|
||||||
**P2 — Clawability hardening (original backlog)**
|
|
||||||
5. Worker readiness handshake + trust resolution
|
|
||||||
6. Prompt misdelivery detection and recovery
|
|
||||||
7. Canonical lane event schema in clawhip
|
|
||||||
8. Failure taxonomy + blocker normalization
|
|
||||||
9. Stale-branch detection before workspace tests
|
|
||||||
10. MCP structured degraded-startup reporting
|
|
||||||
11. Structured task packet format
|
|
||||||
12. Lane board / machine-readable status API
|
|
||||||
|
|
||||||
**P3 — Swarm efficiency**
|
|
||||||
13. Swarm branch-lock protocol — detect same-module/same-branch collision before parallel workers drift into duplicate implementation
|
|
||||||
|
|
||||||
## Suggested Session Split
|
## Suggested Session Split
|
||||||
|
|
||||||
|
|
|
||||||
1
rust/Cargo.lock
generated
1
rust/Cargo.lock
generated
|
|
@ -1208,7 +1208,6 @@ dependencies = [
|
||||||
"pulldown-cmark",
|
"pulldown-cmark",
|
||||||
"runtime",
|
"runtime",
|
||||||
"rustyline",
|
"rustyline",
|
||||||
"serde",
|
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"syntect",
|
"syntect",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
|
|
||||||
|
|
@ -1,152 +0,0 @@
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum GreenLevel {
|
|
||||||
TargetedTests,
|
|
||||||
Package,
|
|
||||||
Workspace,
|
|
||||||
MergeReady,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GreenLevel {
|
|
||||||
#[must_use]
|
|
||||||
pub fn as_str(self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Self::TargetedTests => "targeted_tests",
|
|
||||||
Self::Package => "package",
|
|
||||||
Self::Workspace => "workspace",
|
|
||||||
Self::MergeReady => "merge_ready",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for GreenLevel {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "{}", self.as_str())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
pub struct GreenContract {
|
|
||||||
pub required_level: GreenLevel,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GreenContract {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new(required_level: GreenLevel) -> Self {
|
|
||||||
Self { required_level }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn evaluate(self, observed_level: Option<GreenLevel>) -> GreenContractOutcome {
|
|
||||||
match observed_level {
|
|
||||||
Some(level) if level >= self.required_level => GreenContractOutcome::Satisfied {
|
|
||||||
required_level: self.required_level,
|
|
||||||
observed_level: level,
|
|
||||||
},
|
|
||||||
_ => GreenContractOutcome::Unsatisfied {
|
|
||||||
required_level: self.required_level,
|
|
||||||
observed_level,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn is_satisfied_by(self, observed_level: GreenLevel) -> bool {
|
|
||||||
observed_level >= self.required_level
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(tag = "outcome", rename_all = "snake_case")]
|
|
||||||
pub enum GreenContractOutcome {
|
|
||||||
Satisfied {
|
|
||||||
required_level: GreenLevel,
|
|
||||||
observed_level: GreenLevel,
|
|
||||||
},
|
|
||||||
Unsatisfied {
|
|
||||||
required_level: GreenLevel,
|
|
||||||
observed_level: Option<GreenLevel>,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GreenContractOutcome {
|
|
||||||
#[must_use]
|
|
||||||
pub fn is_satisfied(&self) -> bool {
|
|
||||||
matches!(self, Self::Satisfied { .. })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn given_matching_level_when_evaluating_contract_then_it_is_satisfied() {
|
|
||||||
// given
|
|
||||||
let contract = GreenContract::new(GreenLevel::Package);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let outcome = contract.evaluate(Some(GreenLevel::Package));
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(
|
|
||||||
outcome,
|
|
||||||
GreenContractOutcome::Satisfied {
|
|
||||||
required_level: GreenLevel::Package,
|
|
||||||
observed_level: GreenLevel::Package,
|
|
||||||
}
|
|
||||||
);
|
|
||||||
assert!(outcome.is_satisfied());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn given_higher_level_when_checking_requirement_then_it_still_satisfies_contract() {
|
|
||||||
// given
|
|
||||||
let contract = GreenContract::new(GreenLevel::TargetedTests);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let is_satisfied = contract.is_satisfied_by(GreenLevel::Workspace);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert!(is_satisfied);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn given_lower_level_when_evaluating_contract_then_it_is_unsatisfied() {
|
|
||||||
// given
|
|
||||||
let contract = GreenContract::new(GreenLevel::Workspace);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let outcome = contract.evaluate(Some(GreenLevel::Package));
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(
|
|
||||||
outcome,
|
|
||||||
GreenContractOutcome::Unsatisfied {
|
|
||||||
required_level: GreenLevel::Workspace,
|
|
||||||
observed_level: Some(GreenLevel::Package),
|
|
||||||
}
|
|
||||||
);
|
|
||||||
assert!(!outcome.is_satisfied());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn given_no_green_level_when_evaluating_contract_then_contract_is_unsatisfied() {
|
|
||||||
// given
|
|
||||||
let contract = GreenContract::new(GreenLevel::MergeReady);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let outcome = contract.evaluate(None);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(
|
|
||||||
outcome,
|
|
||||||
GreenContractOutcome::Unsatisfied {
|
|
||||||
required_level: GreenLevel::MergeReady,
|
|
||||||
observed_level: None,
|
|
||||||
}
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -5,35 +5,24 @@ mod compact;
|
||||||
mod config;
|
mod config;
|
||||||
mod conversation;
|
mod conversation;
|
||||||
mod file_ops;
|
mod file_ops;
|
||||||
pub mod green_contract;
|
|
||||||
mod hooks;
|
mod hooks;
|
||||||
mod json;
|
mod json;
|
||||||
pub mod lsp_client;
|
pub mod lsp_client;
|
||||||
mod mcp;
|
mod mcp;
|
||||||
mod mcp_client;
|
mod mcp_client;
|
||||||
pub mod mcp_lifecycle_hardened;
|
|
||||||
mod mcp_stdio;
|
mod mcp_stdio;
|
||||||
pub mod mcp_tool_bridge;
|
pub mod mcp_tool_bridge;
|
||||||
mod oauth;
|
mod oauth;
|
||||||
pub mod permission_enforcer;
|
pub mod permission_enforcer;
|
||||||
mod policy_engine;
|
|
||||||
pub mod recovery_recipes;
|
|
||||||
mod permissions;
|
mod permissions;
|
||||||
pub mod plugin_lifecycle;
|
|
||||||
mod prompt;
|
mod prompt;
|
||||||
mod remote;
|
mod remote;
|
||||||
pub mod session_control;
|
|
||||||
pub mod sandbox;
|
pub mod sandbox;
|
||||||
mod session;
|
mod session;
|
||||||
mod sse;
|
mod sse;
|
||||||
pub mod stale_branch;
|
|
||||||
pub mod summary_compression;
|
|
||||||
pub mod task_registry;
|
pub mod task_registry;
|
||||||
pub mod task_packet;
|
|
||||||
pub mod team_cron_registry;
|
pub mod team_cron_registry;
|
||||||
pub mod trust_resolver;
|
|
||||||
mod usage;
|
mod usage;
|
||||||
pub mod worker_boot;
|
|
||||||
|
|
||||||
pub use bash::{execute_bash, BashCommandInput, BashCommandOutput};
|
pub use bash::{execute_bash, BashCommandInput, BashCommandOutput};
|
||||||
pub use bootstrap::{BootstrapPhase, BootstrapPlan};
|
pub use bootstrap::{BootstrapPhase, BootstrapPlan};
|
||||||
|
|
@ -70,18 +59,13 @@ pub use mcp_client::{
|
||||||
McpClientAuth, McpClientBootstrap, McpClientTransport, McpManagedProxyTransport,
|
McpClientAuth, McpClientBootstrap, McpClientTransport, McpManagedProxyTransport,
|
||||||
McpRemoteTransport, McpSdkTransport, McpStdioTransport,
|
McpRemoteTransport, McpSdkTransport, McpStdioTransport,
|
||||||
};
|
};
|
||||||
pub use mcp_lifecycle_hardened::{
|
|
||||||
McpDegradedReport, McpErrorSurface, McpFailedServer, McpLifecyclePhase, McpLifecycleState,
|
|
||||||
McpLifecycleValidator, McpPhaseResult,
|
|
||||||
};
|
|
||||||
pub use mcp_stdio::{
|
pub use mcp_stdio::{
|
||||||
spawn_mcp_stdio_process, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
|
spawn_mcp_stdio_process, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
|
||||||
ManagedMcpTool, McpDiscoveryFailure, McpInitializeClientInfo, McpInitializeParams,
|
ManagedMcpTool, McpInitializeClientInfo, McpInitializeParams, McpInitializeResult,
|
||||||
McpInitializeResult, McpInitializeServerInfo, McpListResourcesParams, McpListResourcesResult,
|
McpInitializeServerInfo, McpListResourcesParams, McpListResourcesResult, McpListToolsParams,
|
||||||
McpListToolsParams, McpListToolsResult, McpReadResourceParams, McpReadResourceResult,
|
McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpResource,
|
||||||
McpResource, McpResourceContents, McpServerManager, McpServerManagerError, McpStdioProcess,
|
McpResourceContents, McpServerManager, McpServerManagerError, McpStdioProcess, McpTool,
|
||||||
McpTool, McpToolCallContent, McpToolCallParams, McpToolCallResult, McpToolDiscoveryReport,
|
McpToolCallContent, McpToolCallParams, McpToolCallResult, UnsupportedMcpServer,
|
||||||
UnsupportedMcpServer,
|
|
||||||
};
|
};
|
||||||
pub use oauth::{
|
pub use oauth::{
|
||||||
clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
|
clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
|
||||||
|
|
@ -90,22 +74,10 @@ pub use oauth::{
|
||||||
OAuthCallbackParams, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
|
OAuthCallbackParams, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
|
||||||
PkceChallengeMethod, PkceCodePair,
|
PkceChallengeMethod, PkceCodePair,
|
||||||
};
|
};
|
||||||
pub use policy_engine::{
|
|
||||||
evaluate, DiffScope, GreenLevel, LaneBlocker, LaneContext, PolicyAction, PolicyCondition,
|
|
||||||
PolicyEngine, PolicyRule, ReviewStatus,
|
|
||||||
};
|
|
||||||
pub use permissions::{
|
pub use permissions::{
|
||||||
PermissionContext, PermissionMode, PermissionOutcome, PermissionOverride, PermissionPolicy,
|
PermissionContext, PermissionMode, PermissionOutcome, PermissionOverride, PermissionPolicy,
|
||||||
PermissionPromptDecision, PermissionPrompter, PermissionRequest,
|
PermissionPromptDecision, PermissionPrompter, PermissionRequest,
|
||||||
};
|
};
|
||||||
pub use plugin_lifecycle::{
|
|
||||||
DegradedMode, DiscoveryResult, PluginHealthcheck, PluginLifecycle, PluginLifecycleEvent,
|
|
||||||
PluginState, ResourceInfo, ServerHealth, ServerStatus, ToolInfo,
|
|
||||||
};
|
|
||||||
pub use recovery_recipes::{
|
|
||||||
attempt_recovery, recipe_for, EscalationPolicy, FailureScenario, RecoveryContext,
|
|
||||||
RecoveryEvent, RecoveryRecipe, RecoveryResult, RecoveryStep,
|
|
||||||
};
|
|
||||||
pub use prompt::{
|
pub use prompt::{
|
||||||
load_system_prompt, prepend_bullets, ContextFile, ProjectContext, PromptBuildError,
|
load_system_prompt, prepend_bullets, ContextFile, ProjectContext, PromptBuildError,
|
||||||
SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY,
|
SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY,
|
||||||
|
|
@ -125,24 +97,10 @@ pub use session::{
|
||||||
ContentBlock, ConversationMessage, MessageRole, Session, SessionCompaction, SessionError,
|
ContentBlock, ConversationMessage, MessageRole, Session, SessionCompaction, SessionError,
|
||||||
SessionFork,
|
SessionFork,
|
||||||
};
|
};
|
||||||
pub use stale_branch::{
|
|
||||||
apply_policy, check_freshness, BranchFreshness, StaleBranchAction, StaleBranchEvent,
|
|
||||||
StaleBranchPolicy,
|
|
||||||
};
|
|
||||||
pub use sse::{IncrementalSseParser, SseEvent};
|
pub use sse::{IncrementalSseParser, SseEvent};
|
||||||
pub use task_packet::{
|
|
||||||
validate_packet, AcceptanceTest, BranchPolicy, CommitPolicy,
|
|
||||||
RepoConfig, ReportingContract, TaskPacket, TaskPacketValidationError, TaskScope,
|
|
||||||
ValidatedPacket,
|
|
||||||
};
|
|
||||||
pub use usage::{
|
pub use usage::{
|
||||||
format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker,
|
format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker,
|
||||||
};
|
};
|
||||||
pub use trust_resolver::{TrustConfig, TrustDecision, TrustEvent, TrustPolicy, TrustResolver};
|
|
||||||
pub use worker_boot::{
|
|
||||||
Worker, WorkerEvent, WorkerEventKind, WorkerFailure, WorkerFailureKind, WorkerReadySnapshot,
|
|
||||||
WorkerRegistry, WorkerStatus,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
|
pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||||
|
|
|
||||||
|
|
@ -1,761 +0,0 @@
|
||||||
use std::collections::{BTreeMap, BTreeSet};
|
|
||||||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
fn now_secs() -> u64 {
|
|
||||||
SystemTime::now()
|
|
||||||
.duration_since(UNIX_EPOCH)
|
|
||||||
.unwrap_or_default()
|
|
||||||
.as_secs()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum McpLifecyclePhase {
|
|
||||||
ConfigLoad,
|
|
||||||
ServerRegistration,
|
|
||||||
SpawnConnect,
|
|
||||||
InitializeHandshake,
|
|
||||||
ToolDiscovery,
|
|
||||||
ResourceDiscovery,
|
|
||||||
Ready,
|
|
||||||
Invocation,
|
|
||||||
ErrorSurfacing,
|
|
||||||
Shutdown,
|
|
||||||
Cleanup,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl McpLifecyclePhase {
|
|
||||||
#[must_use]
|
|
||||||
pub fn all() -> [Self; 11] {
|
|
||||||
[
|
|
||||||
Self::ConfigLoad,
|
|
||||||
Self::ServerRegistration,
|
|
||||||
Self::SpawnConnect,
|
|
||||||
Self::InitializeHandshake,
|
|
||||||
Self::ToolDiscovery,
|
|
||||||
Self::ResourceDiscovery,
|
|
||||||
Self::Ready,
|
|
||||||
Self::Invocation,
|
|
||||||
Self::ErrorSurfacing,
|
|
||||||
Self::Shutdown,
|
|
||||||
Self::Cleanup,
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for McpLifecyclePhase {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::ConfigLoad => write!(f, "config_load"),
|
|
||||||
Self::ServerRegistration => write!(f, "server_registration"),
|
|
||||||
Self::SpawnConnect => write!(f, "spawn_connect"),
|
|
||||||
Self::InitializeHandshake => write!(f, "initialize_handshake"),
|
|
||||||
Self::ToolDiscovery => write!(f, "tool_discovery"),
|
|
||||||
Self::ResourceDiscovery => write!(f, "resource_discovery"),
|
|
||||||
Self::Ready => write!(f, "ready"),
|
|
||||||
Self::Invocation => write!(f, "invocation"),
|
|
||||||
Self::ErrorSurfacing => write!(f, "error_surfacing"),
|
|
||||||
Self::Shutdown => write!(f, "shutdown"),
|
|
||||||
Self::Cleanup => write!(f, "cleanup"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
pub struct McpErrorSurface {
|
|
||||||
pub phase: McpLifecyclePhase,
|
|
||||||
pub server_name: Option<String>,
|
|
||||||
pub message: String,
|
|
||||||
pub context: BTreeMap<String, String>,
|
|
||||||
pub recoverable: bool,
|
|
||||||
pub timestamp: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl McpErrorSurface {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new(
|
|
||||||
phase: McpLifecyclePhase,
|
|
||||||
server_name: Option<String>,
|
|
||||||
message: impl Into<String>,
|
|
||||||
context: BTreeMap<String, String>,
|
|
||||||
recoverable: bool,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
phase,
|
|
||||||
server_name,
|
|
||||||
message: message.into(),
|
|
||||||
context,
|
|
||||||
recoverable,
|
|
||||||
timestamp: now_secs(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for McpErrorSurface {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(
|
|
||||||
f,
|
|
||||||
"MCP lifecycle error during {}: {}",
|
|
||||||
self.phase, self.message
|
|
||||||
)?;
|
|
||||||
if let Some(server_name) = &self.server_name {
|
|
||||||
write!(f, " (server: {server_name})")?;
|
|
||||||
}
|
|
||||||
if !self.context.is_empty() {
|
|
||||||
write!(f, " with context {:?}", self.context)?;
|
|
||||||
}
|
|
||||||
if self.recoverable {
|
|
||||||
write!(f, " [recoverable]")?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for McpErrorSurface {}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub enum McpPhaseResult {
|
|
||||||
Success {
|
|
||||||
phase: McpLifecyclePhase,
|
|
||||||
duration: Duration,
|
|
||||||
},
|
|
||||||
Failure {
|
|
||||||
phase: McpLifecyclePhase,
|
|
||||||
error: McpErrorSurface,
|
|
||||||
recoverable: bool,
|
|
||||||
},
|
|
||||||
Timeout {
|
|
||||||
phase: McpLifecyclePhase,
|
|
||||||
waited: Duration,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
impl McpPhaseResult {
|
|
||||||
#[must_use]
|
|
||||||
pub fn phase(&self) -> McpLifecyclePhase {
|
|
||||||
match self {
|
|
||||||
Self::Success { phase, .. }
|
|
||||||
| Self::Failure { phase, .. }
|
|
||||||
| Self::Timeout { phase, .. } => *phase,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
|
||||||
pub struct McpLifecycleState {
|
|
||||||
current_phase: Option<McpLifecyclePhase>,
|
|
||||||
phase_errors: BTreeMap<McpLifecyclePhase, Vec<McpErrorSurface>>,
|
|
||||||
phase_timestamps: BTreeMap<McpLifecyclePhase, u64>,
|
|
||||||
phase_results: Vec<McpPhaseResult>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl McpLifecycleState {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self::default()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn current_phase(&self) -> Option<McpLifecyclePhase> {
|
|
||||||
self.current_phase
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn errors_for_phase(&self, phase: McpLifecyclePhase) -> &[McpErrorSurface] {
|
|
||||||
self.phase_errors
|
|
||||||
.get(&phase)
|
|
||||||
.map(Vec::as_slice)
|
|
||||||
.unwrap_or(&[])
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn results(&self) -> &[McpPhaseResult] {
|
|
||||||
&self.phase_results
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn phase_timestamps(&self) -> &BTreeMap<McpLifecyclePhase, u64> {
|
|
||||||
&self.phase_timestamps
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn phase_timestamp(&self, phase: McpLifecyclePhase) -> Option<u64> {
|
|
||||||
self.phase_timestamps.get(&phase).copied()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn record_phase(&mut self, phase: McpLifecyclePhase) {
|
|
||||||
self.current_phase = Some(phase);
|
|
||||||
self.phase_timestamps.insert(phase, now_secs());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn record_error(&mut self, error: McpErrorSurface) {
|
|
||||||
self.phase_errors
|
|
||||||
.entry(error.phase)
|
|
||||||
.or_default()
|
|
||||||
.push(error);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn record_result(&mut self, result: McpPhaseResult) {
|
|
||||||
self.phase_results.push(result);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
pub struct McpFailedServer {
|
|
||||||
pub server_name: String,
|
|
||||||
pub phase: McpLifecyclePhase,
|
|
||||||
pub error: McpErrorSurface,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
pub struct McpDegradedReport {
|
|
||||||
pub working_servers: Vec<String>,
|
|
||||||
pub failed_servers: Vec<McpFailedServer>,
|
|
||||||
pub available_tools: Vec<String>,
|
|
||||||
pub missing_tools: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl McpDegradedReport {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new(
|
|
||||||
working_servers: Vec<String>,
|
|
||||||
failed_servers: Vec<McpFailedServer>,
|
|
||||||
available_tools: Vec<String>,
|
|
||||||
expected_tools: Vec<String>,
|
|
||||||
) -> Self {
|
|
||||||
let working_servers = dedupe_sorted(working_servers);
|
|
||||||
let available_tools = dedupe_sorted(available_tools);
|
|
||||||
let available_tool_set: BTreeSet<_> = available_tools.iter().cloned().collect();
|
|
||||||
let expected_tools = dedupe_sorted(expected_tools);
|
|
||||||
let missing_tools = expected_tools
|
|
||||||
.into_iter()
|
|
||||||
.filter(|tool| !available_tool_set.contains(tool))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Self {
|
|
||||||
working_servers,
|
|
||||||
failed_servers,
|
|
||||||
available_tools,
|
|
||||||
missing_tools,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
|
||||||
pub struct McpLifecycleValidator {
|
|
||||||
state: McpLifecycleState,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl McpLifecycleValidator {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self::default()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn state(&self) -> &McpLifecycleState {
|
|
||||||
&self.state
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn validate_phase_transition(from: McpLifecyclePhase, to: McpLifecyclePhase) -> bool {
|
|
||||||
match (from, to) {
|
|
||||||
(McpLifecyclePhase::ConfigLoad, McpLifecyclePhase::ServerRegistration)
|
|
||||||
| (McpLifecyclePhase::ServerRegistration, McpLifecyclePhase::SpawnConnect)
|
|
||||||
| (McpLifecyclePhase::SpawnConnect, McpLifecyclePhase::InitializeHandshake)
|
|
||||||
| (McpLifecyclePhase::InitializeHandshake, McpLifecyclePhase::ToolDiscovery)
|
|
||||||
| (McpLifecyclePhase::ToolDiscovery, McpLifecyclePhase::ResourceDiscovery)
|
|
||||||
| (McpLifecyclePhase::ToolDiscovery, McpLifecyclePhase::Ready)
|
|
||||||
| (McpLifecyclePhase::ResourceDiscovery, McpLifecyclePhase::Ready)
|
|
||||||
| (McpLifecyclePhase::Ready, McpLifecyclePhase::Invocation)
|
|
||||||
| (McpLifecyclePhase::Invocation, McpLifecyclePhase::Ready)
|
|
||||||
| (McpLifecyclePhase::ErrorSurfacing, McpLifecyclePhase::Ready)
|
|
||||||
| (McpLifecyclePhase::ErrorSurfacing, McpLifecyclePhase::Shutdown)
|
|
||||||
| (McpLifecyclePhase::Shutdown, McpLifecyclePhase::Cleanup) => true,
|
|
||||||
(_, McpLifecyclePhase::Shutdown) => from != McpLifecyclePhase::Cleanup,
|
|
||||||
(_, McpLifecyclePhase::ErrorSurfacing) => {
|
|
||||||
from != McpLifecyclePhase::Cleanup && from != McpLifecyclePhase::Shutdown
|
|
||||||
}
|
|
||||||
_ => false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn run_phase(&mut self, phase: McpLifecyclePhase) -> McpPhaseResult {
|
|
||||||
let started = Instant::now();
|
|
||||||
|
|
||||||
if let Some(current_phase) = self.state.current_phase() {
|
|
||||||
if !Self::validate_phase_transition(current_phase, phase) {
|
|
||||||
return self.record_failure(
|
|
||||||
phase,
|
|
||||||
McpErrorSurface::new(
|
|
||||||
phase,
|
|
||||||
None,
|
|
||||||
format!("invalid MCP lifecycle transition from {current_phase} to {phase}"),
|
|
||||||
BTreeMap::from([
|
|
||||||
("from".to_string(), current_phase.to_string()),
|
|
||||||
("to".to_string(), phase.to_string()),
|
|
||||||
]),
|
|
||||||
false,
|
|
||||||
),
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} else if phase != McpLifecyclePhase::ConfigLoad {
|
|
||||||
return self.record_failure(
|
|
||||||
phase,
|
|
||||||
McpErrorSurface::new(
|
|
||||||
phase,
|
|
||||||
None,
|
|
||||||
format!("invalid initial MCP lifecycle phase {phase}"),
|
|
||||||
BTreeMap::from([("phase".to_string(), phase.to_string())]),
|
|
||||||
false,
|
|
||||||
),
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.state.record_phase(phase);
|
|
||||||
let result = McpPhaseResult::Success {
|
|
||||||
phase,
|
|
||||||
duration: started.elapsed(),
|
|
||||||
};
|
|
||||||
self.state.record_result(result.clone());
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn record_failure(
|
|
||||||
&mut self,
|
|
||||||
phase: McpLifecyclePhase,
|
|
||||||
error: McpErrorSurface,
|
|
||||||
recoverable: bool,
|
|
||||||
) -> McpPhaseResult {
|
|
||||||
self.state.record_error(error.clone());
|
|
||||||
self.state.record_phase(McpLifecyclePhase::ErrorSurfacing);
|
|
||||||
let result = McpPhaseResult::Failure {
|
|
||||||
phase,
|
|
||||||
error,
|
|
||||||
recoverable,
|
|
||||||
};
|
|
||||||
self.state.record_result(result.clone());
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn record_timeout(
|
|
||||||
&mut self,
|
|
||||||
phase: McpLifecyclePhase,
|
|
||||||
waited: Duration,
|
|
||||||
server_name: Option<String>,
|
|
||||||
mut context: BTreeMap<String, String>,
|
|
||||||
) -> McpPhaseResult {
|
|
||||||
context.insert("waited_ms".to_string(), waited.as_millis().to_string());
|
|
||||||
let error = McpErrorSurface::new(
|
|
||||||
phase,
|
|
||||||
server_name,
|
|
||||||
format!(
|
|
||||||
"MCP lifecycle phase {phase} timed out after {} ms",
|
|
||||||
waited.as_millis()
|
|
||||||
),
|
|
||||||
context,
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
self.state.record_error(error);
|
|
||||||
self.state.record_phase(McpLifecyclePhase::ErrorSurfacing);
|
|
||||||
let result = McpPhaseResult::Timeout { phase, waited };
|
|
||||||
self.state.record_result(result.clone());
|
|
||||||
result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dedupe_sorted(mut values: Vec<String>) -> Vec<String> {
|
|
||||||
values.sort();
|
|
||||||
values.dedup();
|
|
||||||
values
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
use serde_json::json;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn phase_display_matches_serde_name() {
|
|
||||||
// given
|
|
||||||
let phases = McpLifecyclePhase::all();
|
|
||||||
|
|
||||||
// when
|
|
||||||
let serialized = phases
|
|
||||||
.into_iter()
|
|
||||||
.map(|phase| {
|
|
||||||
(
|
|
||||||
phase.to_string(),
|
|
||||||
serde_json::to_value(phase).expect("serialize phase"),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
// then
|
|
||||||
for (display, json_value) in serialized {
|
|
||||||
assert_eq!(json_value, json!(display));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn given_startup_path_when_running_to_cleanup_then_each_control_transition_succeeds() {
|
|
||||||
// given
|
|
||||||
let mut validator = McpLifecycleValidator::new();
|
|
||||||
let phases = [
|
|
||||||
McpLifecyclePhase::ConfigLoad,
|
|
||||||
McpLifecyclePhase::ServerRegistration,
|
|
||||||
McpLifecyclePhase::SpawnConnect,
|
|
||||||
McpLifecyclePhase::InitializeHandshake,
|
|
||||||
McpLifecyclePhase::ToolDiscovery,
|
|
||||||
McpLifecyclePhase::ResourceDiscovery,
|
|
||||||
McpLifecyclePhase::Ready,
|
|
||||||
McpLifecyclePhase::Invocation,
|
|
||||||
McpLifecyclePhase::Ready,
|
|
||||||
McpLifecyclePhase::Shutdown,
|
|
||||||
McpLifecyclePhase::Cleanup,
|
|
||||||
];
|
|
||||||
|
|
||||||
// when
|
|
||||||
let results = phases
|
|
||||||
.into_iter()
|
|
||||||
.map(|phase| validator.run_phase(phase))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert!(results
|
|
||||||
.iter()
|
|
||||||
.all(|result| matches!(result, McpPhaseResult::Success { .. })));
|
|
||||||
assert_eq!(
|
|
||||||
validator.state().current_phase(),
|
|
||||||
Some(McpLifecyclePhase::Cleanup)
|
|
||||||
);
|
|
||||||
for phase in [
|
|
||||||
McpLifecyclePhase::ConfigLoad,
|
|
||||||
McpLifecyclePhase::ServerRegistration,
|
|
||||||
McpLifecyclePhase::SpawnConnect,
|
|
||||||
McpLifecyclePhase::InitializeHandshake,
|
|
||||||
McpLifecyclePhase::ToolDiscovery,
|
|
||||||
McpLifecyclePhase::ResourceDiscovery,
|
|
||||||
McpLifecyclePhase::Ready,
|
|
||||||
McpLifecyclePhase::Invocation,
|
|
||||||
McpLifecyclePhase::Shutdown,
|
|
||||||
McpLifecyclePhase::Cleanup,
|
|
||||||
] {
|
|
||||||
assert!(validator.state().phase_timestamp(phase).is_some());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn given_tool_discovery_when_resource_discovery_is_skipped_then_ready_is_still_allowed() {
|
|
||||||
// given
|
|
||||||
let mut validator = McpLifecycleValidator::new();
|
|
||||||
for phase in [
|
|
||||||
McpLifecyclePhase::ConfigLoad,
|
|
||||||
McpLifecyclePhase::ServerRegistration,
|
|
||||||
McpLifecyclePhase::SpawnConnect,
|
|
||||||
McpLifecyclePhase::InitializeHandshake,
|
|
||||||
McpLifecyclePhase::ToolDiscovery,
|
|
||||||
] {
|
|
||||||
let result = validator.run_phase(phase);
|
|
||||||
assert!(matches!(result, McpPhaseResult::Success { .. }));
|
|
||||||
}
|
|
||||||
|
|
||||||
// when
|
|
||||||
let result = validator.run_phase(McpLifecyclePhase::Ready);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert!(matches!(result, McpPhaseResult::Success { .. }));
|
|
||||||
assert_eq!(
|
|
||||||
validator.state().current_phase(),
|
|
||||||
Some(McpLifecyclePhase::Ready)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn validates_expected_phase_transitions() {
|
|
||||||
// given
|
|
||||||
let valid_transitions = [
|
|
||||||
(
|
|
||||||
McpLifecyclePhase::ConfigLoad,
|
|
||||||
McpLifecyclePhase::ServerRegistration,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
McpLifecyclePhase::ServerRegistration,
|
|
||||||
McpLifecyclePhase::SpawnConnect,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
McpLifecyclePhase::SpawnConnect,
|
|
||||||
McpLifecyclePhase::InitializeHandshake,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
McpLifecyclePhase::InitializeHandshake,
|
|
||||||
McpLifecyclePhase::ToolDiscovery,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
McpLifecyclePhase::ToolDiscovery,
|
|
||||||
McpLifecyclePhase::ResourceDiscovery,
|
|
||||||
),
|
|
||||||
(McpLifecyclePhase::ToolDiscovery, McpLifecyclePhase::Ready),
|
|
||||||
(
|
|
||||||
McpLifecyclePhase::ResourceDiscovery,
|
|
||||||
McpLifecyclePhase::Ready,
|
|
||||||
),
|
|
||||||
(McpLifecyclePhase::Ready, McpLifecyclePhase::Invocation),
|
|
||||||
(McpLifecyclePhase::Invocation, McpLifecyclePhase::Ready),
|
|
||||||
(McpLifecyclePhase::Ready, McpLifecyclePhase::Shutdown),
|
|
||||||
(
|
|
||||||
McpLifecyclePhase::Invocation,
|
|
||||||
McpLifecyclePhase::ErrorSurfacing,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
McpLifecyclePhase::ErrorSurfacing,
|
|
||||||
McpLifecyclePhase::Shutdown,
|
|
||||||
),
|
|
||||||
(McpLifecyclePhase::Shutdown, McpLifecyclePhase::Cleanup),
|
|
||||||
];
|
|
||||||
|
|
||||||
// when / then
|
|
||||||
for (from, to) in valid_transitions {
|
|
||||||
assert!(McpLifecycleValidator::validate_phase_transition(from, to));
|
|
||||||
}
|
|
||||||
assert!(!McpLifecycleValidator::validate_phase_transition(
|
|
||||||
McpLifecyclePhase::Ready,
|
|
||||||
McpLifecyclePhase::ConfigLoad,
|
|
||||||
));
|
|
||||||
assert!(!McpLifecycleValidator::validate_phase_transition(
|
|
||||||
McpLifecyclePhase::Cleanup,
|
|
||||||
McpLifecyclePhase::Ready,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn given_invalid_transition_when_running_phase_then_structured_failure_is_recorded() {
|
|
||||||
// given
|
|
||||||
let mut validator = McpLifecycleValidator::new();
|
|
||||||
let _ = validator.run_phase(McpLifecyclePhase::ConfigLoad);
|
|
||||||
let _ = validator.run_phase(McpLifecyclePhase::ServerRegistration);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let result = validator.run_phase(McpLifecyclePhase::Ready);
|
|
||||||
|
|
||||||
// then
|
|
||||||
match result {
|
|
||||||
McpPhaseResult::Failure {
|
|
||||||
phase,
|
|
||||||
error,
|
|
||||||
recoverable,
|
|
||||||
} => {
|
|
||||||
assert_eq!(phase, McpLifecyclePhase::Ready);
|
|
||||||
assert!(!recoverable);
|
|
||||||
assert_eq!(error.phase, McpLifecyclePhase::Ready);
|
|
||||||
assert_eq!(
|
|
||||||
error.context.get("from").map(String::as_str),
|
|
||||||
Some("server_registration")
|
|
||||||
);
|
|
||||||
assert_eq!(error.context.get("to").map(String::as_str), Some("ready"));
|
|
||||||
}
|
|
||||||
other => panic!("expected failure result, got {other:?}"),
|
|
||||||
}
|
|
||||||
assert_eq!(
|
|
||||||
validator.state().current_phase(),
|
|
||||||
Some(McpLifecyclePhase::ErrorSurfacing)
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
validator
|
|
||||||
.state()
|
|
||||||
.errors_for_phase(McpLifecyclePhase::Ready)
|
|
||||||
.len(),
|
|
||||||
1
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn given_each_phase_when_failure_is_recorded_then_error_is_tracked_per_phase() {
|
|
||||||
// given
|
|
||||||
let mut validator = McpLifecycleValidator::new();
|
|
||||||
|
|
||||||
// when / then
|
|
||||||
for phase in McpLifecyclePhase::all() {
|
|
||||||
let result = validator.record_failure(
|
|
||||||
phase,
|
|
||||||
McpErrorSurface::new(
|
|
||||||
phase,
|
|
||||||
Some("alpha".to_string()),
|
|
||||||
format!("failure at {phase}"),
|
|
||||||
BTreeMap::from([("server".to_string(), "alpha".to_string())]),
|
|
||||||
phase == McpLifecyclePhase::ResourceDiscovery,
|
|
||||||
),
|
|
||||||
phase == McpLifecyclePhase::ResourceDiscovery,
|
|
||||||
);
|
|
||||||
|
|
||||||
match result {
|
|
||||||
McpPhaseResult::Failure {
|
|
||||||
phase: failed_phase,
|
|
||||||
error,
|
|
||||||
recoverable,
|
|
||||||
} => {
|
|
||||||
assert_eq!(failed_phase, phase);
|
|
||||||
assert_eq!(error.phase, phase);
|
|
||||||
assert_eq!(recoverable, phase == McpLifecyclePhase::ResourceDiscovery);
|
|
||||||
}
|
|
||||||
other => panic!("expected failure result, got {other:?}"),
|
|
||||||
}
|
|
||||||
assert_eq!(validator.state().errors_for_phase(phase).len(), 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn given_spawn_connect_timeout_when_recorded_then_waited_duration_is_preserved() {
|
|
||||||
// given
|
|
||||||
let mut validator = McpLifecycleValidator::new();
|
|
||||||
let waited = Duration::from_millis(250);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let result = validator.record_timeout(
|
|
||||||
McpLifecyclePhase::SpawnConnect,
|
|
||||||
waited,
|
|
||||||
Some("alpha".to_string()),
|
|
||||||
BTreeMap::from([("attempt".to_string(), "1".to_string())]),
|
|
||||||
);
|
|
||||||
|
|
||||||
// then
|
|
||||||
match result {
|
|
||||||
McpPhaseResult::Timeout {
|
|
||||||
phase,
|
|
||||||
waited: actual,
|
|
||||||
} => {
|
|
||||||
assert_eq!(phase, McpLifecyclePhase::SpawnConnect);
|
|
||||||
assert_eq!(actual, waited);
|
|
||||||
}
|
|
||||||
other => panic!("expected timeout result, got {other:?}"),
|
|
||||||
}
|
|
||||||
let errors = validator
|
|
||||||
.state()
|
|
||||||
.errors_for_phase(McpLifecyclePhase::SpawnConnect);
|
|
||||||
assert_eq!(errors.len(), 1);
|
|
||||||
assert_eq!(
|
|
||||||
errors[0].context.get("waited_ms").map(String::as_str),
|
|
||||||
Some("250")
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
validator.state().current_phase(),
|
|
||||||
Some(McpLifecyclePhase::ErrorSurfacing)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn given_partial_server_health_when_building_degraded_report_then_missing_tools_are_reported() {
|
|
||||||
// given
|
|
||||||
let failed = vec![McpFailedServer {
|
|
||||||
server_name: "broken".to_string(),
|
|
||||||
phase: McpLifecyclePhase::InitializeHandshake,
|
|
||||||
error: McpErrorSurface::new(
|
|
||||||
McpLifecyclePhase::InitializeHandshake,
|
|
||||||
Some("broken".to_string()),
|
|
||||||
"initialize failed",
|
|
||||||
BTreeMap::from([("reason".to_string(), "broken pipe".to_string())]),
|
|
||||||
false,
|
|
||||||
),
|
|
||||||
}];
|
|
||||||
|
|
||||||
// when
|
|
||||||
let report = McpDegradedReport::new(
|
|
||||||
vec!["alpha".to_string(), "beta".to_string(), "alpha".to_string()],
|
|
||||||
failed,
|
|
||||||
vec![
|
|
||||||
"alpha.echo".to_string(),
|
|
||||||
"beta.search".to_string(),
|
|
||||||
"alpha.echo".to_string(),
|
|
||||||
],
|
|
||||||
vec![
|
|
||||||
"alpha.echo".to_string(),
|
|
||||||
"beta.search".to_string(),
|
|
||||||
"broken.fetch".to_string(),
|
|
||||||
],
|
|
||||||
);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(
|
|
||||||
report.working_servers,
|
|
||||||
vec!["alpha".to_string(), "beta".to_string()]
|
|
||||||
);
|
|
||||||
assert_eq!(report.failed_servers.len(), 1);
|
|
||||||
assert_eq!(report.failed_servers[0].server_name, "broken");
|
|
||||||
assert_eq!(
|
|
||||||
report.available_tools,
|
|
||||||
vec!["alpha.echo".to_string(), "beta.search".to_string()]
|
|
||||||
);
|
|
||||||
assert_eq!(report.missing_tools, vec!["broken.fetch".to_string()]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn given_failure_during_resource_discovery_when_shutting_down_then_cleanup_still_succeeds() {
|
|
||||||
// given
|
|
||||||
let mut validator = McpLifecycleValidator::new();
|
|
||||||
for phase in [
|
|
||||||
McpLifecyclePhase::ConfigLoad,
|
|
||||||
McpLifecyclePhase::ServerRegistration,
|
|
||||||
McpLifecyclePhase::SpawnConnect,
|
|
||||||
McpLifecyclePhase::InitializeHandshake,
|
|
||||||
McpLifecyclePhase::ToolDiscovery,
|
|
||||||
] {
|
|
||||||
let result = validator.run_phase(phase);
|
|
||||||
assert!(matches!(result, McpPhaseResult::Success { .. }));
|
|
||||||
}
|
|
||||||
let _ = validator.record_failure(
|
|
||||||
McpLifecyclePhase::ResourceDiscovery,
|
|
||||||
McpErrorSurface::new(
|
|
||||||
McpLifecyclePhase::ResourceDiscovery,
|
|
||||||
Some("alpha".to_string()),
|
|
||||||
"resource listing failed",
|
|
||||||
BTreeMap::from([("reason".to_string(), "timeout".to_string())]),
|
|
||||||
true,
|
|
||||||
),
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let shutdown = validator.run_phase(McpLifecyclePhase::Shutdown);
|
|
||||||
let cleanup = validator.run_phase(McpLifecyclePhase::Cleanup);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert!(matches!(shutdown, McpPhaseResult::Success { .. }));
|
|
||||||
assert!(matches!(cleanup, McpPhaseResult::Success { .. }));
|
|
||||||
assert_eq!(
|
|
||||||
validator.state().current_phase(),
|
|
||||||
Some(McpLifecyclePhase::Cleanup)
|
|
||||||
);
|
|
||||||
assert!(validator
|
|
||||||
.state()
|
|
||||||
.phase_timestamp(McpLifecyclePhase::ErrorSurfacing)
|
|
||||||
.is_some());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn error_surface_display_includes_phase_server_and_recoverable_flag() {
|
|
||||||
// given
|
|
||||||
let error = McpErrorSurface::new(
|
|
||||||
McpLifecyclePhase::SpawnConnect,
|
|
||||||
Some("alpha".to_string()),
|
|
||||||
"process exited early",
|
|
||||||
BTreeMap::from([("exit_code".to_string(), "1".to_string())]),
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let rendered = error.to_string();
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert!(rendered.contains("spawn_connect"));
|
|
||||||
assert!(rendered.contains("process exited early"));
|
|
||||||
assert!(rendered.contains("server: alpha"));
|
|
||||||
assert!(rendered.contains("recoverable"));
|
|
||||||
let trait_object: &dyn std::error::Error = &error;
|
|
||||||
assert_eq!(trait_object.to_string(), rendered);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -230,19 +230,6 @@ pub struct UnsupportedMcpServer {
|
||||||
pub reason: String,
|
pub reason: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub struct McpDiscoveryFailure {
|
|
||||||
pub server_name: String,
|
|
||||||
pub error: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
|
||||||
pub struct McpToolDiscoveryReport {
|
|
||||||
pub tools: Vec<ManagedMcpTool>,
|
|
||||||
pub failed_servers: Vec<McpDiscoveryFailure>,
|
|
||||||
pub unsupported_servers: Vec<UnsupportedMcpServer>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum McpServerManagerError {
|
pub enum McpServerManagerError {
|
||||||
Io(io::Error),
|
Io(io::Error),
|
||||||
|
|
@ -410,11 +397,6 @@ impl McpServerManager {
|
||||||
&self.unsupported_servers
|
&self.unsupported_servers
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn server_names(&self) -> Vec<String> {
|
|
||||||
self.servers.keys().cloned().collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn discover_tools(&mut self) -> Result<Vec<ManagedMcpTool>, McpServerManagerError> {
|
pub async fn discover_tools(&mut self) -> Result<Vec<ManagedMcpTool>, McpServerManagerError> {
|
||||||
let server_names = self.servers.keys().cloned().collect::<Vec<_>>();
|
let server_names = self.servers.keys().cloned().collect::<Vec<_>>();
|
||||||
let mut discovered_tools = Vec::new();
|
let mut discovered_tools = Vec::new();
|
||||||
|
|
@ -438,43 +420,6 @@ impl McpServerManager {
|
||||||
Ok(discovered_tools)
|
Ok(discovered_tools)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn discover_tools_best_effort(&mut self) -> McpToolDiscoveryReport {
|
|
||||||
let server_names = self.server_names();
|
|
||||||
let mut discovered_tools = Vec::new();
|
|
||||||
let mut failed_servers = Vec::new();
|
|
||||||
|
|
||||||
for server_name in server_names {
|
|
||||||
match self.discover_tools_for_server(&server_name).await {
|
|
||||||
Ok(server_tools) => {
|
|
||||||
self.clear_routes_for_server(&server_name);
|
|
||||||
for tool in server_tools {
|
|
||||||
self.tool_index.insert(
|
|
||||||
tool.qualified_name.clone(),
|
|
||||||
ToolRoute {
|
|
||||||
server_name: tool.server_name.clone(),
|
|
||||||
raw_name: tool.raw_name.clone(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
discovered_tools.push(tool);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(error) => {
|
|
||||||
self.clear_routes_for_server(&server_name);
|
|
||||||
failed_servers.push(McpDiscoveryFailure {
|
|
||||||
server_name,
|
|
||||||
error: error.to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
McpToolDiscoveryReport {
|
|
||||||
tools: discovered_tools,
|
|
||||||
failed_servers,
|
|
||||||
unsupported_servers: self.unsupported_servers.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn call_tool(
|
pub async fn call_tool(
|
||||||
&mut self,
|
&mut self,
|
||||||
qualified_tool_name: &str,
|
qualified_tool_name: &str,
|
||||||
|
|
@ -527,53 +472,6 @@ impl McpServerManager {
|
||||||
response
|
response
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn list_resources(
|
|
||||||
&mut self,
|
|
||||||
server_name: &str,
|
|
||||||
) -> Result<McpListResourcesResult, McpServerManagerError> {
|
|
||||||
let mut attempts = 0;
|
|
||||||
|
|
||||||
loop {
|
|
||||||
match self.list_resources_once(server_name).await {
|
|
||||||
Ok(resources) => return Ok(resources),
|
|
||||||
Err(error) if attempts == 0 && Self::is_retryable_error(&error) => {
|
|
||||||
self.reset_server(server_name).await?;
|
|
||||||
attempts += 1;
|
|
||||||
}
|
|
||||||
Err(error) => {
|
|
||||||
if Self::should_reset_server(&error) {
|
|
||||||
self.reset_server(server_name).await?;
|
|
||||||
}
|
|
||||||
return Err(error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn read_resource(
|
|
||||||
&mut self,
|
|
||||||
server_name: &str,
|
|
||||||
uri: &str,
|
|
||||||
) -> Result<McpReadResourceResult, McpServerManagerError> {
|
|
||||||
let mut attempts = 0;
|
|
||||||
|
|
||||||
loop {
|
|
||||||
match self.read_resource_once(server_name, uri).await {
|
|
||||||
Ok(resource) => return Ok(resource),
|
|
||||||
Err(error) if attempts == 0 && Self::is_retryable_error(&error) => {
|
|
||||||
self.reset_server(server_name).await?;
|
|
||||||
attempts += 1;
|
|
||||||
}
|
|
||||||
Err(error) => {
|
|
||||||
if Self::should_reset_server(&error) {
|
|
||||||
self.reset_server(server_name).await?;
|
|
||||||
}
|
|
||||||
return Err(error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn shutdown(&mut self) -> Result<(), McpServerManagerError> {
|
pub async fn shutdown(&mut self) -> Result<(), McpServerManagerError> {
|
||||||
let server_names = self.servers.keys().cloned().collect::<Vec<_>>();
|
let server_names = self.servers.keys().cloned().collect::<Vec<_>>();
|
||||||
for server_name in server_names {
|
for server_name in server_names {
|
||||||
|
|
@ -725,118 +623,6 @@ impl McpServerManager {
|
||||||
Ok(discovered_tools)
|
Ok(discovered_tools)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn list_resources_once(
|
|
||||||
&mut self,
|
|
||||||
server_name: &str,
|
|
||||||
) -> Result<McpListResourcesResult, McpServerManagerError> {
|
|
||||||
self.ensure_server_ready(server_name).await?;
|
|
||||||
|
|
||||||
let mut resources = Vec::new();
|
|
||||||
let mut cursor = None;
|
|
||||||
loop {
|
|
||||||
let request_id = self.take_request_id();
|
|
||||||
let response = {
|
|
||||||
let server = self.server_mut(server_name)?;
|
|
||||||
let process = server.process.as_mut().ok_or_else(|| {
|
|
||||||
McpServerManagerError::InvalidResponse {
|
|
||||||
server_name: server_name.to_string(),
|
|
||||||
method: "resources/list",
|
|
||||||
details: "server process missing after initialization".to_string(),
|
|
||||||
}
|
|
||||||
})?;
|
|
||||||
Self::run_process_request(
|
|
||||||
server_name,
|
|
||||||
"resources/list",
|
|
||||||
MCP_LIST_TOOLS_TIMEOUT_MS,
|
|
||||||
process.list_resources(
|
|
||||||
request_id,
|
|
||||||
Some(McpListResourcesParams {
|
|
||||||
cursor: cursor.clone(),
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.await?
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(error) = response.error {
|
|
||||||
return Err(McpServerManagerError::JsonRpc {
|
|
||||||
server_name: server_name.to_string(),
|
|
||||||
method: "resources/list",
|
|
||||||
error,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = response
|
|
||||||
.result
|
|
||||||
.ok_or_else(|| McpServerManagerError::InvalidResponse {
|
|
||||||
server_name: server_name.to_string(),
|
|
||||||
method: "resources/list",
|
|
||||||
details: "missing result payload".to_string(),
|
|
||||||
})?;
|
|
||||||
|
|
||||||
resources.extend(result.resources);
|
|
||||||
|
|
||||||
match result.next_cursor {
|
|
||||||
Some(next_cursor) => cursor = Some(next_cursor),
|
|
||||||
None => break,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(McpListResourcesResult {
|
|
||||||
resources,
|
|
||||||
next_cursor: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn read_resource_once(
|
|
||||||
&mut self,
|
|
||||||
server_name: &str,
|
|
||||||
uri: &str,
|
|
||||||
) -> Result<McpReadResourceResult, McpServerManagerError> {
|
|
||||||
self.ensure_server_ready(server_name).await?;
|
|
||||||
|
|
||||||
let request_id = self.take_request_id();
|
|
||||||
let response =
|
|
||||||
{
|
|
||||||
let server = self.server_mut(server_name)?;
|
|
||||||
let process = server.process.as_mut().ok_or_else(|| {
|
|
||||||
McpServerManagerError::InvalidResponse {
|
|
||||||
server_name: server_name.to_string(),
|
|
||||||
method: "resources/read",
|
|
||||||
details: "server process missing after initialization".to_string(),
|
|
||||||
}
|
|
||||||
})?;
|
|
||||||
Self::run_process_request(
|
|
||||||
server_name,
|
|
||||||
"resources/read",
|
|
||||||
MCP_LIST_TOOLS_TIMEOUT_MS,
|
|
||||||
process.read_resource(
|
|
||||||
request_id,
|
|
||||||
McpReadResourceParams {
|
|
||||||
uri: uri.to_string(),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.await?
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(error) = response.error {
|
|
||||||
return Err(McpServerManagerError::JsonRpc {
|
|
||||||
server_name: server_name.to_string(),
|
|
||||||
method: "resources/read",
|
|
||||||
error,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
response
|
|
||||||
.result
|
|
||||||
.ok_or_else(|| McpServerManagerError::InvalidResponse {
|
|
||||||
server_name: server_name.to_string(),
|
|
||||||
method: "resources/read",
|
|
||||||
details: "missing result payload".to_string(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn reset_server(&mut self, server_name: &str) -> Result<(), McpServerManagerError> {
|
async fn reset_server(&mut self, server_name: &str) -> Result<(), McpServerManagerError> {
|
||||||
let mut process = {
|
let mut process = {
|
||||||
let server = self.server_mut(server_name)?;
|
let server = self.server_mut(server_name)?;
|
||||||
|
|
@ -2467,103 +2253,6 @@ mod tests {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn manager_lists_and_reads_resources_from_stdio_servers() {
|
|
||||||
let runtime = Builder::new_current_thread()
|
|
||||||
.enable_all()
|
|
||||||
.build()
|
|
||||||
.expect("runtime");
|
|
||||||
runtime.block_on(async {
|
|
||||||
let script_path = write_mcp_server_script();
|
|
||||||
let root = script_path.parent().expect("script parent");
|
|
||||||
let log_path = root.join("resources.log");
|
|
||||||
let servers = BTreeMap::from([(
|
|
||||||
"alpha".to_string(),
|
|
||||||
manager_server_config(&script_path, "alpha", &log_path),
|
|
||||||
)]);
|
|
||||||
let mut manager = McpServerManager::from_servers(&servers);
|
|
||||||
|
|
||||||
let listed = manager
|
|
||||||
.list_resources("alpha")
|
|
||||||
.await
|
|
||||||
.expect("list resources");
|
|
||||||
assert_eq!(listed.resources.len(), 1);
|
|
||||||
assert_eq!(listed.resources[0].uri, "file://guide.txt");
|
|
||||||
|
|
||||||
let read = manager
|
|
||||||
.read_resource("alpha", "file://guide.txt")
|
|
||||||
.await
|
|
||||||
.expect("read resource");
|
|
||||||
assert_eq!(read.contents.len(), 1);
|
|
||||||
assert_eq!(
|
|
||||||
read.contents[0].text.as_deref(),
|
|
||||||
Some("contents for file://guide.txt")
|
|
||||||
);
|
|
||||||
|
|
||||||
manager.shutdown().await.expect("shutdown");
|
|
||||||
cleanup_script(&script_path);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn manager_discovery_report_keeps_healthy_servers_when_one_server_fails() {
|
|
||||||
let runtime = Builder::new_current_thread()
|
|
||||||
.enable_all()
|
|
||||||
.build()
|
|
||||||
.expect("runtime");
|
|
||||||
runtime.block_on(async {
|
|
||||||
let script_path = write_manager_mcp_server_script();
|
|
||||||
let root = script_path.parent().expect("script parent");
|
|
||||||
let alpha_log = root.join("alpha.log");
|
|
||||||
let servers = BTreeMap::from([
|
|
||||||
(
|
|
||||||
"alpha".to_string(),
|
|
||||||
manager_server_config(&script_path, "alpha", &alpha_log),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"broken".to_string(),
|
|
||||||
ScopedMcpServerConfig {
|
|
||||||
scope: ConfigSource::Local,
|
|
||||||
config: McpServerConfig::Stdio(McpStdioServerConfig {
|
|
||||||
command: "python3".to_string(),
|
|
||||||
args: vec!["-c".to_string(), "import sys; sys.exit(0)".to_string()],
|
|
||||||
env: BTreeMap::new(),
|
|
||||||
tool_call_timeout_ms: None,
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]);
|
|
||||||
let mut manager = McpServerManager::from_servers(&servers);
|
|
||||||
|
|
||||||
let report = manager.discover_tools_best_effort().await;
|
|
||||||
|
|
||||||
assert_eq!(report.tools.len(), 1);
|
|
||||||
assert_eq!(
|
|
||||||
report.tools[0].qualified_name,
|
|
||||||
mcp_tool_name("alpha", "echo")
|
|
||||||
);
|
|
||||||
assert_eq!(report.failed_servers.len(), 1);
|
|
||||||
assert_eq!(report.failed_servers[0].server_name, "broken");
|
|
||||||
assert!(report.failed_servers[0].error.contains("initialize"));
|
|
||||||
|
|
||||||
let response = manager
|
|
||||||
.call_tool(&mcp_tool_name("alpha", "echo"), Some(json!({"text": "ok"})))
|
|
||||||
.await
|
|
||||||
.expect("healthy server should remain callable");
|
|
||||||
assert_eq!(
|
|
||||||
response
|
|
||||||
.result
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|result| result.structured_content.as_ref())
|
|
||||||
.and_then(|value| value.get("echoed")),
|
|
||||||
Some(&json!("ok"))
|
|
||||||
);
|
|
||||||
|
|
||||||
manager.shutdown().await.expect("shutdown");
|
|
||||||
cleanup_script(&script_path);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn manager_records_unsupported_non_stdio_servers_without_panicking() {
|
fn manager_records_unsupported_non_stdio_servers_without_panicking() {
|
||||||
let servers = BTreeMap::from([
|
let servers = BTreeMap::from([
|
||||||
|
|
|
||||||
|
|
@ -184,10 +184,7 @@ impl McpToolRegistry {
|
||||||
let mut manager = manager
|
let mut manager = manager
|
||||||
.lock()
|
.lock()
|
||||||
.map_err(|_| "mcp server manager lock poisoned".to_string())?;
|
.map_err(|_| "mcp server manager lock poisoned".to_string())?;
|
||||||
manager
|
manager.discover_tools().await.map_err(|error| error.to_string())?;
|
||||||
.discover_tools()
|
|
||||||
.await
|
|
||||||
.map_err(|error| error.to_string())?;
|
|
||||||
let response = manager
|
let response = manager
|
||||||
.call_tool(&qualified_tool_name, arguments)
|
.call_tool(&qualified_tool_name, arguments)
|
||||||
.await
|
.await
|
||||||
|
|
@ -830,9 +827,7 @@ mod tests {
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
registry
|
registry
|
||||||
.set_manager(Arc::new(Mutex::new(McpServerManager::from_servers(
|
.set_manager(Arc::new(Mutex::new(McpServerManager::from_servers(&servers))))
|
||||||
&servers,
|
|
||||||
))))
|
|
||||||
.expect("manager should only be set once");
|
.expect("manager should only be set once");
|
||||||
|
|
||||||
let result = registry
|
let result = registry
|
||||||
|
|
|
||||||
|
|
@ -1,532 +0,0 @@
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
use crate::config::RuntimePluginConfig;
|
|
||||||
use crate::mcp_tool_bridge::{McpResourceInfo, McpToolInfo};
|
|
||||||
|
|
||||||
fn now_secs() -> u64 {
|
|
||||||
SystemTime::now()
|
|
||||||
.duration_since(UNIX_EPOCH)
|
|
||||||
.unwrap_or_default()
|
|
||||||
.as_secs()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type ToolInfo = McpToolInfo;
|
|
||||||
pub type ResourceInfo = McpResourceInfo;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum ServerStatus {
|
|
||||||
Healthy,
|
|
||||||
Degraded,
|
|
||||||
Failed,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for ServerStatus {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::Healthy => write!(f, "healthy"),
|
|
||||||
Self::Degraded => write!(f, "degraded"),
|
|
||||||
Self::Failed => write!(f, "failed"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
pub struct ServerHealth {
|
|
||||||
pub server_name: String,
|
|
||||||
pub status: ServerStatus,
|
|
||||||
pub capabilities: Vec<String>,
|
|
||||||
pub last_error: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case", tag = "state")]
|
|
||||||
pub enum PluginState {
|
|
||||||
Unconfigured,
|
|
||||||
Validated,
|
|
||||||
Starting,
|
|
||||||
Healthy,
|
|
||||||
Degraded {
|
|
||||||
healthy_servers: Vec<String>,
|
|
||||||
failed_servers: Vec<ServerHealth>,
|
|
||||||
},
|
|
||||||
Failed {
|
|
||||||
reason: String,
|
|
||||||
},
|
|
||||||
ShuttingDown,
|
|
||||||
Stopped,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PluginState {
|
|
||||||
#[must_use]
|
|
||||||
pub fn from_servers(servers: &[ServerHealth]) -> Self {
|
|
||||||
if servers.is_empty() {
|
|
||||||
return Self::Failed {
|
|
||||||
reason: "no servers available".to_string(),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
let healthy_servers = servers
|
|
||||||
.iter()
|
|
||||||
.filter(|server| server.status != ServerStatus::Failed)
|
|
||||||
.map(|server| server.server_name.clone())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let failed_servers = servers
|
|
||||||
.iter()
|
|
||||||
.filter(|server| server.status == ServerStatus::Failed)
|
|
||||||
.cloned()
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let has_degraded_server = servers
|
|
||||||
.iter()
|
|
||||||
.any(|server| server.status == ServerStatus::Degraded);
|
|
||||||
|
|
||||||
if failed_servers.is_empty() && !has_degraded_server {
|
|
||||||
Self::Healthy
|
|
||||||
} else if healthy_servers.is_empty() {
|
|
||||||
Self::Failed {
|
|
||||||
reason: format!("all {} servers failed", failed_servers.len()),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Self::Degraded {
|
|
||||||
healthy_servers,
|
|
||||||
failed_servers,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for PluginState {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::Unconfigured => write!(f, "unconfigured"),
|
|
||||||
Self::Validated => write!(f, "validated"),
|
|
||||||
Self::Starting => write!(f, "starting"),
|
|
||||||
Self::Healthy => write!(f, "healthy"),
|
|
||||||
Self::Degraded { .. } => write!(f, "degraded"),
|
|
||||||
Self::Failed { .. } => write!(f, "failed"),
|
|
||||||
Self::ShuttingDown => write!(f, "shutting_down"),
|
|
||||||
Self::Stopped => write!(f, "stopped"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
pub struct PluginHealthcheck {
|
|
||||||
pub plugin_name: String,
|
|
||||||
pub state: PluginState,
|
|
||||||
pub servers: Vec<ServerHealth>,
|
|
||||||
pub last_check: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PluginHealthcheck {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new(plugin_name: impl Into<String>, servers: Vec<ServerHealth>) -> Self {
|
|
||||||
let state = PluginState::from_servers(&servers);
|
|
||||||
Self {
|
|
||||||
plugin_name: plugin_name.into(),
|
|
||||||
state,
|
|
||||||
servers,
|
|
||||||
last_check: now_secs(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn degraded_mode(&self, discovery: &DiscoveryResult) -> Option<DegradedMode> {
|
|
||||||
match &self.state {
|
|
||||||
PluginState::Degraded {
|
|
||||||
healthy_servers,
|
|
||||||
failed_servers,
|
|
||||||
} => Some(DegradedMode {
|
|
||||||
available_tools: discovery
|
|
||||||
.tools
|
|
||||||
.iter()
|
|
||||||
.map(|tool| tool.name.clone())
|
|
||||||
.collect(),
|
|
||||||
unavailable_tools: failed_servers
|
|
||||||
.iter()
|
|
||||||
.flat_map(|server| server.capabilities.iter().cloned())
|
|
||||||
.collect(),
|
|
||||||
reason: format!(
|
|
||||||
"{} servers healthy, {} servers failed",
|
|
||||||
healthy_servers.len(),
|
|
||||||
failed_servers.len()
|
|
||||||
),
|
|
||||||
}),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct DiscoveryResult {
|
|
||||||
pub tools: Vec<ToolInfo>,
|
|
||||||
pub resources: Vec<ResourceInfo>,
|
|
||||||
pub partial: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
pub struct DegradedMode {
|
|
||||||
pub available_tools: Vec<String>,
|
|
||||||
pub unavailable_tools: Vec<String>,
|
|
||||||
pub reason: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DegradedMode {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new(
|
|
||||||
available_tools: Vec<String>,
|
|
||||||
unavailable_tools: Vec<String>,
|
|
||||||
reason: impl Into<String>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
available_tools,
|
|
||||||
unavailable_tools,
|
|
||||||
reason: reason.into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum PluginLifecycleEvent {
|
|
||||||
ConfigValidated,
|
|
||||||
StartupHealthy,
|
|
||||||
StartupDegraded,
|
|
||||||
StartupFailed,
|
|
||||||
Shutdown,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for PluginLifecycleEvent {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::ConfigValidated => write!(f, "config_validated"),
|
|
||||||
Self::StartupHealthy => write!(f, "startup_healthy"),
|
|
||||||
Self::StartupDegraded => write!(f, "startup_degraded"),
|
|
||||||
Self::StartupFailed => write!(f, "startup_failed"),
|
|
||||||
Self::Shutdown => write!(f, "shutdown"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait PluginLifecycle {
|
|
||||||
fn validate_config(&self, config: &RuntimePluginConfig) -> Result<(), String>;
|
|
||||||
fn healthcheck(&self) -> PluginHealthcheck;
|
|
||||||
fn discover(&self) -> DiscoveryResult;
|
|
||||||
fn shutdown(&mut self) -> Result<(), String>;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct MockPluginLifecycle {
|
|
||||||
plugin_name: String,
|
|
||||||
valid_config: bool,
|
|
||||||
healthcheck: PluginHealthcheck,
|
|
||||||
discovery: DiscoveryResult,
|
|
||||||
shutdown_error: Option<String>,
|
|
||||||
shutdown_called: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MockPluginLifecycle {
|
|
||||||
fn new(
|
|
||||||
plugin_name: &str,
|
|
||||||
valid_config: bool,
|
|
||||||
servers: Vec<ServerHealth>,
|
|
||||||
discovery: DiscoveryResult,
|
|
||||||
shutdown_error: Option<String>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
plugin_name: plugin_name.to_string(),
|
|
||||||
valid_config,
|
|
||||||
healthcheck: PluginHealthcheck::new(plugin_name, servers),
|
|
||||||
discovery,
|
|
||||||
shutdown_error,
|
|
||||||
shutdown_called: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PluginLifecycle for MockPluginLifecycle {
|
|
||||||
fn validate_config(&self, _config: &RuntimePluginConfig) -> Result<(), String> {
|
|
||||||
if self.valid_config {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(format!(
|
|
||||||
"plugin `{}` failed configuration validation",
|
|
||||||
self.plugin_name
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn healthcheck(&self) -> PluginHealthcheck {
|
|
||||||
if self.shutdown_called {
|
|
||||||
PluginHealthcheck {
|
|
||||||
plugin_name: self.plugin_name.clone(),
|
|
||||||
state: PluginState::Stopped,
|
|
||||||
servers: self.healthcheck.servers.clone(),
|
|
||||||
last_check: now_secs(),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
self.healthcheck.clone()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn discover(&self) -> DiscoveryResult {
|
|
||||||
self.discovery.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn shutdown(&mut self) -> Result<(), String> {
|
|
||||||
if let Some(error) = &self.shutdown_error {
|
|
||||||
return Err(error.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
self.shutdown_called = true;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn healthy_server(name: &str, capabilities: &[&str]) -> ServerHealth {
|
|
||||||
ServerHealth {
|
|
||||||
server_name: name.to_string(),
|
|
||||||
status: ServerStatus::Healthy,
|
|
||||||
capabilities: capabilities
|
|
||||||
.iter()
|
|
||||||
.map(|capability| capability.to_string())
|
|
||||||
.collect(),
|
|
||||||
last_error: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn failed_server(name: &str, capabilities: &[&str], error: &str) -> ServerHealth {
|
|
||||||
ServerHealth {
|
|
||||||
server_name: name.to_string(),
|
|
||||||
status: ServerStatus::Failed,
|
|
||||||
capabilities: capabilities
|
|
||||||
.iter()
|
|
||||||
.map(|capability| capability.to_string())
|
|
||||||
.collect(),
|
|
||||||
last_error: Some(error.to_string()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn degraded_server(name: &str, capabilities: &[&str], error: &str) -> ServerHealth {
|
|
||||||
ServerHealth {
|
|
||||||
server_name: name.to_string(),
|
|
||||||
status: ServerStatus::Degraded,
|
|
||||||
capabilities: capabilities
|
|
||||||
.iter()
|
|
||||||
.map(|capability| capability.to_string())
|
|
||||||
.collect(),
|
|
||||||
last_error: Some(error.to_string()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn tool(name: &str) -> ToolInfo {
|
|
||||||
ToolInfo {
|
|
||||||
name: name.to_string(),
|
|
||||||
description: Some(format!("{name} tool")),
|
|
||||||
input_schema: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn resource(name: &str, uri: &str) -> ResourceInfo {
|
|
||||||
ResourceInfo {
|
|
||||||
uri: uri.to_string(),
|
|
||||||
name: name.to_string(),
|
|
||||||
description: Some(format!("{name} resource")),
|
|
||||||
mime_type: Some("application/json".to_string()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn full_lifecycle_happy_path() {
|
|
||||||
// given
|
|
||||||
let mut lifecycle = MockPluginLifecycle::new(
|
|
||||||
"healthy-plugin",
|
|
||||||
true,
|
|
||||||
vec![
|
|
||||||
healthy_server("alpha", &["search", "read"]),
|
|
||||||
healthy_server("beta", &["write"]),
|
|
||||||
],
|
|
||||||
DiscoveryResult {
|
|
||||||
tools: vec![tool("search"), tool("read"), tool("write")],
|
|
||||||
resources: vec![resource("docs", "file:///docs")],
|
|
||||||
partial: false,
|
|
||||||
},
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
let config = RuntimePluginConfig::default();
|
|
||||||
|
|
||||||
// when
|
|
||||||
let validation = lifecycle.validate_config(&config);
|
|
||||||
let healthcheck = lifecycle.healthcheck();
|
|
||||||
let discovery = lifecycle.discover();
|
|
||||||
let shutdown = lifecycle.shutdown();
|
|
||||||
let post_shutdown = lifecycle.healthcheck();
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(validation, Ok(()));
|
|
||||||
assert_eq!(healthcheck.state, PluginState::Healthy);
|
|
||||||
assert_eq!(healthcheck.plugin_name, "healthy-plugin");
|
|
||||||
assert_eq!(discovery.tools.len(), 3);
|
|
||||||
assert_eq!(discovery.resources.len(), 1);
|
|
||||||
assert!(!discovery.partial);
|
|
||||||
assert_eq!(shutdown, Ok(()));
|
|
||||||
assert_eq!(post_shutdown.state, PluginState::Stopped);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn degraded_startup_when_one_of_three_servers_fails() {
|
|
||||||
// given
|
|
||||||
let lifecycle = MockPluginLifecycle::new(
|
|
||||||
"degraded-plugin",
|
|
||||||
true,
|
|
||||||
vec![
|
|
||||||
healthy_server("alpha", &["search"]),
|
|
||||||
failed_server("beta", &["write"], "connection refused"),
|
|
||||||
healthy_server("gamma", &["read"]),
|
|
||||||
],
|
|
||||||
DiscoveryResult {
|
|
||||||
tools: vec![tool("search"), tool("read")],
|
|
||||||
resources: vec![resource("alpha-docs", "file:///alpha")],
|
|
||||||
partial: true,
|
|
||||||
},
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let healthcheck = lifecycle.healthcheck();
|
|
||||||
let discovery = lifecycle.discover();
|
|
||||||
let degraded_mode = healthcheck
|
|
||||||
.degraded_mode(&discovery)
|
|
||||||
.expect("degraded startup should expose degraded mode");
|
|
||||||
|
|
||||||
// then
|
|
||||||
match healthcheck.state {
|
|
||||||
PluginState::Degraded {
|
|
||||||
healthy_servers,
|
|
||||||
failed_servers,
|
|
||||||
} => {
|
|
||||||
assert_eq!(
|
|
||||||
healthy_servers,
|
|
||||||
vec!["alpha".to_string(), "gamma".to_string()]
|
|
||||||
);
|
|
||||||
assert_eq!(failed_servers.len(), 1);
|
|
||||||
assert_eq!(failed_servers[0].server_name, "beta");
|
|
||||||
assert_eq!(
|
|
||||||
failed_servers[0].last_error.as_deref(),
|
|
||||||
Some("connection refused")
|
|
||||||
);
|
|
||||||
}
|
|
||||||
other => panic!("expected degraded state, got {other:?}"),
|
|
||||||
}
|
|
||||||
assert!(discovery.partial);
|
|
||||||
assert_eq!(
|
|
||||||
degraded_mode.available_tools,
|
|
||||||
vec!["search".to_string(), "read".to_string()]
|
|
||||||
);
|
|
||||||
assert_eq!(degraded_mode.unavailable_tools, vec!["write".to_string()]);
|
|
||||||
assert_eq!(degraded_mode.reason, "2 servers healthy, 1 servers failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn degraded_server_status_keeps_server_usable() {
|
|
||||||
// given
|
|
||||||
let lifecycle = MockPluginLifecycle::new(
|
|
||||||
"soft-degraded-plugin",
|
|
||||||
true,
|
|
||||||
vec![
|
|
||||||
healthy_server("alpha", &["search"]),
|
|
||||||
degraded_server("beta", &["write"], "high latency"),
|
|
||||||
],
|
|
||||||
DiscoveryResult {
|
|
||||||
tools: vec![tool("search"), tool("write")],
|
|
||||||
resources: Vec::new(),
|
|
||||||
partial: true,
|
|
||||||
},
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let healthcheck = lifecycle.healthcheck();
|
|
||||||
|
|
||||||
// then
|
|
||||||
match healthcheck.state {
|
|
||||||
PluginState::Degraded {
|
|
||||||
healthy_servers,
|
|
||||||
failed_servers,
|
|
||||||
} => {
|
|
||||||
assert_eq!(
|
|
||||||
healthy_servers,
|
|
||||||
vec!["alpha".to_string(), "beta".to_string()]
|
|
||||||
);
|
|
||||||
assert!(failed_servers.is_empty());
|
|
||||||
}
|
|
||||||
other => panic!("expected degraded state, got {other:?}"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn complete_failure_when_all_servers_fail() {
|
|
||||||
// given
|
|
||||||
let lifecycle = MockPluginLifecycle::new(
|
|
||||||
"failed-plugin",
|
|
||||||
true,
|
|
||||||
vec![
|
|
||||||
failed_server("alpha", &["search"], "timeout"),
|
|
||||||
failed_server("beta", &["read"], "handshake failed"),
|
|
||||||
],
|
|
||||||
DiscoveryResult {
|
|
||||||
tools: Vec::new(),
|
|
||||||
resources: Vec::new(),
|
|
||||||
partial: false,
|
|
||||||
},
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let healthcheck = lifecycle.healthcheck();
|
|
||||||
let discovery = lifecycle.discover();
|
|
||||||
|
|
||||||
// then
|
|
||||||
match &healthcheck.state {
|
|
||||||
PluginState::Failed { reason } => {
|
|
||||||
assert_eq!(reason, "all 2 servers failed");
|
|
||||||
}
|
|
||||||
other => panic!("expected failed state, got {other:?}"),
|
|
||||||
}
|
|
||||||
assert!(!discovery.partial);
|
|
||||||
assert!(discovery.tools.is_empty());
|
|
||||||
assert!(discovery.resources.is_empty());
|
|
||||||
assert!(healthcheck.degraded_mode(&discovery).is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn graceful_shutdown() {
|
|
||||||
// given
|
|
||||||
let mut lifecycle = MockPluginLifecycle::new(
|
|
||||||
"shutdown-plugin",
|
|
||||||
true,
|
|
||||||
vec![healthy_server("alpha", &["search"])],
|
|
||||||
DiscoveryResult {
|
|
||||||
tools: vec![tool("search")],
|
|
||||||
resources: Vec::new(),
|
|
||||||
partial: false,
|
|
||||||
},
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let shutdown = lifecycle.shutdown();
|
|
||||||
let post_shutdown = lifecycle.healthcheck();
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(shutdown, Ok(()));
|
|
||||||
assert_eq!(PluginLifecycleEvent::Shutdown.to_string(), "shutdown");
|
|
||||||
assert_eq!(post_shutdown.state, PluginState::Stopped);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,458 +0,0 @@
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
pub type GreenLevel = u8;
|
|
||||||
|
|
||||||
const STALE_BRANCH_THRESHOLD: Duration = Duration::from_secs(60 * 60);
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub struct PolicyRule {
|
|
||||||
pub name: String,
|
|
||||||
pub condition: PolicyCondition,
|
|
||||||
pub action: PolicyAction,
|
|
||||||
pub priority: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PolicyRule {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new(
|
|
||||||
name: impl Into<String>,
|
|
||||||
condition: PolicyCondition,
|
|
||||||
action: PolicyAction,
|
|
||||||
priority: u32,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
name: name.into(),
|
|
||||||
condition,
|
|
||||||
action,
|
|
||||||
priority,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn matches(&self, context: &LaneContext) -> bool {
|
|
||||||
self.condition.matches(context)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub enum PolicyCondition {
|
|
||||||
And(Vec<PolicyCondition>),
|
|
||||||
Or(Vec<PolicyCondition>),
|
|
||||||
GreenAt { level: GreenLevel },
|
|
||||||
StaleBranch,
|
|
||||||
StartupBlocked,
|
|
||||||
LaneCompleted,
|
|
||||||
ReviewPassed,
|
|
||||||
ScopedDiff,
|
|
||||||
TimedOut { duration: Duration },
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PolicyCondition {
|
|
||||||
#[must_use]
|
|
||||||
pub fn matches(&self, context: &LaneContext) -> bool {
|
|
||||||
match self {
|
|
||||||
Self::And(conditions) => conditions
|
|
||||||
.iter()
|
|
||||||
.all(|condition| condition.matches(context)),
|
|
||||||
Self::Or(conditions) => conditions
|
|
||||||
.iter()
|
|
||||||
.any(|condition| condition.matches(context)),
|
|
||||||
Self::GreenAt { level } => context.green_level >= *level,
|
|
||||||
Self::StaleBranch => context.branch_freshness >= STALE_BRANCH_THRESHOLD,
|
|
||||||
Self::StartupBlocked => context.blocker == LaneBlocker::Startup,
|
|
||||||
Self::LaneCompleted => context.completed,
|
|
||||||
Self::ReviewPassed => context.review_status == ReviewStatus::Approved,
|
|
||||||
Self::ScopedDiff => context.diff_scope == DiffScope::Scoped,
|
|
||||||
Self::TimedOut { duration } => context.branch_freshness >= *duration,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub enum PolicyAction {
|
|
||||||
MergeToDev,
|
|
||||||
MergeForward,
|
|
||||||
RecoverOnce,
|
|
||||||
Escalate { reason: String },
|
|
||||||
CloseoutLane,
|
|
||||||
CleanupSession,
|
|
||||||
Notify { channel: String },
|
|
||||||
Block { reason: String },
|
|
||||||
Chain(Vec<PolicyAction>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PolicyAction {
|
|
||||||
fn flatten_into(&self, actions: &mut Vec<PolicyAction>) {
|
|
||||||
match self {
|
|
||||||
Self::Chain(chained) => {
|
|
||||||
for action in chained {
|
|
||||||
action.flatten_into(actions);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => actions.push(self.clone()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
pub enum LaneBlocker {
|
|
||||||
None,
|
|
||||||
Startup,
|
|
||||||
External,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
pub enum ReviewStatus {
|
|
||||||
Pending,
|
|
||||||
Approved,
|
|
||||||
Rejected,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
pub enum DiffScope {
|
|
||||||
Full,
|
|
||||||
Scoped,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub struct LaneContext {
|
|
||||||
pub lane_id: String,
|
|
||||||
pub green_level: GreenLevel,
|
|
||||||
pub branch_freshness: Duration,
|
|
||||||
pub blocker: LaneBlocker,
|
|
||||||
pub review_status: ReviewStatus,
|
|
||||||
pub diff_scope: DiffScope,
|
|
||||||
pub completed: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LaneContext {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new(
|
|
||||||
lane_id: impl Into<String>,
|
|
||||||
green_level: GreenLevel,
|
|
||||||
branch_freshness: Duration,
|
|
||||||
blocker: LaneBlocker,
|
|
||||||
review_status: ReviewStatus,
|
|
||||||
diff_scope: DiffScope,
|
|
||||||
completed: bool,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
lane_id: lane_id.into(),
|
|
||||||
green_level,
|
|
||||||
branch_freshness,
|
|
||||||
blocker,
|
|
||||||
review_status,
|
|
||||||
diff_scope,
|
|
||||||
completed,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub struct PolicyEngine {
|
|
||||||
rules: Vec<PolicyRule>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PolicyEngine {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new(mut rules: Vec<PolicyRule>) -> Self {
|
|
||||||
rules.sort_by_key(|rule| rule.priority);
|
|
||||||
Self { rules }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn rules(&self) -> &[PolicyRule] {
|
|
||||||
&self.rules
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn evaluate(&self, context: &LaneContext) -> Vec<PolicyAction> {
|
|
||||||
evaluate(self, context)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn evaluate(engine: &PolicyEngine, context: &LaneContext) -> Vec<PolicyAction> {
|
|
||||||
let mut actions = Vec::new();
|
|
||||||
for rule in &engine.rules {
|
|
||||||
if rule.matches(context) {
|
|
||||||
rule.action.flatten_into(&mut actions);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
actions
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
use super::{
|
|
||||||
evaluate, DiffScope, LaneBlocker, LaneContext, PolicyAction, PolicyCondition, PolicyEngine,
|
|
||||||
PolicyRule, ReviewStatus, STALE_BRANCH_THRESHOLD,
|
|
||||||
};
|
|
||||||
|
|
||||||
fn default_context() -> LaneContext {
|
|
||||||
LaneContext::new(
|
|
||||||
"lane-7",
|
|
||||||
0,
|
|
||||||
Duration::from_secs(0),
|
|
||||||
LaneBlocker::None,
|
|
||||||
ReviewStatus::Pending,
|
|
||||||
DiffScope::Full,
|
|
||||||
false,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn merge_to_dev_rule_fires_for_green_scoped_reviewed_lane() {
|
|
||||||
// given
|
|
||||||
let engine = PolicyEngine::new(vec![PolicyRule::new(
|
|
||||||
"merge-to-dev",
|
|
||||||
PolicyCondition::And(vec![
|
|
||||||
PolicyCondition::GreenAt { level: 2 },
|
|
||||||
PolicyCondition::ScopedDiff,
|
|
||||||
PolicyCondition::ReviewPassed,
|
|
||||||
]),
|
|
||||||
PolicyAction::MergeToDev,
|
|
||||||
20,
|
|
||||||
)]);
|
|
||||||
let context = LaneContext::new(
|
|
||||||
"lane-7",
|
|
||||||
3,
|
|
||||||
Duration::from_secs(5),
|
|
||||||
LaneBlocker::None,
|
|
||||||
ReviewStatus::Approved,
|
|
||||||
DiffScope::Scoped,
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let actions = engine.evaluate(&context);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(actions, vec![PolicyAction::MergeToDev]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn stale_branch_rule_fires_at_threshold() {
|
|
||||||
// given
|
|
||||||
let engine = PolicyEngine::new(vec![PolicyRule::new(
|
|
||||||
"merge-forward",
|
|
||||||
PolicyCondition::StaleBranch,
|
|
||||||
PolicyAction::MergeForward,
|
|
||||||
10,
|
|
||||||
)]);
|
|
||||||
let context = LaneContext::new(
|
|
||||||
"lane-7",
|
|
||||||
1,
|
|
||||||
STALE_BRANCH_THRESHOLD,
|
|
||||||
LaneBlocker::None,
|
|
||||||
ReviewStatus::Pending,
|
|
||||||
DiffScope::Full,
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let actions = engine.evaluate(&context);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(actions, vec![PolicyAction::MergeForward]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn startup_blocked_rule_recovers_then_escalates() {
|
|
||||||
// given
|
|
||||||
let engine = PolicyEngine::new(vec![PolicyRule::new(
|
|
||||||
"startup-recovery",
|
|
||||||
PolicyCondition::StartupBlocked,
|
|
||||||
PolicyAction::Chain(vec![
|
|
||||||
PolicyAction::RecoverOnce,
|
|
||||||
PolicyAction::Escalate {
|
|
||||||
reason: "startup remained blocked".to_string(),
|
|
||||||
},
|
|
||||||
]),
|
|
||||||
15,
|
|
||||||
)]);
|
|
||||||
let context = LaneContext::new(
|
|
||||||
"lane-7",
|
|
||||||
0,
|
|
||||||
Duration::from_secs(0),
|
|
||||||
LaneBlocker::Startup,
|
|
||||||
ReviewStatus::Pending,
|
|
||||||
DiffScope::Full,
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let actions = engine.evaluate(&context);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(
|
|
||||||
actions,
|
|
||||||
vec![
|
|
||||||
PolicyAction::RecoverOnce,
|
|
||||||
PolicyAction::Escalate {
|
|
||||||
reason: "startup remained blocked".to_string(),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn completed_lane_rule_closes_out_and_cleans_up() {
|
|
||||||
// given
|
|
||||||
let engine = PolicyEngine::new(vec![PolicyRule::new(
|
|
||||||
"lane-closeout",
|
|
||||||
PolicyCondition::LaneCompleted,
|
|
||||||
PolicyAction::Chain(vec![
|
|
||||||
PolicyAction::CloseoutLane,
|
|
||||||
PolicyAction::CleanupSession,
|
|
||||||
]),
|
|
||||||
30,
|
|
||||||
)]);
|
|
||||||
let context = LaneContext::new(
|
|
||||||
"lane-7",
|
|
||||||
0,
|
|
||||||
Duration::from_secs(0),
|
|
||||||
LaneBlocker::None,
|
|
||||||
ReviewStatus::Pending,
|
|
||||||
DiffScope::Full,
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let actions = engine.evaluate(&context);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(
|
|
||||||
actions,
|
|
||||||
vec![PolicyAction::CloseoutLane, PolicyAction::CleanupSession]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn matching_rules_are_returned_in_priority_order_with_stable_ties() {
|
|
||||||
// given
|
|
||||||
let engine = PolicyEngine::new(vec![
|
|
||||||
PolicyRule::new(
|
|
||||||
"late-cleanup",
|
|
||||||
PolicyCondition::And(vec![]),
|
|
||||||
PolicyAction::CleanupSession,
|
|
||||||
30,
|
|
||||||
),
|
|
||||||
PolicyRule::new(
|
|
||||||
"first-notify",
|
|
||||||
PolicyCondition::And(vec![]),
|
|
||||||
PolicyAction::Notify {
|
|
||||||
channel: "ops".to_string(),
|
|
||||||
},
|
|
||||||
10,
|
|
||||||
),
|
|
||||||
PolicyRule::new(
|
|
||||||
"second-notify",
|
|
||||||
PolicyCondition::And(vec![]),
|
|
||||||
PolicyAction::Notify {
|
|
||||||
channel: "review".to_string(),
|
|
||||||
},
|
|
||||||
10,
|
|
||||||
),
|
|
||||||
PolicyRule::new(
|
|
||||||
"merge",
|
|
||||||
PolicyCondition::And(vec![]),
|
|
||||||
PolicyAction::MergeToDev,
|
|
||||||
20,
|
|
||||||
),
|
|
||||||
]);
|
|
||||||
let context = default_context();
|
|
||||||
|
|
||||||
// when
|
|
||||||
let actions = evaluate(&engine, &context);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(
|
|
||||||
actions,
|
|
||||||
vec![
|
|
||||||
PolicyAction::Notify {
|
|
||||||
channel: "ops".to_string(),
|
|
||||||
},
|
|
||||||
PolicyAction::Notify {
|
|
||||||
channel: "review".to_string(),
|
|
||||||
},
|
|
||||||
PolicyAction::MergeToDev,
|
|
||||||
PolicyAction::CleanupSession,
|
|
||||||
]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn combinators_handle_empty_cases_and_nested_chains() {
|
|
||||||
// given
|
|
||||||
let engine = PolicyEngine::new(vec![
|
|
||||||
PolicyRule::new(
|
|
||||||
"empty-and",
|
|
||||||
PolicyCondition::And(vec![]),
|
|
||||||
PolicyAction::Notify {
|
|
||||||
channel: "orchestrator".to_string(),
|
|
||||||
},
|
|
||||||
5,
|
|
||||||
),
|
|
||||||
PolicyRule::new(
|
|
||||||
"empty-or",
|
|
||||||
PolicyCondition::Or(vec![]),
|
|
||||||
PolicyAction::Block {
|
|
||||||
reason: "should not fire".to_string(),
|
|
||||||
},
|
|
||||||
10,
|
|
||||||
),
|
|
||||||
PolicyRule::new(
|
|
||||||
"nested",
|
|
||||||
PolicyCondition::Or(vec![
|
|
||||||
PolicyCondition::StartupBlocked,
|
|
||||||
PolicyCondition::And(vec![
|
|
||||||
PolicyCondition::GreenAt { level: 2 },
|
|
||||||
PolicyCondition::TimedOut {
|
|
||||||
duration: Duration::from_secs(5),
|
|
||||||
},
|
|
||||||
]),
|
|
||||||
]),
|
|
||||||
PolicyAction::Chain(vec![
|
|
||||||
PolicyAction::Notify {
|
|
||||||
channel: "alerts".to_string(),
|
|
||||||
},
|
|
||||||
PolicyAction::Chain(vec![
|
|
||||||
PolicyAction::MergeForward,
|
|
||||||
PolicyAction::CleanupSession,
|
|
||||||
]),
|
|
||||||
]),
|
|
||||||
15,
|
|
||||||
),
|
|
||||||
]);
|
|
||||||
let context = LaneContext::new(
|
|
||||||
"lane-7",
|
|
||||||
2,
|
|
||||||
Duration::from_secs(10),
|
|
||||||
LaneBlocker::External,
|
|
||||||
ReviewStatus::Pending,
|
|
||||||
DiffScope::Full,
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let actions = engine.evaluate(&context);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(
|
|
||||||
actions,
|
|
||||||
vec![
|
|
||||||
PolicyAction::Notify {
|
|
||||||
channel: "orchestrator".to_string(),
|
|
||||||
},
|
|
||||||
PolicyAction::Notify {
|
|
||||||
channel: "alerts".to_string(),
|
|
||||||
},
|
|
||||||
PolicyAction::MergeForward,
|
|
||||||
PolicyAction::CleanupSession,
|
|
||||||
]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,554 +0,0 @@
|
||||||
//! Recovery recipes for common failure scenarios.
|
|
||||||
//!
|
|
||||||
//! Encodes known automatic recoveries for the six failure scenarios
|
|
||||||
//! listed in ROADMAP item 8, and enforces one automatic recovery
|
|
||||||
//! attempt before escalation. Each attempt is emitted as a structured
|
|
||||||
//! recovery event.
|
|
||||||
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
/// The six failure scenarios that have known recovery recipes.
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum FailureScenario {
|
|
||||||
TrustPromptUnresolved,
|
|
||||||
PromptMisdelivery,
|
|
||||||
StaleBranch,
|
|
||||||
CompileRedCrossCrate,
|
|
||||||
McpHandshakeFailure,
|
|
||||||
PartialPluginStartup,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FailureScenario {
|
|
||||||
/// Returns all known failure scenarios.
|
|
||||||
#[must_use]
|
|
||||||
pub fn all() -> &'static [FailureScenario] {
|
|
||||||
&[
|
|
||||||
Self::TrustPromptUnresolved,
|
|
||||||
Self::PromptMisdelivery,
|
|
||||||
Self::StaleBranch,
|
|
||||||
Self::CompileRedCrossCrate,
|
|
||||||
Self::McpHandshakeFailure,
|
|
||||||
Self::PartialPluginStartup,
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for FailureScenario {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::TrustPromptUnresolved => write!(f, "trust_prompt_unresolved"),
|
|
||||||
Self::PromptMisdelivery => write!(f, "prompt_misdelivery"),
|
|
||||||
Self::StaleBranch => write!(f, "stale_branch"),
|
|
||||||
Self::CompileRedCrossCrate => write!(f, "compile_red_cross_crate"),
|
|
||||||
Self::McpHandshakeFailure => write!(f, "mcp_handshake_failure"),
|
|
||||||
Self::PartialPluginStartup => write!(f, "partial_plugin_startup"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Individual step that can be executed as part of a recovery recipe.
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum RecoveryStep {
|
|
||||||
AcceptTrustPrompt,
|
|
||||||
RedirectPromptToAgent,
|
|
||||||
RebaseBranch,
|
|
||||||
CleanBuild,
|
|
||||||
RetryMcpHandshake { timeout: u64 },
|
|
||||||
RestartPlugin { name: String },
|
|
||||||
EscalateToHuman { reason: String },
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Policy governing what happens when automatic recovery is exhausted.
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum EscalationPolicy {
|
|
||||||
AlertHuman,
|
|
||||||
LogAndContinue,
|
|
||||||
Abort,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A recovery recipe encodes the sequence of steps to attempt for a
|
|
||||||
/// given failure scenario, along with the maximum number of automatic
|
|
||||||
/// attempts and the escalation policy.
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
pub struct RecoveryRecipe {
|
|
||||||
pub scenario: FailureScenario,
|
|
||||||
pub steps: Vec<RecoveryStep>,
|
|
||||||
pub max_attempts: u32,
|
|
||||||
pub escalation_policy: EscalationPolicy,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Outcome of a recovery attempt.
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum RecoveryResult {
|
|
||||||
Recovered {
|
|
||||||
steps_taken: u32,
|
|
||||||
},
|
|
||||||
PartialRecovery {
|
|
||||||
recovered: Vec<RecoveryStep>,
|
|
||||||
remaining: Vec<RecoveryStep>,
|
|
||||||
},
|
|
||||||
EscalationRequired {
|
|
||||||
reason: String,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Structured event emitted during recovery.
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum RecoveryEvent {
|
|
||||||
RecoveryAttempted {
|
|
||||||
scenario: FailureScenario,
|
|
||||||
recipe: RecoveryRecipe,
|
|
||||||
result: RecoveryResult,
|
|
||||||
},
|
|
||||||
RecoverySucceeded,
|
|
||||||
RecoveryFailed,
|
|
||||||
Escalated,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Minimal context for tracking recovery state and emitting events.
|
|
||||||
///
|
|
||||||
/// Holds per-scenario attempt counts, a structured event log, and an
|
|
||||||
/// optional simulation knob for controlling step outcomes during tests.
|
|
||||||
#[derive(Debug, Clone, Default)]
|
|
||||||
pub struct RecoveryContext {
|
|
||||||
attempts: HashMap<FailureScenario, u32>,
|
|
||||||
events: Vec<RecoveryEvent>,
|
|
||||||
/// Optional step index at which simulated execution fails.
|
|
||||||
/// `None` means all steps succeed.
|
|
||||||
fail_at_step: Option<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RecoveryContext {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self::default()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Configure a step index at which simulated execution will fail.
|
|
||||||
#[must_use]
|
|
||||||
pub fn with_fail_at_step(mut self, index: usize) -> Self {
|
|
||||||
self.fail_at_step = Some(index);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the structured event log populated during recovery.
|
|
||||||
#[must_use]
|
|
||||||
pub fn events(&self) -> &[RecoveryEvent] {
|
|
||||||
&self.events
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the number of recovery attempts made for a scenario.
|
|
||||||
#[must_use]
|
|
||||||
pub fn attempt_count(&self, scenario: &FailureScenario) -> u32 {
|
|
||||||
self.attempts.get(scenario).copied().unwrap_or(0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the known recovery recipe for the given failure scenario.
|
|
||||||
#[must_use]
|
|
||||||
pub fn recipe_for(scenario: &FailureScenario) -> RecoveryRecipe {
|
|
||||||
match scenario {
|
|
||||||
FailureScenario::TrustPromptUnresolved => RecoveryRecipe {
|
|
||||||
scenario: *scenario,
|
|
||||||
steps: vec![RecoveryStep::AcceptTrustPrompt],
|
|
||||||
max_attempts: 1,
|
|
||||||
escalation_policy: EscalationPolicy::AlertHuman,
|
|
||||||
},
|
|
||||||
FailureScenario::PromptMisdelivery => RecoveryRecipe {
|
|
||||||
scenario: *scenario,
|
|
||||||
steps: vec![RecoveryStep::RedirectPromptToAgent],
|
|
||||||
max_attempts: 1,
|
|
||||||
escalation_policy: EscalationPolicy::AlertHuman,
|
|
||||||
},
|
|
||||||
FailureScenario::StaleBranch => RecoveryRecipe {
|
|
||||||
scenario: *scenario,
|
|
||||||
steps: vec![RecoveryStep::RebaseBranch, RecoveryStep::CleanBuild],
|
|
||||||
max_attempts: 1,
|
|
||||||
escalation_policy: EscalationPolicy::AlertHuman,
|
|
||||||
},
|
|
||||||
FailureScenario::CompileRedCrossCrate => RecoveryRecipe {
|
|
||||||
scenario: *scenario,
|
|
||||||
steps: vec![RecoveryStep::CleanBuild],
|
|
||||||
max_attempts: 1,
|
|
||||||
escalation_policy: EscalationPolicy::AlertHuman,
|
|
||||||
},
|
|
||||||
FailureScenario::McpHandshakeFailure => RecoveryRecipe {
|
|
||||||
scenario: *scenario,
|
|
||||||
steps: vec![RecoveryStep::RetryMcpHandshake { timeout: 5000 }],
|
|
||||||
max_attempts: 1,
|
|
||||||
escalation_policy: EscalationPolicy::Abort,
|
|
||||||
},
|
|
||||||
FailureScenario::PartialPluginStartup => RecoveryRecipe {
|
|
||||||
scenario: *scenario,
|
|
||||||
steps: vec![
|
|
||||||
RecoveryStep::RestartPlugin {
|
|
||||||
name: "stalled".to_string(),
|
|
||||||
},
|
|
||||||
RecoveryStep::RetryMcpHandshake { timeout: 3000 },
|
|
||||||
],
|
|
||||||
max_attempts: 1,
|
|
||||||
escalation_policy: EscalationPolicy::LogAndContinue,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Attempts automatic recovery for the given failure scenario.
|
|
||||||
///
|
|
||||||
/// Looks up the recipe, enforces the one-attempt-before-escalation
|
|
||||||
/// policy, simulates step execution (controlled by the context), and
|
|
||||||
/// emits structured [`RecoveryEvent`]s for every attempt.
|
|
||||||
pub fn attempt_recovery(scenario: &FailureScenario, ctx: &mut RecoveryContext) -> RecoveryResult {
|
|
||||||
let recipe = recipe_for(scenario);
|
|
||||||
let attempt_count = ctx.attempts.entry(*scenario).or_insert(0);
|
|
||||||
|
|
||||||
// Enforce one automatic recovery attempt before escalation.
|
|
||||||
if *attempt_count >= recipe.max_attempts {
|
|
||||||
let result = RecoveryResult::EscalationRequired {
|
|
||||||
reason: format!(
|
|
||||||
"max recovery attempts ({}) exceeded for {}",
|
|
||||||
recipe.max_attempts, scenario
|
|
||||||
),
|
|
||||||
};
|
|
||||||
ctx.events.push(RecoveryEvent::RecoveryAttempted {
|
|
||||||
scenario: *scenario,
|
|
||||||
recipe,
|
|
||||||
result: result.clone(),
|
|
||||||
});
|
|
||||||
ctx.events.push(RecoveryEvent::Escalated);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
*attempt_count += 1;
|
|
||||||
|
|
||||||
// Execute steps, honoring the optional fail_at_step simulation.
|
|
||||||
let fail_index = ctx.fail_at_step;
|
|
||||||
let mut executed = Vec::new();
|
|
||||||
let mut failed = false;
|
|
||||||
|
|
||||||
for (i, step) in recipe.steps.iter().enumerate() {
|
|
||||||
if fail_index == Some(i) {
|
|
||||||
failed = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
executed.push(step.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = if failed {
|
|
||||||
let remaining: Vec<RecoveryStep> = recipe.steps[executed.len()..].to_vec();
|
|
||||||
if executed.is_empty() {
|
|
||||||
RecoveryResult::EscalationRequired {
|
|
||||||
reason: format!("recovery failed at first step for {}", scenario),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
RecoveryResult::PartialRecovery {
|
|
||||||
recovered: executed,
|
|
||||||
remaining,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
RecoveryResult::Recovered {
|
|
||||||
steps_taken: recipe.steps.len() as u32,
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Emit the attempt as structured event data.
|
|
||||||
ctx.events.push(RecoveryEvent::RecoveryAttempted {
|
|
||||||
scenario: *scenario,
|
|
||||||
recipe,
|
|
||||||
result: result.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
match &result {
|
|
||||||
RecoveryResult::Recovered { .. } => {
|
|
||||||
ctx.events.push(RecoveryEvent::RecoverySucceeded);
|
|
||||||
}
|
|
||||||
RecoveryResult::PartialRecovery { .. } => {
|
|
||||||
ctx.events.push(RecoveryEvent::RecoveryFailed);
|
|
||||||
}
|
|
||||||
RecoveryResult::EscalationRequired { .. } => {
|
|
||||||
ctx.events.push(RecoveryEvent::Escalated);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn each_scenario_has_a_matching_recipe() {
|
|
||||||
// given
|
|
||||||
let scenarios = FailureScenario::all();
|
|
||||||
|
|
||||||
// when / then
|
|
||||||
for scenario in scenarios {
|
|
||||||
let recipe = recipe_for(scenario);
|
|
||||||
assert_eq!(
|
|
||||||
recipe.scenario, *scenario,
|
|
||||||
"recipe scenario should match requested scenario"
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
!recipe.steps.is_empty(),
|
|
||||||
"recipe for {} should have at least one step",
|
|
||||||
scenario
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
recipe.max_attempts >= 1,
|
|
||||||
"recipe for {} should allow at least one attempt",
|
|
||||||
scenario
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn successful_recovery_returns_recovered_and_emits_events() {
|
|
||||||
// given
|
|
||||||
let mut ctx = RecoveryContext::new();
|
|
||||||
let scenario = FailureScenario::TrustPromptUnresolved;
|
|
||||||
|
|
||||||
// when
|
|
||||||
let result = attempt_recovery(&scenario, &mut ctx);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(result, RecoveryResult::Recovered { steps_taken: 1 });
|
|
||||||
assert_eq!(ctx.events().len(), 2);
|
|
||||||
assert!(matches!(
|
|
||||||
&ctx.events()[0],
|
|
||||||
RecoveryEvent::RecoveryAttempted {
|
|
||||||
scenario: s,
|
|
||||||
result: r,
|
|
||||||
..
|
|
||||||
} if *s == FailureScenario::TrustPromptUnresolved
|
|
||||||
&& matches!(r, RecoveryResult::Recovered { steps_taken: 1 })
|
|
||||||
));
|
|
||||||
assert_eq!(ctx.events()[1], RecoveryEvent::RecoverySucceeded);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn escalation_after_max_attempts_exceeded() {
|
|
||||||
// given
|
|
||||||
let mut ctx = RecoveryContext::new();
|
|
||||||
let scenario = FailureScenario::PromptMisdelivery;
|
|
||||||
|
|
||||||
// when — first attempt succeeds
|
|
||||||
let first = attempt_recovery(&scenario, &mut ctx);
|
|
||||||
assert!(matches!(first, RecoveryResult::Recovered { .. }));
|
|
||||||
|
|
||||||
// when — second attempt should escalate
|
|
||||||
let second = attempt_recovery(&scenario, &mut ctx);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert!(
|
|
||||||
matches!(
|
|
||||||
&second,
|
|
||||||
RecoveryResult::EscalationRequired { reason }
|
|
||||||
if reason.contains("max recovery attempts")
|
|
||||||
),
|
|
||||||
"second attempt should require escalation, got: {second:?}"
|
|
||||||
);
|
|
||||||
assert_eq!(ctx.attempt_count(&scenario), 1);
|
|
||||||
assert!(ctx
|
|
||||||
.events()
|
|
||||||
.iter()
|
|
||||||
.any(|e| matches!(e, RecoveryEvent::Escalated)));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn partial_recovery_when_step_fails_midway() {
|
|
||||||
// given — PartialPluginStartup has two steps; fail at step index 1
|
|
||||||
let mut ctx = RecoveryContext::new().with_fail_at_step(1);
|
|
||||||
let scenario = FailureScenario::PartialPluginStartup;
|
|
||||||
|
|
||||||
// when
|
|
||||||
let result = attempt_recovery(&scenario, &mut ctx);
|
|
||||||
|
|
||||||
// then
|
|
||||||
match &result {
|
|
||||||
RecoveryResult::PartialRecovery {
|
|
||||||
recovered,
|
|
||||||
remaining,
|
|
||||||
} => {
|
|
||||||
assert_eq!(recovered.len(), 1, "one step should have succeeded");
|
|
||||||
assert_eq!(remaining.len(), 1, "one step should remain");
|
|
||||||
assert!(matches!(recovered[0], RecoveryStep::RestartPlugin { .. }));
|
|
||||||
assert!(matches!(
|
|
||||||
remaining[0],
|
|
||||||
RecoveryStep::RetryMcpHandshake { .. }
|
|
||||||
));
|
|
||||||
}
|
|
||||||
other => panic!("expected PartialRecovery, got {other:?}"),
|
|
||||||
}
|
|
||||||
assert!(ctx
|
|
||||||
.events()
|
|
||||||
.iter()
|
|
||||||
.any(|e| matches!(e, RecoveryEvent::RecoveryFailed)));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn first_step_failure_escalates_immediately() {
|
|
||||||
// given — fail at step index 0
|
|
||||||
let mut ctx = RecoveryContext::new().with_fail_at_step(0);
|
|
||||||
let scenario = FailureScenario::CompileRedCrossCrate;
|
|
||||||
|
|
||||||
// when
|
|
||||||
let result = attempt_recovery(&scenario, &mut ctx);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert!(
|
|
||||||
matches!(
|
|
||||||
&result,
|
|
||||||
RecoveryResult::EscalationRequired { reason }
|
|
||||||
if reason.contains("failed at first step")
|
|
||||||
),
|
|
||||||
"zero-step failure should escalate, got: {result:?}"
|
|
||||||
);
|
|
||||||
assert!(ctx
|
|
||||||
.events()
|
|
||||||
.iter()
|
|
||||||
.any(|e| matches!(e, RecoveryEvent::Escalated)));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn emitted_events_include_structured_attempt_data() {
|
|
||||||
// given
|
|
||||||
let mut ctx = RecoveryContext::new();
|
|
||||||
let scenario = FailureScenario::McpHandshakeFailure;
|
|
||||||
|
|
||||||
// when
|
|
||||||
let _ = attempt_recovery(&scenario, &mut ctx);
|
|
||||||
|
|
||||||
// then — verify the RecoveryAttempted event carries full context
|
|
||||||
let attempted = ctx
|
|
||||||
.events()
|
|
||||||
.iter()
|
|
||||||
.find(|e| matches!(e, RecoveryEvent::RecoveryAttempted { .. }))
|
|
||||||
.expect("should have emitted RecoveryAttempted event");
|
|
||||||
|
|
||||||
match attempted {
|
|
||||||
RecoveryEvent::RecoveryAttempted {
|
|
||||||
scenario: s,
|
|
||||||
recipe,
|
|
||||||
result,
|
|
||||||
} => {
|
|
||||||
assert_eq!(*s, scenario);
|
|
||||||
assert_eq!(recipe.scenario, scenario);
|
|
||||||
assert!(!recipe.steps.is_empty());
|
|
||||||
assert!(matches!(result, RecoveryResult::Recovered { .. }));
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the event is serializable as structured JSON
|
|
||||||
let json = serde_json::to_string(&ctx.events()[0])
|
|
||||||
.expect("recovery event should be serializable to JSON");
|
|
||||||
assert!(
|
|
||||||
json.contains("mcp_handshake_failure"),
|
|
||||||
"serialized event should contain scenario name"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn recovery_context_tracks_attempts_per_scenario() {
|
|
||||||
// given
|
|
||||||
let mut ctx = RecoveryContext::new();
|
|
||||||
|
|
||||||
// when
|
|
||||||
assert_eq!(ctx.attempt_count(&FailureScenario::StaleBranch), 0);
|
|
||||||
attempt_recovery(&FailureScenario::StaleBranch, &mut ctx);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(ctx.attempt_count(&FailureScenario::StaleBranch), 1);
|
|
||||||
assert_eq!(ctx.attempt_count(&FailureScenario::PromptMisdelivery), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn stale_branch_recipe_has_rebase_then_clean_build() {
|
|
||||||
// given
|
|
||||||
let recipe = recipe_for(&FailureScenario::StaleBranch);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(recipe.steps.len(), 2);
|
|
||||||
assert_eq!(recipe.steps[0], RecoveryStep::RebaseBranch);
|
|
||||||
assert_eq!(recipe.steps[1], RecoveryStep::CleanBuild);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn partial_plugin_startup_recipe_has_restart_then_handshake() {
|
|
||||||
// given
|
|
||||||
let recipe = recipe_for(&FailureScenario::PartialPluginStartup);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(recipe.steps.len(), 2);
|
|
||||||
assert!(matches!(
|
|
||||||
recipe.steps[0],
|
|
||||||
RecoveryStep::RestartPlugin { .. }
|
|
||||||
));
|
|
||||||
assert!(matches!(
|
|
||||||
recipe.steps[1],
|
|
||||||
RecoveryStep::RetryMcpHandshake { timeout: 3000 }
|
|
||||||
));
|
|
||||||
assert_eq!(recipe.escalation_policy, EscalationPolicy::LogAndContinue);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn failure_scenario_display_all_variants() {
|
|
||||||
// given
|
|
||||||
let cases = [
|
|
||||||
(
|
|
||||||
FailureScenario::TrustPromptUnresolved,
|
|
||||||
"trust_prompt_unresolved",
|
|
||||||
),
|
|
||||||
(FailureScenario::PromptMisdelivery, "prompt_misdelivery"),
|
|
||||||
(FailureScenario::StaleBranch, "stale_branch"),
|
|
||||||
(
|
|
||||||
FailureScenario::CompileRedCrossCrate,
|
|
||||||
"compile_red_cross_crate",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
FailureScenario::McpHandshakeFailure,
|
|
||||||
"mcp_handshake_failure",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
FailureScenario::PartialPluginStartup,
|
|
||||||
"partial_plugin_startup",
|
|
||||||
),
|
|
||||||
];
|
|
||||||
|
|
||||||
// when / then
|
|
||||||
for (scenario, expected) in &cases {
|
|
||||||
assert_eq!(scenario.to_string(), *expected);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn multi_step_success_reports_correct_steps_taken() {
|
|
||||||
// given — StaleBranch has 2 steps, no simulated failure
|
|
||||||
let mut ctx = RecoveryContext::new();
|
|
||||||
let scenario = FailureScenario::StaleBranch;
|
|
||||||
|
|
||||||
// when
|
|
||||||
let result = attempt_recovery(&scenario, &mut ctx);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(result, RecoveryResult::Recovered { steps_taken: 2 });
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn mcp_handshake_recipe_uses_abort_escalation_policy() {
|
|
||||||
// given
|
|
||||||
let recipe = recipe_for(&FailureScenario::McpHandshakeFailure);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(recipe.escalation_policy, EscalationPolicy::Abort);
|
|
||||||
assert_eq!(recipe.max_attempts, 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,461 +0,0 @@
|
||||||
use std::env;
|
|
||||||
use std::fmt::{Display, Formatter};
|
|
||||||
use std::fs;
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
use std::time::UNIX_EPOCH;
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
use crate::session::{Session, SessionError};
|
|
||||||
use crate::worker_boot::{Worker, WorkerReadySnapshot, WorkerRegistry, WorkerStatus};
|
|
||||||
|
|
||||||
pub const PRIMARY_SESSION_EXTENSION: &str = "jsonl";
|
|
||||||
pub const LEGACY_SESSION_EXTENSION: &str = "json";
|
|
||||||
pub const LATEST_SESSION_REFERENCE: &str = "latest";
|
|
||||||
|
|
||||||
const SESSION_REFERENCE_ALIASES: &[&str] = &[LATEST_SESSION_REFERENCE, "last", "recent"];
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub struct SessionHandle {
|
|
||||||
pub id: String,
|
|
||||||
pub path: PathBuf,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub struct ManagedSessionSummary {
|
|
||||||
pub id: String,
|
|
||||||
pub path: PathBuf,
|
|
||||||
pub modified_epoch_millis: u128,
|
|
||||||
pub message_count: usize,
|
|
||||||
pub parent_session_id: Option<String>,
|
|
||||||
pub branch_name: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub struct LoadedManagedSession {
|
|
||||||
pub handle: SessionHandle,
|
|
||||||
pub session: Session,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub struct ForkedManagedSession {
|
|
||||||
pub parent_session_id: String,
|
|
||||||
pub handle: SessionHandle,
|
|
||||||
pub session: Session,
|
|
||||||
pub branch_name: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum SessionControlError {
|
|
||||||
Io(std::io::Error),
|
|
||||||
Session(SessionError),
|
|
||||||
Format(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for SessionControlError {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::Io(error) => write!(f, "{error}"),
|
|
||||||
Self::Session(error) => write!(f, "{error}"),
|
|
||||||
Self::Format(error) => write!(f, "{error}"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for SessionControlError {}
|
|
||||||
|
|
||||||
impl From<std::io::Error> for SessionControlError {
|
|
||||||
fn from(value: std::io::Error) -> Self {
|
|
||||||
Self::Io(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<SessionError> for SessionControlError {
|
|
||||||
fn from(value: SessionError) -> Self {
|
|
||||||
Self::Session(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn sessions_dir() -> Result<PathBuf, SessionControlError> {
|
|
||||||
managed_sessions_dir_for(env::current_dir()?)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn managed_sessions_dir_for(
|
|
||||||
base_dir: impl AsRef<Path>,
|
|
||||||
) -> Result<PathBuf, SessionControlError> {
|
|
||||||
let path = base_dir.as_ref().join(".claw").join("sessions");
|
|
||||||
fs::create_dir_all(&path)?;
|
|
||||||
Ok(path)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn create_managed_session_handle(
|
|
||||||
session_id: &str,
|
|
||||||
) -> Result<SessionHandle, SessionControlError> {
|
|
||||||
create_managed_session_handle_for(env::current_dir()?, session_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn create_managed_session_handle_for(
|
|
||||||
base_dir: impl AsRef<Path>,
|
|
||||||
session_id: &str,
|
|
||||||
) -> Result<SessionHandle, SessionControlError> {
|
|
||||||
let id = session_id.to_string();
|
|
||||||
let path =
|
|
||||||
managed_sessions_dir_for(base_dir)?.join(format!("{id}.{PRIMARY_SESSION_EXTENSION}"));
|
|
||||||
Ok(SessionHandle { id, path })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn resolve_session_reference(reference: &str) -> Result<SessionHandle, SessionControlError> {
|
|
||||||
resolve_session_reference_for(env::current_dir()?, reference)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn resolve_session_reference_for(
|
|
||||||
base_dir: impl AsRef<Path>,
|
|
||||||
reference: &str,
|
|
||||||
) -> Result<SessionHandle, SessionControlError> {
|
|
||||||
let base_dir = base_dir.as_ref();
|
|
||||||
if is_session_reference_alias(reference) {
|
|
||||||
let latest = latest_managed_session_for(base_dir)?;
|
|
||||||
return Ok(SessionHandle {
|
|
||||||
id: latest.id,
|
|
||||||
path: latest.path,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
let direct = PathBuf::from(reference);
|
|
||||||
let candidate = if direct.is_absolute() {
|
|
||||||
direct.clone()
|
|
||||||
} else {
|
|
||||||
base_dir.join(&direct)
|
|
||||||
};
|
|
||||||
let looks_like_path = direct.extension().is_some() || direct.components().count() > 1;
|
|
||||||
let path = if candidate.exists() {
|
|
||||||
candidate
|
|
||||||
} else if looks_like_path {
|
|
||||||
return Err(SessionControlError::Format(
|
|
||||||
format_missing_session_reference(reference),
|
|
||||||
));
|
|
||||||
} else {
|
|
||||||
resolve_managed_session_path_for(base_dir, reference)?
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(SessionHandle {
|
|
||||||
id: session_id_from_path(&path).unwrap_or_else(|| reference.to_string()),
|
|
||||||
path,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn resolve_managed_session_path(session_id: &str) -> Result<PathBuf, SessionControlError> {
|
|
||||||
resolve_managed_session_path_for(env::current_dir()?, session_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn resolve_managed_session_path_for(
|
|
||||||
base_dir: impl AsRef<Path>,
|
|
||||||
session_id: &str,
|
|
||||||
) -> Result<PathBuf, SessionControlError> {
|
|
||||||
let directory = managed_sessions_dir_for(base_dir)?;
|
|
||||||
for extension in [PRIMARY_SESSION_EXTENSION, LEGACY_SESSION_EXTENSION] {
|
|
||||||
let path = directory.join(format!("{session_id}.{extension}"));
|
|
||||||
if path.exists() {
|
|
||||||
return Ok(path);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(SessionControlError::Format(
|
|
||||||
format_missing_session_reference(session_id),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn is_managed_session_file(path: &Path) -> bool {
|
|
||||||
path.extension()
|
|
||||||
.and_then(|ext| ext.to_str())
|
|
||||||
.is_some_and(|extension| {
|
|
||||||
extension == PRIMARY_SESSION_EXTENSION || extension == LEGACY_SESSION_EXTENSION
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn list_managed_sessions() -> Result<Vec<ManagedSessionSummary>, SessionControlError> {
|
|
||||||
list_managed_sessions_for(env::current_dir()?)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn list_managed_sessions_for(
|
|
||||||
base_dir: impl AsRef<Path>,
|
|
||||||
) -> Result<Vec<ManagedSessionSummary>, SessionControlError> {
|
|
||||||
let mut sessions = Vec::new();
|
|
||||||
for entry in fs::read_dir(managed_sessions_dir_for(base_dir)?)? {
|
|
||||||
let entry = entry?;
|
|
||||||
let path = entry.path();
|
|
||||||
if !is_managed_session_file(&path) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let metadata = entry.metadata()?;
|
|
||||||
let modified_epoch_millis = metadata
|
|
||||||
.modified()
|
|
||||||
.ok()
|
|
||||||
.and_then(|time| time.duration_since(UNIX_EPOCH).ok())
|
|
||||||
.map(|duration| duration.as_millis())
|
|
||||||
.unwrap_or_default();
|
|
||||||
let (id, message_count, parent_session_id, branch_name) =
|
|
||||||
match Session::load_from_path(&path) {
|
|
||||||
Ok(session) => {
|
|
||||||
let parent_session_id = session
|
|
||||||
.fork
|
|
||||||
.as_ref()
|
|
||||||
.map(|fork| fork.parent_session_id.clone());
|
|
||||||
let branch_name = session
|
|
||||||
.fork
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|fork| fork.branch_name.clone());
|
|
||||||
(
|
|
||||||
session.session_id,
|
|
||||||
session.messages.len(),
|
|
||||||
parent_session_id,
|
|
||||||
branch_name,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
Err(_) => (
|
|
||||||
path.file_stem()
|
|
||||||
.and_then(|value| value.to_str())
|
|
||||||
.unwrap_or("unknown")
|
|
||||||
.to_string(),
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
sessions.push(ManagedSessionSummary {
|
|
||||||
id,
|
|
||||||
path,
|
|
||||||
modified_epoch_millis,
|
|
||||||
message_count,
|
|
||||||
parent_session_id,
|
|
||||||
branch_name,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
sessions.sort_by(|left, right| {
|
|
||||||
right
|
|
||||||
.modified_epoch_millis
|
|
||||||
.cmp(&left.modified_epoch_millis)
|
|
||||||
.then_with(|| right.id.cmp(&left.id))
|
|
||||||
});
|
|
||||||
Ok(sessions)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn latest_managed_session() -> Result<ManagedSessionSummary, SessionControlError> {
|
|
||||||
latest_managed_session_for(env::current_dir()?)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn latest_managed_session_for(
|
|
||||||
base_dir: impl AsRef<Path>,
|
|
||||||
) -> Result<ManagedSessionSummary, SessionControlError> {
|
|
||||||
list_managed_sessions_for(base_dir)?
|
|
||||||
.into_iter()
|
|
||||||
.next()
|
|
||||||
.ok_or_else(|| SessionControlError::Format(format_no_managed_sessions()))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_managed_session(reference: &str) -> Result<LoadedManagedSession, SessionControlError> {
|
|
||||||
load_managed_session_for(env::current_dir()?, reference)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_managed_session_for(
|
|
||||||
base_dir: impl AsRef<Path>,
|
|
||||||
reference: &str,
|
|
||||||
) -> Result<LoadedManagedSession, SessionControlError> {
|
|
||||||
let handle = resolve_session_reference_for(base_dir, reference)?;
|
|
||||||
let session = Session::load_from_path(&handle.path)?;
|
|
||||||
Ok(LoadedManagedSession {
|
|
||||||
handle: SessionHandle {
|
|
||||||
id: session.session_id.clone(),
|
|
||||||
path: handle.path,
|
|
||||||
},
|
|
||||||
session,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fork_managed_session(
|
|
||||||
session: &Session,
|
|
||||||
branch_name: Option<String>,
|
|
||||||
) -> Result<ForkedManagedSession, SessionControlError> {
|
|
||||||
fork_managed_session_for(env::current_dir()?, session, branch_name)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fork_managed_session_for(
|
|
||||||
base_dir: impl AsRef<Path>,
|
|
||||||
session: &Session,
|
|
||||||
branch_name: Option<String>,
|
|
||||||
) -> Result<ForkedManagedSession, SessionControlError> {
|
|
||||||
let parent_session_id = session.session_id.clone();
|
|
||||||
let forked = session.fork(branch_name);
|
|
||||||
let handle = create_managed_session_handle_for(base_dir, &forked.session_id)?;
|
|
||||||
let branch_name = forked
|
|
||||||
.fork
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|fork| fork.branch_name.clone());
|
|
||||||
let forked = forked.with_persistence_path(handle.path.clone());
|
|
||||||
forked.save_to_path(&handle.path)?;
|
|
||||||
Ok(ForkedManagedSession {
|
|
||||||
parent_session_id,
|
|
||||||
handle,
|
|
||||||
session: forked,
|
|
||||||
branch_name,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn is_session_reference_alias(reference: &str) -> bool {
|
|
||||||
SESSION_REFERENCE_ALIASES
|
|
||||||
.iter()
|
|
||||||
.any(|alias| reference.eq_ignore_ascii_case(alias))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn session_id_from_path(path: &Path) -> Option<String> {
|
|
||||||
path.file_name()
|
|
||||||
.and_then(|value| value.to_str())
|
|
||||||
.and_then(|name| {
|
|
||||||
name.strip_suffix(&format!(".{PRIMARY_SESSION_EXTENSION}"))
|
|
||||||
.or_else(|| name.strip_suffix(&format!(".{LEGACY_SESSION_EXTENSION}")))
|
|
||||||
})
|
|
||||||
.map(ToOwned::to_owned)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format_missing_session_reference(reference: &str) -> String {
|
|
||||||
format!(
|
|
||||||
"session not found: {reference}\nHint: managed sessions live in .claw/sessions/. Try `{LATEST_SESSION_REFERENCE}` for the most recent session or `/session list` in the REPL."
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format_no_managed_sessions() -> String {
|
|
||||||
format!(
|
|
||||||
"no managed sessions found in .claw/sessions/\nStart `claw` to create a session, then rerun with `--resume {LATEST_SESSION_REFERENCE}`."
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::{
|
|
||||||
create_managed_session_handle_for, fork_managed_session_for, is_session_reference_alias,
|
|
||||||
list_managed_sessions_for, load_managed_session_for, resolve_session_reference_for,
|
|
||||||
ManagedSessionSummary, LATEST_SESSION_REFERENCE,
|
|
||||||
};
|
|
||||||
use crate::session::Session;
|
|
||||||
use std::fs;
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
|
||||||
|
|
||||||
fn temp_dir() -> PathBuf {
|
|
||||||
let nanos = SystemTime::now()
|
|
||||||
.duration_since(UNIX_EPOCH)
|
|
||||||
.expect("time should be after epoch")
|
|
||||||
.as_nanos();
|
|
||||||
std::env::temp_dir().join(format!("runtime-session-control-{nanos}"))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn persist_session(root: &Path, text: &str) -> Session {
|
|
||||||
let mut session = Session::new();
|
|
||||||
session
|
|
||||||
.push_user_text(text)
|
|
||||||
.expect("session message should save");
|
|
||||||
let handle = create_managed_session_handle_for(root, &session.session_id)
|
|
||||||
.expect("managed session handle should build");
|
|
||||||
let session = session.with_persistence_path(handle.path.clone());
|
|
||||||
session
|
|
||||||
.save_to_path(&handle.path)
|
|
||||||
.expect("session should persist");
|
|
||||||
session
|
|
||||||
}
|
|
||||||
|
|
||||||
fn wait_for_next_millisecond() {
|
|
||||||
let start = SystemTime::now()
|
|
||||||
.duration_since(UNIX_EPOCH)
|
|
||||||
.expect("time should be after epoch")
|
|
||||||
.as_millis();
|
|
||||||
while SystemTime::now()
|
|
||||||
.duration_since(UNIX_EPOCH)
|
|
||||||
.expect("time should be after epoch")
|
|
||||||
.as_millis()
|
|
||||||
<= start
|
|
||||||
{}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn summary_by_id<'a>(
|
|
||||||
summaries: &'a [ManagedSessionSummary],
|
|
||||||
id: &str,
|
|
||||||
) -> &'a ManagedSessionSummary {
|
|
||||||
summaries
|
|
||||||
.iter()
|
|
||||||
.find(|summary| summary.id == id)
|
|
||||||
.expect("session summary should exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn creates_and_lists_managed_sessions() {
|
|
||||||
// given
|
|
||||||
let root = temp_dir();
|
|
||||||
fs::create_dir_all(&root).expect("root dir should exist");
|
|
||||||
let older = persist_session(&root, "older session");
|
|
||||||
wait_for_next_millisecond();
|
|
||||||
let newer = persist_session(&root, "newer session");
|
|
||||||
|
|
||||||
// when
|
|
||||||
let sessions = list_managed_sessions_for(&root).expect("managed sessions should list");
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(sessions.len(), 2);
|
|
||||||
assert_eq!(sessions[0].id, newer.session_id);
|
|
||||||
assert_eq!(summary_by_id(&sessions, &older.session_id).message_count, 1);
|
|
||||||
assert_eq!(summary_by_id(&sessions, &newer.session_id).message_count, 1);
|
|
||||||
fs::remove_dir_all(root).expect("temp dir should clean up");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn resolves_latest_alias_and_loads_session_from_workspace_root() {
|
|
||||||
// given
|
|
||||||
let root = temp_dir();
|
|
||||||
fs::create_dir_all(&root).expect("root dir should exist");
|
|
||||||
let older = persist_session(&root, "older session");
|
|
||||||
wait_for_next_millisecond();
|
|
||||||
let newer = persist_session(&root, "newer session");
|
|
||||||
|
|
||||||
// when
|
|
||||||
let handle = resolve_session_reference_for(&root, LATEST_SESSION_REFERENCE)
|
|
||||||
.expect("latest alias should resolve");
|
|
||||||
let loaded = load_managed_session_for(&root, "recent")
|
|
||||||
.expect("recent alias should load the latest session");
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(handle.id, newer.session_id);
|
|
||||||
assert_eq!(loaded.handle.id, newer.session_id);
|
|
||||||
assert_eq!(loaded.session.messages.len(), 1);
|
|
||||||
assert_ne!(loaded.handle.id, older.session_id);
|
|
||||||
assert!(is_session_reference_alias("last"));
|
|
||||||
fs::remove_dir_all(root).expect("temp dir should clean up");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn forks_session_into_managed_storage_with_lineage() {
|
|
||||||
// given
|
|
||||||
let root = temp_dir();
|
|
||||||
fs::create_dir_all(&root).expect("root dir should exist");
|
|
||||||
let source = persist_session(&root, "parent session");
|
|
||||||
|
|
||||||
// when
|
|
||||||
let forked = fork_managed_session_for(&root, &source, Some("incident-review".to_string()))
|
|
||||||
.expect("session should fork");
|
|
||||||
let sessions = list_managed_sessions_for(&root).expect("managed sessions should list");
|
|
||||||
let summary = summary_by_id(&sessions, &forked.handle.id);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(forked.parent_session_id, source.session_id);
|
|
||||||
assert_eq!(forked.branch_name.as_deref(), Some("incident-review"));
|
|
||||||
assert_eq!(
|
|
||||||
summary.parent_session_id.as_deref(),
|
|
||||||
Some(source.session_id.as_str())
|
|
||||||
);
|
|
||||||
assert_eq!(summary.branch_name.as_deref(), Some("incident-review"));
|
|
||||||
assert_eq!(
|
|
||||||
forked.session.persistence_path(),
|
|
||||||
Some(forked.handle.path.as_path())
|
|
||||||
);
|
|
||||||
fs::remove_dir_all(root).expect("temp dir should clean up");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,389 +0,0 @@
|
||||||
use std::path::Path;
|
|
||||||
use std::process::Command;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub enum BranchFreshness {
|
|
||||||
Fresh,
|
|
||||||
Stale {
|
|
||||||
commits_behind: usize,
|
|
||||||
missing_fixes: Vec<String>,
|
|
||||||
},
|
|
||||||
Diverged {
|
|
||||||
ahead: usize,
|
|
||||||
behind: usize,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
pub enum StaleBranchPolicy {
|
|
||||||
AutoRebase,
|
|
||||||
AutoMergeForward,
|
|
||||||
WarnOnly,
|
|
||||||
Block,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub enum StaleBranchEvent {
|
|
||||||
BranchStaleAgainstMain {
|
|
||||||
branch: String,
|
|
||||||
commits_behind: usize,
|
|
||||||
missing_fixes: Vec<String>,
|
|
||||||
},
|
|
||||||
RebaseAttempted {
|
|
||||||
branch: String,
|
|
||||||
result: String,
|
|
||||||
},
|
|
||||||
MergeForwardAttempted {
|
|
||||||
branch: String,
|
|
||||||
result: String,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub enum StaleBranchAction {
|
|
||||||
Noop,
|
|
||||||
Warn { message: String },
|
|
||||||
Block { message: String },
|
|
||||||
Rebase,
|
|
||||||
MergeForward,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn check_freshness(branch: &str, main_ref: &str) -> BranchFreshness {
|
|
||||||
check_freshness_in(branch, main_ref, Path::new("."))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn apply_policy(freshness: &BranchFreshness, policy: StaleBranchPolicy) -> StaleBranchAction {
|
|
||||||
match freshness {
|
|
||||||
BranchFreshness::Fresh => StaleBranchAction::Noop,
|
|
||||||
BranchFreshness::Stale {
|
|
||||||
commits_behind,
|
|
||||||
missing_fixes,
|
|
||||||
} => match policy {
|
|
||||||
StaleBranchPolicy::WarnOnly => StaleBranchAction::Warn {
|
|
||||||
message: format!(
|
|
||||||
"Branch is {commits_behind} commit(s) behind main. Missing fixes: {}",
|
|
||||||
if missing_fixes.is_empty() {
|
|
||||||
"(none)".to_string()
|
|
||||||
} else {
|
|
||||||
missing_fixes.join("; ")
|
|
||||||
}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
StaleBranchPolicy::Block => StaleBranchAction::Block {
|
|
||||||
message: format!(
|
|
||||||
"Branch is {commits_behind} commit(s) behind main and must be updated before proceeding."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
StaleBranchPolicy::AutoRebase => StaleBranchAction::Rebase,
|
|
||||||
StaleBranchPolicy::AutoMergeForward => StaleBranchAction::MergeForward,
|
|
||||||
},
|
|
||||||
BranchFreshness::Diverged { ahead, behind } => match policy {
|
|
||||||
StaleBranchPolicy::WarnOnly => StaleBranchAction::Warn {
|
|
||||||
message: format!(
|
|
||||||
"Branch has diverged: {ahead} commit(s) ahead, {behind} commit(s) behind main."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
StaleBranchPolicy::Block => StaleBranchAction::Block {
|
|
||||||
message: format!(
|
|
||||||
"Branch has diverged ({ahead} ahead, {behind} behind) and must be reconciled before proceeding."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
StaleBranchPolicy::AutoRebase => StaleBranchAction::Rebase,
|
|
||||||
StaleBranchPolicy::AutoMergeForward => StaleBranchAction::MergeForward,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn check_freshness_in(
|
|
||||||
branch: &str,
|
|
||||||
main_ref: &str,
|
|
||||||
repo_path: &Path,
|
|
||||||
) -> BranchFreshness {
|
|
||||||
let behind = rev_list_count(main_ref, branch, repo_path);
|
|
||||||
let ahead = rev_list_count(branch, main_ref, repo_path);
|
|
||||||
|
|
||||||
if behind == 0 {
|
|
||||||
return BranchFreshness::Fresh;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ahead > 0 {
|
|
||||||
return BranchFreshness::Diverged { ahead, behind };
|
|
||||||
}
|
|
||||||
|
|
||||||
let missing_fixes = missing_fix_subjects(main_ref, branch, repo_path);
|
|
||||||
BranchFreshness::Stale {
|
|
||||||
commits_behind: behind,
|
|
||||||
missing_fixes,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rev_list_count(a: &str, b: &str, repo_path: &Path) -> usize {
|
|
||||||
let output = Command::new("git")
|
|
||||||
.args(["rev-list", "--count", &format!("{b}..{a}")])
|
|
||||||
.current_dir(repo_path)
|
|
||||||
.output();
|
|
||||||
match output {
|
|
||||||
Ok(o) if o.status.success() => String::from_utf8_lossy(&o.stdout)
|
|
||||||
.trim()
|
|
||||||
.parse::<usize>()
|
|
||||||
.unwrap_or(0),
|
|
||||||
_ => 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn missing_fix_subjects(a: &str, b: &str, repo_path: &Path) -> Vec<String> {
|
|
||||||
let output = Command::new("git")
|
|
||||||
.args(["log", "--format=%s", &format!("{b}..{a}")])
|
|
||||||
.current_dir(repo_path)
|
|
||||||
.output();
|
|
||||||
match output {
|
|
||||||
Ok(o) if o.status.success() => String::from_utf8_lossy(&o.stdout)
|
|
||||||
.lines()
|
|
||||||
.filter(|l| !l.is_empty())
|
|
||||||
.map(String::from)
|
|
||||||
.collect(),
|
|
||||||
_ => Vec::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use std::fs;
|
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
|
||||||
|
|
||||||
fn temp_dir() -> std::path::PathBuf {
|
|
||||||
let nanos = SystemTime::now()
|
|
||||||
.duration_since(UNIX_EPOCH)
|
|
||||||
.expect("time should be after epoch")
|
|
||||||
.as_nanos();
|
|
||||||
std::env::temp_dir().join(format!("runtime-stale-branch-{nanos}"))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn init_repo(path: &Path) {
|
|
||||||
fs::create_dir_all(path).expect("create repo dir");
|
|
||||||
run(path, &["init", "--quiet", "-b", "main"]);
|
|
||||||
run(path, &["config", "user.email", "tests@example.com"]);
|
|
||||||
run(path, &["config", "user.name", "Stale Branch Tests"]);
|
|
||||||
fs::write(path.join("init.txt"), "initial\n").expect("write init file");
|
|
||||||
run(path, &["add", "."]);
|
|
||||||
run(path, &["commit", "-m", "initial commit", "--quiet"]);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(cwd: &Path, args: &[&str]) {
|
|
||||||
let status = Command::new("git")
|
|
||||||
.args(args)
|
|
||||||
.current_dir(cwd)
|
|
||||||
.status()
|
|
||||||
.unwrap_or_else(|e| panic!("git {} failed to execute: {e}", args.join(" ")));
|
|
||||||
assert!(
|
|
||||||
status.success(),
|
|
||||||
"git {} exited with {status}",
|
|
||||||
args.join(" ")
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn commit_file(repo: &Path, name: &str, msg: &str) {
|
|
||||||
fs::write(repo.join(name), format!("{msg}\n")).expect("write file");
|
|
||||||
run(repo, &["add", name]);
|
|
||||||
run(repo, &["commit", "-m", msg, "--quiet"]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn fresh_branch_passes() {
|
|
||||||
let root = temp_dir();
|
|
||||||
init_repo(&root);
|
|
||||||
|
|
||||||
// given
|
|
||||||
run(&root, &["checkout", "-b", "topic"]);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let freshness = check_freshness_in("topic", "main", &root);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(freshness, BranchFreshness::Fresh);
|
|
||||||
|
|
||||||
fs::remove_dir_all(&root).expect("cleanup");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn fresh_branch_ahead_of_main_still_fresh() {
|
|
||||||
let root = temp_dir();
|
|
||||||
init_repo(&root);
|
|
||||||
|
|
||||||
// given
|
|
||||||
run(&root, &["checkout", "-b", "topic"]);
|
|
||||||
commit_file(&root, "feature.txt", "add feature");
|
|
||||||
|
|
||||||
// when
|
|
||||||
let freshness = check_freshness_in("topic", "main", &root);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(freshness, BranchFreshness::Fresh);
|
|
||||||
|
|
||||||
fs::remove_dir_all(&root).expect("cleanup");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn stale_branch_detected_with_correct_behind_count_and_missing_fixes() {
|
|
||||||
let root = temp_dir();
|
|
||||||
init_repo(&root);
|
|
||||||
|
|
||||||
// given
|
|
||||||
run(&root, &["checkout", "-b", "topic"]);
|
|
||||||
run(&root, &["checkout", "main"]);
|
|
||||||
commit_file(&root, "fix1.txt", "fix: resolve timeout");
|
|
||||||
commit_file(&root, "fix2.txt", "fix: handle null pointer");
|
|
||||||
|
|
||||||
// when
|
|
||||||
let freshness = check_freshness_in("topic", "main", &root);
|
|
||||||
|
|
||||||
// then
|
|
||||||
match freshness {
|
|
||||||
BranchFreshness::Stale {
|
|
||||||
commits_behind,
|
|
||||||
missing_fixes,
|
|
||||||
} => {
|
|
||||||
assert_eq!(commits_behind, 2);
|
|
||||||
assert_eq!(missing_fixes.len(), 2);
|
|
||||||
assert_eq!(missing_fixes[0], "fix: handle null pointer");
|
|
||||||
assert_eq!(missing_fixes[1], "fix: resolve timeout");
|
|
||||||
}
|
|
||||||
other => panic!("expected Stale, got {other:?}"),
|
|
||||||
}
|
|
||||||
|
|
||||||
fs::remove_dir_all(&root).expect("cleanup");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn diverged_branch_detection() {
|
|
||||||
let root = temp_dir();
|
|
||||||
init_repo(&root);
|
|
||||||
|
|
||||||
// given
|
|
||||||
run(&root, &["checkout", "-b", "topic"]);
|
|
||||||
commit_file(&root, "topic_work.txt", "topic work");
|
|
||||||
run(&root, &["checkout", "main"]);
|
|
||||||
commit_file(&root, "main_fix.txt", "main fix");
|
|
||||||
|
|
||||||
// when
|
|
||||||
let freshness = check_freshness_in("topic", "main", &root);
|
|
||||||
|
|
||||||
// then
|
|
||||||
match freshness {
|
|
||||||
BranchFreshness::Diverged { ahead, behind } => {
|
|
||||||
assert_eq!(ahead, 1);
|
|
||||||
assert_eq!(behind, 1);
|
|
||||||
}
|
|
||||||
other => panic!("expected Diverged, got {other:?}"),
|
|
||||||
}
|
|
||||||
|
|
||||||
fs::remove_dir_all(&root).expect("cleanup");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn policy_noop_for_fresh_branch() {
|
|
||||||
// given
|
|
||||||
let freshness = BranchFreshness::Fresh;
|
|
||||||
|
|
||||||
// when
|
|
||||||
let action = apply_policy(&freshness, StaleBranchPolicy::WarnOnly);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(action, StaleBranchAction::Noop);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn policy_warn_for_stale_branch() {
|
|
||||||
// given
|
|
||||||
let freshness = BranchFreshness::Stale {
|
|
||||||
commits_behind: 3,
|
|
||||||
missing_fixes: vec!["fix: timeout".into(), "fix: null ptr".into()],
|
|
||||||
};
|
|
||||||
|
|
||||||
// when
|
|
||||||
let action = apply_policy(&freshness, StaleBranchPolicy::WarnOnly);
|
|
||||||
|
|
||||||
// then
|
|
||||||
match action {
|
|
||||||
StaleBranchAction::Warn { message } => {
|
|
||||||
assert!(message.contains("3 commit(s) behind"));
|
|
||||||
assert!(message.contains("fix: timeout"));
|
|
||||||
assert!(message.contains("fix: null ptr"));
|
|
||||||
}
|
|
||||||
other => panic!("expected Warn, got {other:?}"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn policy_block_for_stale_branch() {
|
|
||||||
// given
|
|
||||||
let freshness = BranchFreshness::Stale {
|
|
||||||
commits_behind: 1,
|
|
||||||
missing_fixes: vec!["hotfix".into()],
|
|
||||||
};
|
|
||||||
|
|
||||||
// when
|
|
||||||
let action = apply_policy(&freshness, StaleBranchPolicy::Block);
|
|
||||||
|
|
||||||
// then
|
|
||||||
match action {
|
|
||||||
StaleBranchAction::Block { message } => {
|
|
||||||
assert!(message.contains("1 commit(s) behind"));
|
|
||||||
}
|
|
||||||
other => panic!("expected Block, got {other:?}"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn policy_auto_rebase_for_stale_branch() {
|
|
||||||
// given
|
|
||||||
let freshness = BranchFreshness::Stale {
|
|
||||||
commits_behind: 2,
|
|
||||||
missing_fixes: vec![],
|
|
||||||
};
|
|
||||||
|
|
||||||
// when
|
|
||||||
let action = apply_policy(&freshness, StaleBranchPolicy::AutoRebase);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(action, StaleBranchAction::Rebase);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn policy_auto_merge_forward_for_diverged_branch() {
|
|
||||||
// given
|
|
||||||
let freshness = BranchFreshness::Diverged {
|
|
||||||
ahead: 5,
|
|
||||||
behind: 2,
|
|
||||||
};
|
|
||||||
|
|
||||||
// when
|
|
||||||
let action = apply_policy(&freshness, StaleBranchPolicy::AutoMergeForward);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(action, StaleBranchAction::MergeForward);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn policy_warn_for_diverged_branch() {
|
|
||||||
// given
|
|
||||||
let freshness = BranchFreshness::Diverged {
|
|
||||||
ahead: 3,
|
|
||||||
behind: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
// when
|
|
||||||
let action = apply_policy(&freshness, StaleBranchPolicy::WarnOnly);
|
|
||||||
|
|
||||||
// then
|
|
||||||
match action {
|
|
||||||
StaleBranchAction::Warn { message } => {
|
|
||||||
assert!(message.contains("diverged"));
|
|
||||||
assert!(message.contains("3 commit(s) ahead"));
|
|
||||||
assert!(message.contains("1 commit(s) behind"));
|
|
||||||
}
|
|
||||||
other => panic!("expected Warn, got {other:?}"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,300 +0,0 @@
|
||||||
use std::collections::BTreeSet;
|
|
||||||
|
|
||||||
const DEFAULT_MAX_CHARS: usize = 1_200;
|
|
||||||
const DEFAULT_MAX_LINES: usize = 24;
|
|
||||||
const DEFAULT_MAX_LINE_CHARS: usize = 160;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
pub struct SummaryCompressionBudget {
|
|
||||||
pub max_chars: usize,
|
|
||||||
pub max_lines: usize,
|
|
||||||
pub max_line_chars: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for SummaryCompressionBudget {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
max_chars: DEFAULT_MAX_CHARS,
|
|
||||||
max_lines: DEFAULT_MAX_LINES,
|
|
||||||
max_line_chars: DEFAULT_MAX_LINE_CHARS,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub struct SummaryCompressionResult {
|
|
||||||
pub summary: String,
|
|
||||||
pub original_chars: usize,
|
|
||||||
pub compressed_chars: usize,
|
|
||||||
pub original_lines: usize,
|
|
||||||
pub compressed_lines: usize,
|
|
||||||
pub removed_duplicate_lines: usize,
|
|
||||||
pub omitted_lines: usize,
|
|
||||||
pub truncated: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn compress_summary(
|
|
||||||
summary: &str,
|
|
||||||
budget: SummaryCompressionBudget,
|
|
||||||
) -> SummaryCompressionResult {
|
|
||||||
let original_chars = summary.chars().count();
|
|
||||||
let original_lines = summary.lines().count();
|
|
||||||
|
|
||||||
let normalized = normalize_lines(summary, budget.max_line_chars);
|
|
||||||
if normalized.lines.is_empty() || budget.max_chars == 0 || budget.max_lines == 0 {
|
|
||||||
return SummaryCompressionResult {
|
|
||||||
summary: String::new(),
|
|
||||||
original_chars,
|
|
||||||
compressed_chars: 0,
|
|
||||||
original_lines,
|
|
||||||
compressed_lines: 0,
|
|
||||||
removed_duplicate_lines: normalized.removed_duplicate_lines,
|
|
||||||
omitted_lines: normalized.lines.len(),
|
|
||||||
truncated: original_chars > 0,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
let selected = select_line_indexes(&normalized.lines, budget);
|
|
||||||
let mut compressed_lines = selected
|
|
||||||
.iter()
|
|
||||||
.map(|index| normalized.lines[*index].clone())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
if compressed_lines.is_empty() {
|
|
||||||
compressed_lines.push(truncate_line(&normalized.lines[0], budget.max_chars));
|
|
||||||
}
|
|
||||||
let omitted_lines = normalized
|
|
||||||
.lines
|
|
||||||
.len()
|
|
||||||
.saturating_sub(compressed_lines.len());
|
|
||||||
|
|
||||||
if omitted_lines > 0 {
|
|
||||||
let omission_notice = omission_notice(omitted_lines);
|
|
||||||
push_line_with_budget(&mut compressed_lines, omission_notice, budget);
|
|
||||||
}
|
|
||||||
|
|
||||||
let compressed_summary = compressed_lines.join("\n");
|
|
||||||
|
|
||||||
SummaryCompressionResult {
|
|
||||||
compressed_chars: compressed_summary.chars().count(),
|
|
||||||
compressed_lines: compressed_lines.len(),
|
|
||||||
removed_duplicate_lines: normalized.removed_duplicate_lines,
|
|
||||||
omitted_lines,
|
|
||||||
truncated: compressed_summary != summary.trim(),
|
|
||||||
summary: compressed_summary,
|
|
||||||
original_chars,
|
|
||||||
original_lines,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn compress_summary_text(summary: &str) -> String {
|
|
||||||
compress_summary(summary, SummaryCompressionBudget::default()).summary
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
|
||||||
struct NormalizedSummary {
|
|
||||||
lines: Vec<String>,
|
|
||||||
removed_duplicate_lines: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn normalize_lines(summary: &str, max_line_chars: usize) -> NormalizedSummary {
|
|
||||||
let mut seen = BTreeSet::new();
|
|
||||||
let mut lines = Vec::new();
|
|
||||||
let mut removed_duplicate_lines = 0;
|
|
||||||
|
|
||||||
for raw_line in summary.lines() {
|
|
||||||
let normalized = collapse_inline_whitespace(raw_line);
|
|
||||||
if normalized.is_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let truncated = truncate_line(&normalized, max_line_chars);
|
|
||||||
let dedupe_key = dedupe_key(&truncated);
|
|
||||||
if !seen.insert(dedupe_key) {
|
|
||||||
removed_duplicate_lines += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
lines.push(truncated);
|
|
||||||
}
|
|
||||||
|
|
||||||
NormalizedSummary {
|
|
||||||
lines,
|
|
||||||
removed_duplicate_lines,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn select_line_indexes(lines: &[String], budget: SummaryCompressionBudget) -> Vec<usize> {
|
|
||||||
let mut selected = BTreeSet::<usize>::new();
|
|
||||||
|
|
||||||
for priority in 0..=3 {
|
|
||||||
for (index, line) in lines.iter().enumerate() {
|
|
||||||
if selected.contains(&index) || line_priority(line) != priority {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let candidate = selected
|
|
||||||
.iter()
|
|
||||||
.map(|selected_index| lines[*selected_index].as_str())
|
|
||||||
.chain(std::iter::once(line.as_str()))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
if candidate.len() > budget.max_lines {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if joined_char_count(&candidate) > budget.max_chars {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
selected.insert(index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
selected.into_iter().collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn push_line_with_budget(lines: &mut Vec<String>, line: String, budget: SummaryCompressionBudget) {
|
|
||||||
let candidate = lines
|
|
||||||
.iter()
|
|
||||||
.map(String::as_str)
|
|
||||||
.chain(std::iter::once(line.as_str()))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
if candidate.len() <= budget.max_lines && joined_char_count(&candidate) <= budget.max_chars {
|
|
||||||
lines.push(line);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn joined_char_count(lines: &[&str]) -> usize {
|
|
||||||
lines.iter().map(|line| line.chars().count()).sum::<usize>() + lines.len().saturating_sub(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn line_priority(line: &str) -> usize {
|
|
||||||
if line == "Summary:" || line == "Conversation summary:" || is_core_detail(line) {
|
|
||||||
0
|
|
||||||
} else if is_section_header(line) {
|
|
||||||
1
|
|
||||||
} else if line.starts_with("- ") || line.starts_with(" - ") {
|
|
||||||
2
|
|
||||||
} else {
|
|
||||||
3
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_core_detail(line: &str) -> bool {
|
|
||||||
[
|
|
||||||
"- Scope:",
|
|
||||||
"- Current work:",
|
|
||||||
"- Pending work:",
|
|
||||||
"- Key files referenced:",
|
|
||||||
"- Tools mentioned:",
|
|
||||||
"- Recent user requests:",
|
|
||||||
"- Previously compacted context:",
|
|
||||||
"- Newly compacted context:",
|
|
||||||
]
|
|
||||||
.iter()
|
|
||||||
.any(|prefix| line.starts_with(prefix))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_section_header(line: &str) -> bool {
|
|
||||||
line.ends_with(':')
|
|
||||||
}
|
|
||||||
|
|
||||||
fn omission_notice(omitted_lines: usize) -> String {
|
|
||||||
format!("- … {omitted_lines} additional line(s) omitted.")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn collapse_inline_whitespace(line: &str) -> String {
|
|
||||||
line.split_whitespace().collect::<Vec<_>>().join(" ")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn truncate_line(line: &str, max_chars: usize) -> String {
|
|
||||||
if max_chars == 0 || line.chars().count() <= max_chars {
|
|
||||||
return line.to_string();
|
|
||||||
}
|
|
||||||
|
|
||||||
if max_chars == 1 {
|
|
||||||
return "…".to_string();
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut truncated = line
|
|
||||||
.chars()
|
|
||||||
.take(max_chars.saturating_sub(1))
|
|
||||||
.collect::<String>();
|
|
||||||
truncated.push('…');
|
|
||||||
truncated
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dedupe_key(line: &str) -> String {
|
|
||||||
line.to_ascii_lowercase()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::{compress_summary, compress_summary_text, SummaryCompressionBudget};
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn collapses_whitespace_and_duplicate_lines() {
|
|
||||||
// given
|
|
||||||
let summary = "Conversation summary:\n\n- Scope: compact earlier messages.\n- Scope: compact earlier messages.\n- Current work: update runtime module.\n";
|
|
||||||
|
|
||||||
// when
|
|
||||||
let result = compress_summary(summary, SummaryCompressionBudget::default());
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(result.removed_duplicate_lines, 1);
|
|
||||||
assert!(result
|
|
||||||
.summary
|
|
||||||
.contains("- Scope: compact earlier messages."));
|
|
||||||
assert!(!result.summary.contains(" compact earlier"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn keeps_core_lines_when_budget_is_tight() {
|
|
||||||
// given
|
|
||||||
let summary = [
|
|
||||||
"Conversation summary:",
|
|
||||||
"- Scope: 18 earlier messages compacted.",
|
|
||||||
"- Current work: finish summary compression.",
|
|
||||||
"- Key timeline:",
|
|
||||||
" - user: asked for a working implementation.",
|
|
||||||
" - assistant: inspected runtime compaction flow.",
|
|
||||||
" - tool: cargo check succeeded.",
|
|
||||||
]
|
|
||||||
.join("\n");
|
|
||||||
|
|
||||||
// when
|
|
||||||
let result = compress_summary(
|
|
||||||
&summary,
|
|
||||||
SummaryCompressionBudget {
|
|
||||||
max_chars: 120,
|
|
||||||
max_lines: 3,
|
|
||||||
max_line_chars: 80,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert!(result.summary.contains("Conversation summary:"));
|
|
||||||
assert!(result
|
|
||||||
.summary
|
|
||||||
.contains("- Scope: 18 earlier messages compacted."));
|
|
||||||
assert!(result
|
|
||||||
.summary
|
|
||||||
.contains("- Current work: finish summary compression."));
|
|
||||||
assert!(result.omitted_lines > 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn provides_a_default_text_only_helper() {
|
|
||||||
// given
|
|
||||||
let summary = "Summary:\n\nA short line.";
|
|
||||||
|
|
||||||
// when
|
|
||||||
let compressed = compress_summary_text(summary);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(compressed, "Summary:\nA short line.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,591 +0,0 @@
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::Value as JsonValue;
|
|
||||||
use std::collections::BTreeMap;
|
|
||||||
use std::fmt::{Display, Formatter};
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
pub struct RepoConfig {
|
|
||||||
pub repo_root: PathBuf,
|
|
||||||
pub worktree_root: Option<PathBuf>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RepoConfig {
|
|
||||||
#[must_use]
|
|
||||||
pub fn dispatch_root(&self) -> &Path {
|
|
||||||
self.worktree_root
|
|
||||||
.as_deref()
|
|
||||||
.unwrap_or(self.repo_root.as_path())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum TaskScope {
|
|
||||||
SingleFile { path: PathBuf },
|
|
||||||
Module { crate_name: String },
|
|
||||||
Workspace,
|
|
||||||
Custom { paths: Vec<PathBuf> },
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TaskScope {
|
|
||||||
#[must_use]
|
|
||||||
pub fn resolve_paths(&self, repo_config: &RepoConfig) -> Vec<PathBuf> {
|
|
||||||
let dispatch_root = repo_config.dispatch_root();
|
|
||||||
match self {
|
|
||||||
Self::SingleFile { path } => vec![resolve_path(dispatch_root, path)],
|
|
||||||
Self::Module { crate_name } => vec![dispatch_root.join("crates").join(crate_name)],
|
|
||||||
Self::Workspace => vec![dispatch_root.to_path_buf()],
|
|
||||||
Self::Custom { paths } => paths
|
|
||||||
.iter()
|
|
||||||
.map(|path| resolve_path(dispatch_root, path))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for TaskScope {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::SingleFile { .. } => write!(f, "single_file"),
|
|
||||||
Self::Module { .. } => write!(f, "module"),
|
|
||||||
Self::Workspace => write!(f, "workspace"),
|
|
||||||
Self::Custom { .. } => write!(f, "custom"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum BranchPolicy {
|
|
||||||
CreateNew { prefix: String },
|
|
||||||
UseExisting { name: String },
|
|
||||||
WorktreeIsolated,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for BranchPolicy {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::CreateNew { .. } => write!(f, "create_new"),
|
|
||||||
Self::UseExisting { .. } => write!(f, "use_existing"),
|
|
||||||
Self::WorktreeIsolated => write!(f, "worktree_isolated"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum CommitPolicy {
|
|
||||||
CommitPerTask,
|
|
||||||
SquashOnMerge,
|
|
||||||
NoAutoCommit,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for CommitPolicy {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::CommitPerTask => write!(f, "commit_per_task"),
|
|
||||||
Self::SquashOnMerge => write!(f, "squash_on_merge"),
|
|
||||||
Self::NoAutoCommit => write!(f, "no_auto_commit"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum GreenLevel {
|
|
||||||
Package,
|
|
||||||
Workspace,
|
|
||||||
MergeReady,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for GreenLevel {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::Package => write!(f, "package"),
|
|
||||||
Self::Workspace => write!(f, "workspace"),
|
|
||||||
Self::MergeReady => write!(f, "merge_ready"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum AcceptanceTest {
|
|
||||||
CargoTest { filter: Option<String> },
|
|
||||||
CustomCommand { cmd: String },
|
|
||||||
GreenLevel { level: GreenLevel },
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for AcceptanceTest {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::CargoTest { .. } => write!(f, "cargo_test"),
|
|
||||||
Self::CustomCommand { .. } => write!(f, "custom_command"),
|
|
||||||
Self::GreenLevel { .. } => write!(f, "green_level"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum ReportingContract {
|
|
||||||
EventStream,
|
|
||||||
Summary,
|
|
||||||
Silent,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for ReportingContract {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::EventStream => write!(f, "event_stream"),
|
|
||||||
Self::Summary => write!(f, "summary"),
|
|
||||||
Self::Silent => write!(f, "silent"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum EscalationPolicy {
|
|
||||||
RetryThenEscalate { max_retries: u32 },
|
|
||||||
AutoEscalate,
|
|
||||||
NeverEscalate,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for EscalationPolicy {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::RetryThenEscalate { .. } => write!(f, "retry_then_escalate"),
|
|
||||||
Self::AutoEscalate => write!(f, "auto_escalate"),
|
|
||||||
Self::NeverEscalate => write!(f, "never_escalate"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
||||||
pub struct TaskPacket {
|
|
||||||
pub id: String,
|
|
||||||
pub objective: String,
|
|
||||||
pub scope: TaskScope,
|
|
||||||
pub repo_config: RepoConfig,
|
|
||||||
pub branch_policy: BranchPolicy,
|
|
||||||
pub acceptance_tests: Vec<AcceptanceTest>,
|
|
||||||
pub commit_policy: CommitPolicy,
|
|
||||||
pub reporting: ReportingContract,
|
|
||||||
pub escalation: EscalationPolicy,
|
|
||||||
pub created_at: u64,
|
|
||||||
pub metadata: BTreeMap<String, JsonValue>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TaskPacket {
|
|
||||||
#[must_use]
|
|
||||||
pub fn resolve_scope_paths(&self) -> Vec<PathBuf> {
|
|
||||||
self.scope.resolve_paths(&self.repo_config)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub struct TaskPacketValidationError {
|
|
||||||
errors: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TaskPacketValidationError {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new(errors: Vec<String>) -> Self {
|
|
||||||
Self { errors }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn errors(&self) -> &[String] {
|
|
||||||
&self.errors
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for TaskPacketValidationError {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "{}", self.errors.join("; "))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for TaskPacketValidationError {}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
|
||||||
pub struct ValidatedPacket(TaskPacket);
|
|
||||||
|
|
||||||
impl ValidatedPacket {
|
|
||||||
#[must_use]
|
|
||||||
pub fn packet(&self) -> &TaskPacket {
|
|
||||||
&self.0
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn into_inner(self) -> TaskPacket {
|
|
||||||
self.0
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn resolve_scope_paths(&self) -> Vec<PathBuf> {
|
|
||||||
self.0.resolve_scope_paths()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn validate_packet(packet: TaskPacket) -> Result<ValidatedPacket, TaskPacketValidationError> {
|
|
||||||
let mut errors = Vec::new();
|
|
||||||
|
|
||||||
if packet.id.trim().is_empty() {
|
|
||||||
errors.push("packet id must not be empty".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
if packet.objective.trim().is_empty() {
|
|
||||||
errors.push("packet objective must not be empty".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
if packet.repo_config.repo_root.as_os_str().is_empty() {
|
|
||||||
errors.push("repo_config repo_root must not be empty".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
if packet
|
|
||||||
.repo_config
|
|
||||||
.worktree_root
|
|
||||||
.as_ref()
|
|
||||||
.is_some_and(|path| path.as_os_str().is_empty())
|
|
||||||
{
|
|
||||||
errors.push("repo_config worktree_root must not be empty when present".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
validate_scope(&packet.scope, &mut errors);
|
|
||||||
validate_branch_policy(&packet.branch_policy, &mut errors);
|
|
||||||
validate_acceptance_tests(&packet.acceptance_tests, &mut errors);
|
|
||||||
validate_escalation_policy(packet.escalation, &mut errors);
|
|
||||||
|
|
||||||
if errors.is_empty() {
|
|
||||||
Ok(ValidatedPacket(packet))
|
|
||||||
} else {
|
|
||||||
Err(TaskPacketValidationError::new(errors))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn validate_scope(scope: &TaskScope, errors: &mut Vec<String>) {
|
|
||||||
match scope {
|
|
||||||
TaskScope::SingleFile { path } if path.as_os_str().is_empty() => {
|
|
||||||
errors.push("single_file scope path must not be empty".to_string());
|
|
||||||
}
|
|
||||||
TaskScope::Module { crate_name } if crate_name.trim().is_empty() => {
|
|
||||||
errors.push("module scope crate_name must not be empty".to_string());
|
|
||||||
}
|
|
||||||
TaskScope::Custom { paths } if paths.is_empty() => {
|
|
||||||
errors.push("custom scope paths must not be empty".to_string());
|
|
||||||
}
|
|
||||||
TaskScope::Custom { paths } => {
|
|
||||||
for (index, path) in paths.iter().enumerate() {
|
|
||||||
if path.as_os_str().is_empty() {
|
|
||||||
errors.push(format!("custom scope contains empty path at index {index}"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TaskScope::SingleFile { .. } | TaskScope::Module { .. } | TaskScope::Workspace => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn validate_branch_policy(branch_policy: &BranchPolicy, errors: &mut Vec<String>) {
|
|
||||||
match branch_policy {
|
|
||||||
BranchPolicy::CreateNew { prefix } if prefix.trim().is_empty() => {
|
|
||||||
errors.push("create_new branch prefix must not be empty".to_string());
|
|
||||||
}
|
|
||||||
BranchPolicy::UseExisting { name } if name.trim().is_empty() => {
|
|
||||||
errors.push("use_existing branch name must not be empty".to_string());
|
|
||||||
}
|
|
||||||
BranchPolicy::CreateNew { .. }
|
|
||||||
| BranchPolicy::UseExisting { .. }
|
|
||||||
| BranchPolicy::WorktreeIsolated => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn validate_acceptance_tests(tests: &[AcceptanceTest], errors: &mut Vec<String>) {
|
|
||||||
for test in tests {
|
|
||||||
match test {
|
|
||||||
AcceptanceTest::CargoTest { filter } => {
|
|
||||||
if filter
|
|
||||||
.as_deref()
|
|
||||||
.is_some_and(|value| value.trim().is_empty())
|
|
||||||
{
|
|
||||||
errors.push("cargo_test filter must not be empty when present".to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
AcceptanceTest::CustomCommand { cmd } if cmd.trim().is_empty() => {
|
|
||||||
errors.push("custom_command cmd must not be empty".to_string());
|
|
||||||
}
|
|
||||||
AcceptanceTest::CustomCommand { .. } | AcceptanceTest::GreenLevel { .. } => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn validate_escalation_policy(escalation: EscalationPolicy, errors: &mut Vec<String>) {
|
|
||||||
if matches!(
|
|
||||||
escalation,
|
|
||||||
EscalationPolicy::RetryThenEscalate { max_retries: 0 }
|
|
||||||
) {
|
|
||||||
errors.push("retry_then_escalate max_retries must be greater than zero".to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn resolve_path(dispatch_root: &Path, path: &Path) -> PathBuf {
|
|
||||||
if path.is_absolute() {
|
|
||||||
path.to_path_buf()
|
|
||||||
} else {
|
|
||||||
dispatch_root.join(path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use serde_json::json;
|
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
|
||||||
|
|
||||||
fn now_secs() -> u64 {
|
|
||||||
SystemTime::now()
|
|
||||||
.duration_since(UNIX_EPOCH)
|
|
||||||
.unwrap_or_default()
|
|
||||||
.as_secs()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sample_repo_config() -> RepoConfig {
|
|
||||||
RepoConfig {
|
|
||||||
repo_root: PathBuf::from("/repo"),
|
|
||||||
worktree_root: Some(PathBuf::from("/repo/.worktrees/task-1")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sample_packet() -> TaskPacket {
|
|
||||||
let mut metadata = BTreeMap::new();
|
|
||||||
metadata.insert("attempt".to_string(), json!(1));
|
|
||||||
metadata.insert("lane".to_string(), json!("runtime"));
|
|
||||||
|
|
||||||
TaskPacket {
|
|
||||||
id: "packet_001".to_string(),
|
|
||||||
objective: "Implement typed task packet format".to_string(),
|
|
||||||
scope: TaskScope::Module {
|
|
||||||
crate_name: "runtime".to_string(),
|
|
||||||
},
|
|
||||||
repo_config: sample_repo_config(),
|
|
||||||
branch_policy: BranchPolicy::CreateNew {
|
|
||||||
prefix: "ultraclaw".to_string(),
|
|
||||||
},
|
|
||||||
acceptance_tests: vec![
|
|
||||||
AcceptanceTest::CargoTest {
|
|
||||||
filter: Some("task_packet".to_string()),
|
|
||||||
},
|
|
||||||
AcceptanceTest::GreenLevel {
|
|
||||||
level: GreenLevel::Workspace,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
commit_policy: CommitPolicy::CommitPerTask,
|
|
||||||
reporting: ReportingContract::EventStream,
|
|
||||||
escalation: EscalationPolicy::RetryThenEscalate { max_retries: 2 },
|
|
||||||
created_at: now_secs(),
|
|
||||||
metadata,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn valid_packet_passes_validation() {
|
|
||||||
// given
|
|
||||||
let packet = sample_packet();
|
|
||||||
|
|
||||||
// when
|
|
||||||
let validated = validate_packet(packet);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert!(validated.is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn invalid_packet_accumulates_errors() {
|
|
||||||
// given
|
|
||||||
let packet = TaskPacket {
|
|
||||||
id: " ".to_string(),
|
|
||||||
objective: " ".to_string(),
|
|
||||||
scope: TaskScope::Custom {
|
|
||||||
paths: vec![PathBuf::new()],
|
|
||||||
},
|
|
||||||
repo_config: RepoConfig {
|
|
||||||
repo_root: PathBuf::new(),
|
|
||||||
worktree_root: Some(PathBuf::new()),
|
|
||||||
},
|
|
||||||
branch_policy: BranchPolicy::CreateNew {
|
|
||||||
prefix: " ".to_string(),
|
|
||||||
},
|
|
||||||
acceptance_tests: vec![
|
|
||||||
AcceptanceTest::CargoTest {
|
|
||||||
filter: Some(" ".to_string()),
|
|
||||||
},
|
|
||||||
AcceptanceTest::CustomCommand {
|
|
||||||
cmd: " ".to_string(),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
commit_policy: CommitPolicy::NoAutoCommit,
|
|
||||||
reporting: ReportingContract::Silent,
|
|
||||||
escalation: EscalationPolicy::RetryThenEscalate { max_retries: 0 },
|
|
||||||
created_at: 0,
|
|
||||||
metadata: BTreeMap::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// when
|
|
||||||
let error = validate_packet(packet).expect_err("packet should be rejected");
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert!(error.errors().len() >= 8);
|
|
||||||
assert!(error
|
|
||||||
.errors()
|
|
||||||
.contains(&"packet id must not be empty".to_string()));
|
|
||||||
assert!(error
|
|
||||||
.errors()
|
|
||||||
.contains(&"packet objective must not be empty".to_string()));
|
|
||||||
assert!(error
|
|
||||||
.errors()
|
|
||||||
.contains(&"repo_config repo_root must not be empty".to_string()));
|
|
||||||
assert!(error
|
|
||||||
.errors()
|
|
||||||
.contains(&"create_new branch prefix must not be empty".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn single_file_scope_resolves_against_worktree_root() {
|
|
||||||
// given
|
|
||||||
let repo_config = sample_repo_config();
|
|
||||||
let scope = TaskScope::SingleFile {
|
|
||||||
path: PathBuf::from("crates/runtime/src/task_packet.rs"),
|
|
||||||
};
|
|
||||||
|
|
||||||
// when
|
|
||||||
let paths = scope.resolve_paths(&repo_config);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(
|
|
||||||
paths,
|
|
||||||
vec![PathBuf::from(
|
|
||||||
"/repo/.worktrees/task-1/crates/runtime/src/task_packet.rs"
|
|
||||||
)]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn workspace_scope_resolves_to_dispatch_root() {
|
|
||||||
// given
|
|
||||||
let repo_config = sample_repo_config();
|
|
||||||
let scope = TaskScope::Workspace;
|
|
||||||
|
|
||||||
// when
|
|
||||||
let paths = scope.resolve_paths(&repo_config);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(paths, vec![PathBuf::from("/repo/.worktrees/task-1")]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn module_scope_resolves_to_crate_directory() {
|
|
||||||
// given
|
|
||||||
let repo_config = sample_repo_config();
|
|
||||||
let scope = TaskScope::Module {
|
|
||||||
crate_name: "runtime".to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// when
|
|
||||||
let paths = scope.resolve_paths(&repo_config);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(
|
|
||||||
paths,
|
|
||||||
vec![PathBuf::from("/repo/.worktrees/task-1/crates/runtime")]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn custom_scope_preserves_absolute_paths_and_resolves_relative_paths() {
|
|
||||||
// given
|
|
||||||
let repo_config = sample_repo_config();
|
|
||||||
let scope = TaskScope::Custom {
|
|
||||||
paths: vec![
|
|
||||||
PathBuf::from("Cargo.toml"),
|
|
||||||
PathBuf::from("/tmp/shared/script.sh"),
|
|
||||||
],
|
|
||||||
};
|
|
||||||
|
|
||||||
// when
|
|
||||||
let paths = scope.resolve_paths(&repo_config);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(
|
|
||||||
paths,
|
|
||||||
vec![
|
|
||||||
PathBuf::from("/repo/.worktrees/task-1/Cargo.toml"),
|
|
||||||
PathBuf::from("/tmp/shared/script.sh"),
|
|
||||||
]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn serialization_roundtrip_preserves_packet() {
|
|
||||||
// given
|
|
||||||
let packet = sample_packet();
|
|
||||||
|
|
||||||
// when
|
|
||||||
let serialized = serde_json::to_string(&packet).expect("packet should serialize");
|
|
||||||
let deserialized: TaskPacket =
|
|
||||||
serde_json::from_str(&serialized).expect("packet should deserialize");
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(deserialized, packet);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn validated_packet_exposes_inner_packet_and_scope_paths() {
|
|
||||||
// given
|
|
||||||
let packet = sample_packet();
|
|
||||||
|
|
||||||
// when
|
|
||||||
let validated = validate_packet(packet.clone()).expect("packet should validate");
|
|
||||||
let resolved_paths = validated.resolve_scope_paths();
|
|
||||||
let inner = validated.into_inner();
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(
|
|
||||||
resolved_paths,
|
|
||||||
vec![PathBuf::from("/repo/.worktrees/task-1/crates/runtime")]
|
|
||||||
);
|
|
||||||
assert_eq!(inner, packet);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn display_impls_render_snake_case_variants() {
|
|
||||||
// given
|
|
||||||
let rendered = vec![
|
|
||||||
TaskScope::Workspace.to_string(),
|
|
||||||
BranchPolicy::WorktreeIsolated.to_string(),
|
|
||||||
CommitPolicy::SquashOnMerge.to_string(),
|
|
||||||
GreenLevel::MergeReady.to_string(),
|
|
||||||
AcceptanceTest::GreenLevel {
|
|
||||||
level: GreenLevel::Package,
|
|
||||||
}
|
|
||||||
.to_string(),
|
|
||||||
ReportingContract::EventStream.to_string(),
|
|
||||||
EscalationPolicy::AutoEscalate.to_string(),
|
|
||||||
];
|
|
||||||
|
|
||||||
// when
|
|
||||||
let expected = vec![
|
|
||||||
"workspace",
|
|
||||||
"worktree_isolated",
|
|
||||||
"squash_on_merge",
|
|
||||||
"merge_ready",
|
|
||||||
"green_level",
|
|
||||||
"event_stream",
|
|
||||||
"auto_escalate",
|
|
||||||
];
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(rendered, expected);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,299 +0,0 @@
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
|
|
||||||
const TRUST_PROMPT_CUES: &[&str] = &[
|
|
||||||
"do you trust the files in this folder",
|
|
||||||
"trust the files in this folder",
|
|
||||||
"trust this folder",
|
|
||||||
"allow and continue",
|
|
||||||
"yes, proceed",
|
|
||||||
];
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
pub enum TrustPolicy {
|
|
||||||
AutoTrust,
|
|
||||||
RequireApproval,
|
|
||||||
Deny,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub enum TrustEvent {
|
|
||||||
TrustRequired { cwd: String },
|
|
||||||
TrustResolved { cwd: String, policy: TrustPolicy },
|
|
||||||
TrustDenied { cwd: String, reason: String },
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
|
||||||
pub struct TrustConfig {
|
|
||||||
allowlisted: Vec<PathBuf>,
|
|
||||||
denied: Vec<PathBuf>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TrustConfig {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self::default()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn with_allowlisted(mut self, path: impl Into<PathBuf>) -> Self {
|
|
||||||
self.allowlisted.push(path.into());
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn with_denied(mut self, path: impl Into<PathBuf>) -> Self {
|
|
||||||
self.denied.push(path.into());
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub enum TrustDecision {
|
|
||||||
NotRequired,
|
|
||||||
Required {
|
|
||||||
policy: TrustPolicy,
|
|
||||||
events: Vec<TrustEvent>,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TrustDecision {
|
|
||||||
#[must_use]
|
|
||||||
pub fn policy(&self) -> Option<TrustPolicy> {
|
|
||||||
match self {
|
|
||||||
Self::NotRequired => None,
|
|
||||||
Self::Required { policy, .. } => Some(*policy),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn events(&self) -> &[TrustEvent] {
|
|
||||||
match self {
|
|
||||||
Self::NotRequired => &[],
|
|
||||||
Self::Required { events, .. } => events,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct TrustResolver {
|
|
||||||
config: TrustConfig,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TrustResolver {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new(config: TrustConfig) -> Self {
|
|
||||||
Self { config }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn resolve(&self, cwd: &str, screen_text: &str) -> TrustDecision {
|
|
||||||
if !detect_trust_prompt(screen_text) {
|
|
||||||
return TrustDecision::NotRequired;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut events = vec![TrustEvent::TrustRequired {
|
|
||||||
cwd: cwd.to_owned(),
|
|
||||||
}];
|
|
||||||
|
|
||||||
if let Some(matched_root) = self
|
|
||||||
.config
|
|
||||||
.denied
|
|
||||||
.iter()
|
|
||||||
.find(|root| path_matches(cwd, root))
|
|
||||||
{
|
|
||||||
let reason = format!("cwd matches denied trust root: {}", matched_root.display());
|
|
||||||
events.push(TrustEvent::TrustDenied {
|
|
||||||
cwd: cwd.to_owned(),
|
|
||||||
reason,
|
|
||||||
});
|
|
||||||
return TrustDecision::Required {
|
|
||||||
policy: TrustPolicy::Deny,
|
|
||||||
events,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if self
|
|
||||||
.config
|
|
||||||
.allowlisted
|
|
||||||
.iter()
|
|
||||||
.any(|root| path_matches(cwd, root))
|
|
||||||
{
|
|
||||||
events.push(TrustEvent::TrustResolved {
|
|
||||||
cwd: cwd.to_owned(),
|
|
||||||
policy: TrustPolicy::AutoTrust,
|
|
||||||
});
|
|
||||||
return TrustDecision::Required {
|
|
||||||
policy: TrustPolicy::AutoTrust,
|
|
||||||
events,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
TrustDecision::Required {
|
|
||||||
policy: TrustPolicy::RequireApproval,
|
|
||||||
events,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn trusts(&self, cwd: &str) -> bool {
|
|
||||||
!self
|
|
||||||
.config
|
|
||||||
.denied
|
|
||||||
.iter()
|
|
||||||
.any(|root| path_matches(cwd, root))
|
|
||||||
&& self
|
|
||||||
.config
|
|
||||||
.allowlisted
|
|
||||||
.iter()
|
|
||||||
.any(|root| path_matches(cwd, root))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn detect_trust_prompt(screen_text: &str) -> bool {
|
|
||||||
let lowered = screen_text.to_ascii_lowercase();
|
|
||||||
TRUST_PROMPT_CUES
|
|
||||||
.iter()
|
|
||||||
.any(|needle| lowered.contains(needle))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn path_matches_trusted_root(cwd: &str, trusted_root: &str) -> bool {
|
|
||||||
path_matches(cwd, &normalize_path(Path::new(trusted_root)))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn path_matches(candidate: &str, root: &Path) -> bool {
|
|
||||||
let candidate = normalize_path(Path::new(candidate));
|
|
||||||
let root = normalize_path(root);
|
|
||||||
candidate == root || candidate.starts_with(&root)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn normalize_path(path: &Path) -> PathBuf {
|
|
||||||
std::fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::{
|
|
||||||
detect_trust_prompt, path_matches_trusted_root, TrustConfig, TrustDecision, TrustEvent,
|
|
||||||
TrustPolicy, TrustResolver,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn detects_known_trust_prompt_copy() {
|
|
||||||
// given
|
|
||||||
let screen_text = "Do you trust the files in this folder?\n1. Yes, proceed\n2. No";
|
|
||||||
|
|
||||||
// when
|
|
||||||
let detected = detect_trust_prompt(screen_text);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert!(detected);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn does_not_emit_events_when_prompt_is_absent() {
|
|
||||||
// given
|
|
||||||
let resolver = TrustResolver::new(TrustConfig::new().with_allowlisted("/tmp/worktrees"));
|
|
||||||
|
|
||||||
// when
|
|
||||||
let decision = resolver.resolve("/tmp/worktrees/repo-a", "Ready for your input\n>");
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(decision, TrustDecision::NotRequired);
|
|
||||||
assert_eq!(decision.events(), &[]);
|
|
||||||
assert_eq!(decision.policy(), None);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn auto_trusts_allowlisted_cwd_after_prompt_detection() {
|
|
||||||
// given
|
|
||||||
let resolver = TrustResolver::new(TrustConfig::new().with_allowlisted("/tmp/worktrees"));
|
|
||||||
|
|
||||||
// when
|
|
||||||
let decision = resolver.resolve(
|
|
||||||
"/tmp/worktrees/repo-a",
|
|
||||||
"Do you trust the files in this folder?\n1. Yes, proceed\n2. No",
|
|
||||||
);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(decision.policy(), Some(TrustPolicy::AutoTrust));
|
|
||||||
assert_eq!(
|
|
||||||
decision.events(),
|
|
||||||
&[
|
|
||||||
TrustEvent::TrustRequired {
|
|
||||||
cwd: "/tmp/worktrees/repo-a".to_string(),
|
|
||||||
},
|
|
||||||
TrustEvent::TrustResolved {
|
|
||||||
cwd: "/tmp/worktrees/repo-a".to_string(),
|
|
||||||
policy: TrustPolicy::AutoTrust,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn requires_approval_for_unknown_cwd_after_prompt_detection() {
|
|
||||||
// given
|
|
||||||
let resolver = TrustResolver::new(TrustConfig::new().with_allowlisted("/tmp/worktrees"));
|
|
||||||
|
|
||||||
// when
|
|
||||||
let decision = resolver.resolve(
|
|
||||||
"/tmp/other/repo-b",
|
|
||||||
"Do you trust the files in this folder?\n1. Yes, proceed\n2. No",
|
|
||||||
);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(decision.policy(), Some(TrustPolicy::RequireApproval));
|
|
||||||
assert_eq!(
|
|
||||||
decision.events(),
|
|
||||||
&[TrustEvent::TrustRequired {
|
|
||||||
cwd: "/tmp/other/repo-b".to_string(),
|
|
||||||
}]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn denied_root_takes_precedence_over_allowlist() {
|
|
||||||
// given
|
|
||||||
let resolver = TrustResolver::new(
|
|
||||||
TrustConfig::new()
|
|
||||||
.with_allowlisted("/tmp/worktrees")
|
|
||||||
.with_denied("/tmp/worktrees/repo-c"),
|
|
||||||
);
|
|
||||||
|
|
||||||
// when
|
|
||||||
let decision = resolver.resolve(
|
|
||||||
"/tmp/worktrees/repo-c",
|
|
||||||
"Do you trust the files in this folder?\n1. Yes, proceed\n2. No",
|
|
||||||
);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert_eq!(decision.policy(), Some(TrustPolicy::Deny));
|
|
||||||
assert_eq!(
|
|
||||||
decision.events(),
|
|
||||||
&[
|
|
||||||
TrustEvent::TrustRequired {
|
|
||||||
cwd: "/tmp/worktrees/repo-c".to_string(),
|
|
||||||
},
|
|
||||||
TrustEvent::TrustDenied {
|
|
||||||
cwd: "/tmp/worktrees/repo-c".to_string(),
|
|
||||||
reason: "cwd matches denied trust root: /tmp/worktrees/repo-c".to_string(),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn sibling_prefix_does_not_match_trusted_root() {
|
|
||||||
// given
|
|
||||||
let trusted_root = "/tmp/worktrees";
|
|
||||||
let sibling_path = "/tmp/worktrees-other/repo-d";
|
|
||||||
|
|
||||||
// when
|
|
||||||
let matched = path_matches_trusted_root(sibling_path, trusted_root);
|
|
||||||
|
|
||||||
// then
|
|
||||||
assert!(!matched);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,732 +0,0 @@
|
||||||
//! In-memory worker-boot state machine and control registry.
|
|
||||||
//!
|
|
||||||
//! This provides a foundational control plane for reliable worker startup:
|
|
||||||
//! trust-gate detection, ready-for-prompt handshakes, and prompt-misdelivery
|
|
||||||
//! detection/recovery all live above raw terminal transport.
|
|
||||||
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
fn now_secs() -> u64 {
|
|
||||||
SystemTime::now()
|
|
||||||
.duration_since(UNIX_EPOCH)
|
|
||||||
.unwrap_or_default()
|
|
||||||
.as_secs()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum WorkerStatus {
|
|
||||||
Spawning,
|
|
||||||
TrustRequired,
|
|
||||||
ReadyForPrompt,
|
|
||||||
PromptAccepted,
|
|
||||||
Running,
|
|
||||||
Blocked,
|
|
||||||
Finished,
|
|
||||||
Failed,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for WorkerStatus {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::Spawning => write!(f, "spawning"),
|
|
||||||
Self::TrustRequired => write!(f, "trust_required"),
|
|
||||||
Self::ReadyForPrompt => write!(f, "ready_for_prompt"),
|
|
||||||
Self::PromptAccepted => write!(f, "prompt_accepted"),
|
|
||||||
Self::Running => write!(f, "running"),
|
|
||||||
Self::Blocked => write!(f, "blocked"),
|
|
||||||
Self::Finished => write!(f, "finished"),
|
|
||||||
Self::Failed => write!(f, "failed"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum WorkerFailureKind {
|
|
||||||
TrustGate,
|
|
||||||
PromptDelivery,
|
|
||||||
Protocol,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
||||||
pub struct WorkerFailure {
|
|
||||||
pub kind: WorkerFailureKind,
|
|
||||||
pub message: String,
|
|
||||||
pub created_at: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum WorkerEventKind {
|
|
||||||
Spawning,
|
|
||||||
TrustRequired,
|
|
||||||
TrustResolved,
|
|
||||||
ReadyForPrompt,
|
|
||||||
PromptAccepted,
|
|
||||||
PromptMisdelivery,
|
|
||||||
PromptReplayArmed,
|
|
||||||
Running,
|
|
||||||
Restarted,
|
|
||||||
Finished,
|
|
||||||
Failed,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
||||||
pub struct WorkerEvent {
|
|
||||||
pub seq: u64,
|
|
||||||
pub kind: WorkerEventKind,
|
|
||||||
pub status: WorkerStatus,
|
|
||||||
pub detail: Option<String>,
|
|
||||||
pub timestamp: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
||||||
pub struct Worker {
|
|
||||||
pub worker_id: String,
|
|
||||||
pub cwd: String,
|
|
||||||
pub status: WorkerStatus,
|
|
||||||
pub trust_auto_resolve: bool,
|
|
||||||
pub trust_gate_cleared: bool,
|
|
||||||
pub auto_recover_prompt_misdelivery: bool,
|
|
||||||
pub prompt_delivery_attempts: u32,
|
|
||||||
pub last_prompt: Option<String>,
|
|
||||||
pub replay_prompt: Option<String>,
|
|
||||||
pub last_error: Option<WorkerFailure>,
|
|
||||||
pub created_at: u64,
|
|
||||||
pub updated_at: u64,
|
|
||||||
pub events: Vec<WorkerEvent>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
|
||||||
pub struct WorkerRegistry {
|
|
||||||
inner: Arc<Mutex<WorkerRegistryInner>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
|
||||||
struct WorkerRegistryInner {
|
|
||||||
workers: HashMap<String, Worker>,
|
|
||||||
counter: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl WorkerRegistry {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self::default()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn create(
|
|
||||||
&self,
|
|
||||||
cwd: &str,
|
|
||||||
trusted_roots: &[String],
|
|
||||||
auto_recover_prompt_misdelivery: bool,
|
|
||||||
) -> Worker {
|
|
||||||
let mut inner = self.inner.lock().expect("worker registry lock poisoned");
|
|
||||||
inner.counter += 1;
|
|
||||||
let ts = now_secs();
|
|
||||||
let worker_id = format!("worker_{:08x}_{}", ts, inner.counter);
|
|
||||||
let trust_auto_resolve = trusted_roots
|
|
||||||
.iter()
|
|
||||||
.any(|root| path_matches_allowlist(cwd, root));
|
|
||||||
let mut worker = Worker {
|
|
||||||
worker_id: worker_id.clone(),
|
|
||||||
cwd: cwd.to_owned(),
|
|
||||||
status: WorkerStatus::Spawning,
|
|
||||||
trust_auto_resolve,
|
|
||||||
trust_gate_cleared: false,
|
|
||||||
auto_recover_prompt_misdelivery,
|
|
||||||
prompt_delivery_attempts: 0,
|
|
||||||
last_prompt: None,
|
|
||||||
replay_prompt: None,
|
|
||||||
last_error: None,
|
|
||||||
created_at: ts,
|
|
||||||
updated_at: ts,
|
|
||||||
events: Vec::new(),
|
|
||||||
};
|
|
||||||
push_event(
|
|
||||||
&mut worker,
|
|
||||||
WorkerEventKind::Spawning,
|
|
||||||
WorkerStatus::Spawning,
|
|
||||||
Some("worker created".to_string()),
|
|
||||||
);
|
|
||||||
inner.workers.insert(worker_id, worker.clone());
|
|
||||||
worker
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn get(&self, worker_id: &str) -> Option<Worker> {
|
|
||||||
let inner = self.inner.lock().expect("worker registry lock poisoned");
|
|
||||||
inner.workers.get(worker_id).cloned()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn observe(&self, worker_id: &str, screen_text: &str) -> Result<Worker, String> {
|
|
||||||
let mut inner = self.inner.lock().expect("worker registry lock poisoned");
|
|
||||||
let worker = inner
|
|
||||||
.workers
|
|
||||||
.get_mut(worker_id)
|
|
||||||
.ok_or_else(|| format!("worker not found: {worker_id}"))?;
|
|
||||||
let lowered = screen_text.to_ascii_lowercase();
|
|
||||||
|
|
||||||
if !worker.trust_gate_cleared && detect_trust_prompt(&lowered) {
|
|
||||||
worker.status = WorkerStatus::TrustRequired;
|
|
||||||
worker.last_error = Some(WorkerFailure {
|
|
||||||
kind: WorkerFailureKind::TrustGate,
|
|
||||||
message: "worker boot blocked on trust prompt".to_string(),
|
|
||||||
created_at: now_secs(),
|
|
||||||
});
|
|
||||||
push_event(
|
|
||||||
worker,
|
|
||||||
WorkerEventKind::TrustRequired,
|
|
||||||
WorkerStatus::TrustRequired,
|
|
||||||
Some("trust prompt detected".to_string()),
|
|
||||||
);
|
|
||||||
|
|
||||||
if worker.trust_auto_resolve {
|
|
||||||
worker.trust_gate_cleared = true;
|
|
||||||
worker.last_error = None;
|
|
||||||
worker.status = WorkerStatus::Spawning;
|
|
||||||
push_event(
|
|
||||||
worker,
|
|
||||||
WorkerEventKind::TrustResolved,
|
|
||||||
WorkerStatus::Spawning,
|
|
||||||
Some("allowlisted repo auto-resolved trust prompt".to_string()),
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
return Ok(worker.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if prompt_misdelivery_is_relevant(worker)
|
|
||||||
&& detect_prompt_misdelivery(&lowered, worker.last_prompt.as_deref())
|
|
||||||
{
|
|
||||||
let detail = prompt_preview(worker.last_prompt.as_deref().unwrap_or_default());
|
|
||||||
worker.last_error = Some(WorkerFailure {
|
|
||||||
kind: WorkerFailureKind::PromptDelivery,
|
|
||||||
message: format!("worker prompt landed in shell instead of coding agent: {detail}"),
|
|
||||||
created_at: now_secs(),
|
|
||||||
});
|
|
||||||
push_event(
|
|
||||||
worker,
|
|
||||||
WorkerEventKind::PromptMisdelivery,
|
|
||||||
WorkerStatus::Blocked,
|
|
||||||
Some("shell misdelivery detected".to_string()),
|
|
||||||
);
|
|
||||||
if worker.auto_recover_prompt_misdelivery {
|
|
||||||
worker.replay_prompt = worker.last_prompt.clone();
|
|
||||||
worker.status = WorkerStatus::ReadyForPrompt;
|
|
||||||
push_event(
|
|
||||||
worker,
|
|
||||||
WorkerEventKind::PromptReplayArmed,
|
|
||||||
WorkerStatus::ReadyForPrompt,
|
|
||||||
Some("prompt replay armed after shell misdelivery".to_string()),
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
worker.status = WorkerStatus::Blocked;
|
|
||||||
}
|
|
||||||
return Ok(worker.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
if detect_running_cue(&lowered)
|
|
||||||
&& matches!(
|
|
||||||
worker.status,
|
|
||||||
WorkerStatus::PromptAccepted | WorkerStatus::ReadyForPrompt
|
|
||||||
)
|
|
||||||
{
|
|
||||||
worker.status = WorkerStatus::Running;
|
|
||||||
worker.last_error = None;
|
|
||||||
push_event(
|
|
||||||
worker,
|
|
||||||
WorkerEventKind::Running,
|
|
||||||
WorkerStatus::Running,
|
|
||||||
Some("worker accepted prompt and started running".to_string()),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if detect_ready_for_prompt(screen_text, &lowered)
|
|
||||||
&& !matches!(
|
|
||||||
worker.status,
|
|
||||||
WorkerStatus::ReadyForPrompt | WorkerStatus::Running
|
|
||||||
)
|
|
||||||
{
|
|
||||||
worker.status = WorkerStatus::ReadyForPrompt;
|
|
||||||
if matches!(
|
|
||||||
worker.last_error.as_ref().map(|failure| failure.kind),
|
|
||||||
Some(WorkerFailureKind::TrustGate)
|
|
||||||
) {
|
|
||||||
worker.last_error = None;
|
|
||||||
}
|
|
||||||
push_event(
|
|
||||||
worker,
|
|
||||||
WorkerEventKind::ReadyForPrompt,
|
|
||||||
WorkerStatus::ReadyForPrompt,
|
|
||||||
Some("worker is ready for prompt delivery".to_string()),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(worker.clone())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn resolve_trust(&self, worker_id: &str) -> Result<Worker, String> {
|
|
||||||
let mut inner = self.inner.lock().expect("worker registry lock poisoned");
|
|
||||||
let worker = inner
|
|
||||||
.workers
|
|
||||||
.get_mut(worker_id)
|
|
||||||
.ok_or_else(|| format!("worker not found: {worker_id}"))?;
|
|
||||||
|
|
||||||
if worker.status != WorkerStatus::TrustRequired {
|
|
||||||
return Err(format!(
|
|
||||||
"worker {worker_id} is not waiting on trust; current status: {}",
|
|
||||||
worker.status
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
worker.trust_gate_cleared = true;
|
|
||||||
worker.last_error = None;
|
|
||||||
worker.status = WorkerStatus::Spawning;
|
|
||||||
push_event(
|
|
||||||
worker,
|
|
||||||
WorkerEventKind::TrustResolved,
|
|
||||||
WorkerStatus::Spawning,
|
|
||||||
Some("trust prompt resolved manually".to_string()),
|
|
||||||
);
|
|
||||||
Ok(worker.clone())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send_prompt(&self, worker_id: &str, prompt: Option<&str>) -> Result<Worker, String> {
|
|
||||||
let mut inner = self.inner.lock().expect("worker registry lock poisoned");
|
|
||||||
let worker = inner
|
|
||||||
.workers
|
|
||||||
.get_mut(worker_id)
|
|
||||||
.ok_or_else(|| format!("worker not found: {worker_id}"))?;
|
|
||||||
|
|
||||||
if worker.status != WorkerStatus::ReadyForPrompt {
|
|
||||||
return Err(format!(
|
|
||||||
"worker {worker_id} is not ready for prompt delivery; current status: {}",
|
|
||||||
worker.status
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let next_prompt = prompt
|
|
||||||
.map(str::trim)
|
|
||||||
.filter(|value| !value.is_empty())
|
|
||||||
.map(str::to_owned)
|
|
||||||
.or_else(|| worker.replay_prompt.clone())
|
|
||||||
.ok_or_else(|| format!("worker {worker_id} has no prompt to send or replay"))?;
|
|
||||||
|
|
||||||
worker.prompt_delivery_attempts += 1;
|
|
||||||
worker.last_prompt = Some(next_prompt.clone());
|
|
||||||
worker.replay_prompt = None;
|
|
||||||
worker.last_error = None;
|
|
||||||
worker.status = WorkerStatus::PromptAccepted;
|
|
||||||
push_event(
|
|
||||||
worker,
|
|
||||||
WorkerEventKind::PromptAccepted,
|
|
||||||
WorkerStatus::PromptAccepted,
|
|
||||||
Some(format!(
|
|
||||||
"prompt accepted for delivery: {}",
|
|
||||||
prompt_preview(&next_prompt)
|
|
||||||
)),
|
|
||||||
);
|
|
||||||
Ok(worker.clone())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn await_ready(&self, worker_id: &str) -> Result<WorkerReadySnapshot, String> {
|
|
||||||
let worker = self
|
|
||||||
.get(worker_id)
|
|
||||||
.ok_or_else(|| format!("worker not found: {worker_id}"))?;
|
|
||||||
|
|
||||||
Ok(WorkerReadySnapshot {
|
|
||||||
worker_id: worker.worker_id.clone(),
|
|
||||||
status: worker.status,
|
|
||||||
ready: worker.status == WorkerStatus::ReadyForPrompt,
|
|
||||||
blocked: matches!(
|
|
||||||
worker.status,
|
|
||||||
WorkerStatus::TrustRequired | WorkerStatus::Blocked
|
|
||||||
),
|
|
||||||
replay_prompt_ready: worker.replay_prompt.is_some(),
|
|
||||||
last_error: worker.last_error.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn restart(&self, worker_id: &str) -> Result<Worker, String> {
|
|
||||||
let mut inner = self.inner.lock().expect("worker registry lock poisoned");
|
|
||||||
let worker = inner
|
|
||||||
.workers
|
|
||||||
.get_mut(worker_id)
|
|
||||||
.ok_or_else(|| format!("worker not found: {worker_id}"))?;
|
|
||||||
worker.status = WorkerStatus::Spawning;
|
|
||||||
worker.trust_gate_cleared = false;
|
|
||||||
worker.last_prompt = None;
|
|
||||||
worker.replay_prompt = None;
|
|
||||||
worker.last_error = None;
|
|
||||||
worker.prompt_delivery_attempts = 0;
|
|
||||||
push_event(
|
|
||||||
worker,
|
|
||||||
WorkerEventKind::Restarted,
|
|
||||||
WorkerStatus::Spawning,
|
|
||||||
Some("worker restarted".to_string()),
|
|
||||||
);
|
|
||||||
Ok(worker.clone())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn terminate(&self, worker_id: &str) -> Result<Worker, String> {
|
|
||||||
let mut inner = self.inner.lock().expect("worker registry lock poisoned");
|
|
||||||
let worker = inner
|
|
||||||
.workers
|
|
||||||
.get_mut(worker_id)
|
|
||||||
.ok_or_else(|| format!("worker not found: {worker_id}"))?;
|
|
||||||
worker.status = WorkerStatus::Finished;
|
|
||||||
push_event(
|
|
||||||
worker,
|
|
||||||
WorkerEventKind::Finished,
|
|
||||||
WorkerStatus::Finished,
|
|
||||||
Some("worker terminated by control plane".to_string()),
|
|
||||||
);
|
|
||||||
Ok(worker.clone())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
||||||
pub struct WorkerReadySnapshot {
|
|
||||||
pub worker_id: String,
|
|
||||||
pub status: WorkerStatus,
|
|
||||||
pub ready: bool,
|
|
||||||
pub blocked: bool,
|
|
||||||
pub replay_prompt_ready: bool,
|
|
||||||
pub last_error: Option<WorkerFailure>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prompt_misdelivery_is_relevant(worker: &Worker) -> bool {
|
|
||||||
matches!(
|
|
||||||
worker.status,
|
|
||||||
WorkerStatus::PromptAccepted | WorkerStatus::Running
|
|
||||||
) && worker.last_prompt.is_some()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn push_event(
|
|
||||||
worker: &mut Worker,
|
|
||||||
kind: WorkerEventKind,
|
|
||||||
status: WorkerStatus,
|
|
||||||
detail: Option<String>,
|
|
||||||
) {
|
|
||||||
let timestamp = now_secs();
|
|
||||||
let seq = worker.events.len() as u64 + 1;
|
|
||||||
worker.updated_at = timestamp;
|
|
||||||
worker.events.push(WorkerEvent {
|
|
||||||
seq,
|
|
||||||
kind,
|
|
||||||
status,
|
|
||||||
detail,
|
|
||||||
timestamp,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
fn path_matches_allowlist(cwd: &str, trusted_root: &str) -> bool {
|
|
||||||
let cwd = normalize_path(cwd);
|
|
||||||
let trusted_root = normalize_path(trusted_root);
|
|
||||||
cwd == trusted_root || cwd.starts_with(&trusted_root)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn normalize_path(path: &str) -> PathBuf {
|
|
||||||
std::fs::canonicalize(path).unwrap_or_else(|_| Path::new(path).to_path_buf())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn detect_trust_prompt(lowered: &str) -> bool {
|
|
||||||
[
|
|
||||||
"do you trust the files in this folder",
|
|
||||||
"trust the files in this folder",
|
|
||||||
"trust this folder",
|
|
||||||
"allow and continue",
|
|
||||||
"yes, proceed",
|
|
||||||
]
|
|
||||||
.iter()
|
|
||||||
.any(|needle| lowered.contains(needle))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn detect_ready_for_prompt(screen_text: &str, lowered: &str) -> bool {
|
|
||||||
if [
|
|
||||||
"ready for input",
|
|
||||||
"ready for your input",
|
|
||||||
"ready for prompt",
|
|
||||||
"send a message",
|
|
||||||
]
|
|
||||||
.iter()
|
|
||||||
.any(|needle| lowered.contains(needle))
|
|
||||||
{
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
let Some(last_non_empty) = screen_text
|
|
||||||
.lines()
|
|
||||||
.rev()
|
|
||||||
.find(|line| !line.trim().is_empty())
|
|
||||||
else {
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
let trimmed = last_non_empty.trim();
|
|
||||||
if is_shell_prompt(trimmed) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
trimmed == ">"
|
|
||||||
|| trimmed == "›"
|
|
||||||
|| trimmed == "❯"
|
|
||||||
|| trimmed.starts_with("> ")
|
|
||||||
|| trimmed.starts_with("› ")
|
|
||||||
|| trimmed.starts_with("❯ ")
|
|
||||||
|| trimmed.contains("│ >")
|
|
||||||
|| trimmed.contains("│ ›")
|
|
||||||
|| trimmed.contains("│ ❯")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn detect_running_cue(lowered: &str) -> bool {
|
|
||||||
[
|
|
||||||
"thinking",
|
|
||||||
"working",
|
|
||||||
"running tests",
|
|
||||||
"inspecting",
|
|
||||||
"analyzing",
|
|
||||||
]
|
|
||||||
.iter()
|
|
||||||
.any(|needle| lowered.contains(needle))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_shell_prompt(trimmed: &str) -> bool {
|
|
||||||
trimmed.ends_with('$')
|
|
||||||
|| trimmed.ends_with('%')
|
|
||||||
|| trimmed.ends_with('#')
|
|
||||||
|| trimmed.starts_with('$')
|
|
||||||
|| trimmed.starts_with('%')
|
|
||||||
|| trimmed.starts_with('#')
|
|
||||||
}
|
|
||||||
|
|
||||||
fn detect_prompt_misdelivery(lowered: &str, prompt: Option<&str>) -> bool {
|
|
||||||
let Some(prompt) = prompt else {
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
let shell_error = [
|
|
||||||
"command not found",
|
|
||||||
"syntax error near unexpected token",
|
|
||||||
"parse error near",
|
|
||||||
"no such file or directory",
|
|
||||||
"unknown command",
|
|
||||||
]
|
|
||||||
.iter()
|
|
||||||
.any(|needle| lowered.contains(needle));
|
|
||||||
|
|
||||||
if !shell_error {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
let first_prompt_line = prompt
|
|
||||||
.lines()
|
|
||||||
.find(|line| !line.trim().is_empty())
|
|
||||||
.map(|line| line.trim().to_ascii_lowercase())
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
first_prompt_line.is_empty() || lowered.contains(&first_prompt_line)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prompt_preview(prompt: &str) -> String {
|
|
||||||
let trimmed = prompt.trim();
|
|
||||||
if trimmed.chars().count() <= 48 {
|
|
||||||
return trimmed.to_string();
|
|
||||||
}
|
|
||||||
let preview = trimmed.chars().take(48).collect::<String>();
|
|
||||||
format!("{}…", preview.trim_end())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn allowlisted_trust_prompt_auto_resolves_then_reaches_ready_state() {
|
|
||||||
let registry = WorkerRegistry::new();
|
|
||||||
let worker = registry.create(
|
|
||||||
"/tmp/worktrees/repo-a",
|
|
||||||
&["/tmp/worktrees".to_string()],
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
|
|
||||||
let after_trust = registry
|
|
||||||
.observe(
|
|
||||||
&worker.worker_id,
|
|
||||||
"Do you trust the files in this folder?\n1. Yes, proceed\n2. No",
|
|
||||||
)
|
|
||||||
.expect("trust observe should succeed");
|
|
||||||
assert_eq!(after_trust.status, WorkerStatus::Spawning);
|
|
||||||
assert!(after_trust.trust_gate_cleared);
|
|
||||||
assert!(after_trust
|
|
||||||
.events
|
|
||||||
.iter()
|
|
||||||
.any(|event| event.kind == WorkerEventKind::TrustRequired));
|
|
||||||
assert!(after_trust
|
|
||||||
.events
|
|
||||||
.iter()
|
|
||||||
.any(|event| event.kind == WorkerEventKind::TrustResolved));
|
|
||||||
|
|
||||||
let ready = registry
|
|
||||||
.observe(&worker.worker_id, "Ready for your input\n>")
|
|
||||||
.expect("ready observe should succeed");
|
|
||||||
assert_eq!(ready.status, WorkerStatus::ReadyForPrompt);
|
|
||||||
assert!(ready.last_error.is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn trust_prompt_blocks_non_allowlisted_worker_until_resolved() {
|
|
||||||
let registry = WorkerRegistry::new();
|
|
||||||
let worker = registry.create("/tmp/repo-b", &[], true);
|
|
||||||
|
|
||||||
let blocked = registry
|
|
||||||
.observe(
|
|
||||||
&worker.worker_id,
|
|
||||||
"Do you trust the files in this folder?\n1. Yes, proceed\n2. No",
|
|
||||||
)
|
|
||||||
.expect("trust observe should succeed");
|
|
||||||
assert_eq!(blocked.status, WorkerStatus::TrustRequired);
|
|
||||||
assert_eq!(
|
|
||||||
blocked.last_error.expect("trust error should exist").kind,
|
|
||||||
WorkerFailureKind::TrustGate
|
|
||||||
);
|
|
||||||
|
|
||||||
let send_before_resolve = registry.send_prompt(&worker.worker_id, Some("ship it"));
|
|
||||||
assert!(send_before_resolve
|
|
||||||
.expect_err("prompt delivery should be gated")
|
|
||||||
.contains("not ready for prompt delivery"));
|
|
||||||
|
|
||||||
let resolved = registry
|
|
||||||
.resolve_trust(&worker.worker_id)
|
|
||||||
.expect("manual trust resolution should succeed");
|
|
||||||
assert_eq!(resolved.status, WorkerStatus::Spawning);
|
|
||||||
assert!(resolved.trust_gate_cleared);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn ready_detection_ignores_plain_shell_prompts() {
|
|
||||||
assert!(!detect_ready_for_prompt("bellman@host %", "bellman@host %"));
|
|
||||||
assert!(!detect_ready_for_prompt("/tmp/repo $", "/tmp/repo $"));
|
|
||||||
assert!(detect_ready_for_prompt("│ >", "│ >"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn prompt_misdelivery_is_detected_and_replay_can_be_rearmed() {
|
|
||||||
let registry = WorkerRegistry::new();
|
|
||||||
let worker = registry.create("/tmp/repo-c", &[], true);
|
|
||||||
registry
|
|
||||||
.observe(&worker.worker_id, "Ready for input\n>")
|
|
||||||
.expect("ready observe should succeed");
|
|
||||||
|
|
||||||
let accepted = registry
|
|
||||||
.send_prompt(&worker.worker_id, Some("Implement worker handshake"))
|
|
||||||
.expect("prompt send should succeed");
|
|
||||||
assert_eq!(accepted.status, WorkerStatus::PromptAccepted);
|
|
||||||
assert_eq!(accepted.prompt_delivery_attempts, 1);
|
|
||||||
|
|
||||||
let recovered = registry
|
|
||||||
.observe(
|
|
||||||
&worker.worker_id,
|
|
||||||
"% Implement worker handshake\nzsh: command not found: Implement",
|
|
||||||
)
|
|
||||||
.expect("misdelivery observe should succeed");
|
|
||||||
assert_eq!(recovered.status, WorkerStatus::ReadyForPrompt);
|
|
||||||
assert_eq!(
|
|
||||||
recovered
|
|
||||||
.last_error
|
|
||||||
.expect("misdelivery error should exist")
|
|
||||||
.kind,
|
|
||||||
WorkerFailureKind::PromptDelivery
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
recovered.replay_prompt.as_deref(),
|
|
||||||
Some("Implement worker handshake")
|
|
||||||
);
|
|
||||||
assert!(recovered
|
|
||||||
.events
|
|
||||||
.iter()
|
|
||||||
.any(|event| event.kind == WorkerEventKind::PromptMisdelivery));
|
|
||||||
assert!(recovered
|
|
||||||
.events
|
|
||||||
.iter()
|
|
||||||
.any(|event| event.kind == WorkerEventKind::PromptReplayArmed));
|
|
||||||
|
|
||||||
let replayed = registry
|
|
||||||
.send_prompt(&worker.worker_id, None)
|
|
||||||
.expect("replay send should succeed");
|
|
||||||
assert_eq!(replayed.status, WorkerStatus::PromptAccepted);
|
|
||||||
assert!(replayed.replay_prompt.is_none());
|
|
||||||
assert_eq!(replayed.prompt_delivery_attempts, 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn await_ready_surfaces_blocked_or_ready_worker_state() {
|
|
||||||
let registry = WorkerRegistry::new();
|
|
||||||
let worker = registry.create("/tmp/repo-d", &[], false);
|
|
||||||
|
|
||||||
let initial = registry
|
|
||||||
.await_ready(&worker.worker_id)
|
|
||||||
.expect("await should succeed");
|
|
||||||
assert!(!initial.ready);
|
|
||||||
assert!(!initial.blocked);
|
|
||||||
|
|
||||||
registry
|
|
||||||
.observe(
|
|
||||||
&worker.worker_id,
|
|
||||||
"Do you trust the files in this folder?\n1. Yes, proceed\n2. No",
|
|
||||||
)
|
|
||||||
.expect("trust observe should succeed");
|
|
||||||
let blocked = registry
|
|
||||||
.await_ready(&worker.worker_id)
|
|
||||||
.expect("await should succeed");
|
|
||||||
assert!(!blocked.ready);
|
|
||||||
assert!(blocked.blocked);
|
|
||||||
|
|
||||||
registry
|
|
||||||
.resolve_trust(&worker.worker_id)
|
|
||||||
.expect("manual trust resolution should succeed");
|
|
||||||
registry
|
|
||||||
.observe(&worker.worker_id, "Ready for your input\n>")
|
|
||||||
.expect("ready observe should succeed");
|
|
||||||
let ready = registry
|
|
||||||
.await_ready(&worker.worker_id)
|
|
||||||
.expect("await should succeed");
|
|
||||||
assert!(ready.ready);
|
|
||||||
assert!(!ready.blocked);
|
|
||||||
assert!(ready.last_error.is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn restart_and_terminate_reset_or_finish_worker() {
|
|
||||||
let registry = WorkerRegistry::new();
|
|
||||||
let worker = registry.create("/tmp/repo-e", &[], true);
|
|
||||||
registry
|
|
||||||
.observe(&worker.worker_id, "Ready for input\n>")
|
|
||||||
.expect("ready observe should succeed");
|
|
||||||
registry
|
|
||||||
.send_prompt(&worker.worker_id, Some("Run tests"))
|
|
||||||
.expect("prompt send should succeed");
|
|
||||||
|
|
||||||
let restarted = registry
|
|
||||||
.restart(&worker.worker_id)
|
|
||||||
.expect("restart should succeed");
|
|
||||||
assert_eq!(restarted.status, WorkerStatus::Spawning);
|
|
||||||
assert_eq!(restarted.prompt_delivery_attempts, 0);
|
|
||||||
assert!(restarted.last_prompt.is_none());
|
|
||||||
|
|
||||||
let finished = registry
|
|
||||||
.terminate(&worker.worker_id)
|
|
||||||
.expect("terminate should succeed");
|
|
||||||
assert_eq!(finished.status, WorkerStatus::Finished);
|
|
||||||
assert!(finished
|
|
||||||
.events
|
|
||||||
.iter()
|
|
||||||
.any(|event| event.kind == WorkerEventKind::Finished));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
{"created_at_ms":1775230717464,"session_id":"session-1775230717464-3","type":"session_meta","updated_at_ms":1775230717464,"version":1}
|
|
||||||
|
|
@ -18,7 +18,6 @@ pulldown-cmark = "0.13"
|
||||||
rustyline = "15"
|
rustyline = "15"
|
||||||
runtime = { path = "../runtime" }
|
runtime = { path = "../runtime" }
|
||||||
plugins = { path = "../plugins" }
|
plugins = { path = "../plugins" }
|
||||||
serde = { version = "1", features = ["derive"] }
|
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
syntect = "5"
|
syntect = "5"
|
||||||
tokio = { version = "1", features = ["rt-multi-thread", "signal", "time"] }
|
tokio = { version = "1", features = ["rt-multi-thread", "signal", "time"] }
|
||||||
|
|
|
||||||
|
|
@ -39,18 +39,17 @@ use init::initialize_repo;
|
||||||
use plugins::{PluginHooks, PluginManager, PluginManagerConfig, PluginRegistry};
|
use plugins::{PluginHooks, PluginManager, PluginManagerConfig, PluginRegistry};
|
||||||
use render::{MarkdownStreamState, Spinner, TerminalRenderer};
|
use render::{MarkdownStreamState, Spinner, TerminalRenderer};
|
||||||
use runtime::{
|
use runtime::{
|
||||||
clear_oauth_credentials, format_usd, generate_pkce_pair, generate_state, load_system_prompt,
|
clear_oauth_credentials, generate_pkce_pair, generate_state, load_system_prompt,
|
||||||
parse_oauth_callback_request_target, pricing_for_model, resolve_sandbox_status,
|
parse_oauth_callback_request_target, resolve_sandbox_status, save_oauth_credentials,
|
||||||
save_oauth_credentials, ApiClient, ApiRequest, AssistantEvent, CompactionConfig, ConfigLoader,
|
ApiClient, ApiRequest, AssistantEvent,
|
||||||
ConfigSource, ContentBlock, ConversationMessage, ConversationRuntime, McpServerManager,
|
CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, ConversationMessage,
|
||||||
McpTool, MessageRole, ModelPricing, OAuthAuthorizationRequest, OAuthConfig,
|
ConversationRuntime, MessageRole, OAuthAuthorizationRequest, OAuthConfig,
|
||||||
OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, PromptCacheEvent,
|
OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, PromptCacheEvent,
|
||||||
ResolvedPermissionMode, RuntimeError, Session, TokenUsage, ToolError, ToolExecutor,
|
ResolvedPermissionMode, RuntimeError, Session, TokenUsage, ToolError, ToolExecutor,
|
||||||
UsageTracker,
|
UsageTracker, ModelPricing, format_usd, pricing_for_model,
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use tools::{GlobalToolRegistry, RuntimeToolDefinition, ToolSearchOutput};
|
use tools::GlobalToolRegistry;
|
||||||
|
|
||||||
const DEFAULT_MODEL: &str = "claude-opus-4-6";
|
const DEFAULT_MODEL: &str = "claude-opus-4-6";
|
||||||
fn max_tokens_for_model(model: &str) -> u32 {
|
fn max_tokens_for_model(model: &str) -> u32 {
|
||||||
|
|
@ -595,17 +594,11 @@ fn current_tool_registry() -> Result<GlobalToolRegistry, String> {
|
||||||
let cwd = env::current_dir().map_err(|error| error.to_string())?;
|
let cwd = env::current_dir().map_err(|error| error.to_string())?;
|
||||||
let loader = ConfigLoader::default_for(&cwd);
|
let loader = ConfigLoader::default_for(&cwd);
|
||||||
let runtime_config = loader.load().map_err(|error| error.to_string())?;
|
let runtime_config = loader.load().map_err(|error| error.to_string())?;
|
||||||
let state = build_runtime_plugin_state_with_loader(&cwd, &loader, &runtime_config)
|
let plugin_manager = build_plugin_manager(&cwd, &loader, &runtime_config);
|
||||||
|
let plugin_tools = plugin_manager
|
||||||
|
.aggregated_tools()
|
||||||
.map_err(|error| error.to_string())?;
|
.map_err(|error| error.to_string())?;
|
||||||
let registry = state.tool_registry.clone();
|
GlobalToolRegistry::with_plugin_tools(plugin_tools)
|
||||||
if let Some(mcp_state) = state.mcp_state {
|
|
||||||
mcp_state
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
||||||
.shutdown()
|
|
||||||
.map_err(|error| error.to_string())?;
|
|
||||||
}
|
|
||||||
Ok(registry)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_permission_mode_arg(value: &str) -> Result<PermissionMode, String> {
|
fn parse_permission_mode_arg(value: &str) -> Result<PermissionMode, String> {
|
||||||
|
|
@ -1594,35 +1587,23 @@ struct RuntimePluginState {
|
||||||
feature_config: runtime::RuntimeFeatureConfig,
|
feature_config: runtime::RuntimeFeatureConfig,
|
||||||
tool_registry: GlobalToolRegistry,
|
tool_registry: GlobalToolRegistry,
|
||||||
plugin_registry: PluginRegistry,
|
plugin_registry: PluginRegistry,
|
||||||
mcp_state: Option<Arc<Mutex<RuntimeMcpState>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct RuntimeMcpState {
|
|
||||||
runtime: tokio::runtime::Runtime,
|
|
||||||
manager: McpServerManager,
|
|
||||||
pending_servers: Vec<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct BuiltRuntime {
|
struct BuiltRuntime {
|
||||||
runtime: Option<ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>>,
|
runtime: Option<ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>>,
|
||||||
plugin_registry: PluginRegistry,
|
plugin_registry: PluginRegistry,
|
||||||
plugins_active: bool,
|
plugins_active: bool,
|
||||||
mcp_state: Option<Arc<Mutex<RuntimeMcpState>>>,
|
|
||||||
mcp_active: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BuiltRuntime {
|
impl BuiltRuntime {
|
||||||
fn new(
|
fn new(
|
||||||
runtime: ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>,
|
runtime: ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>,
|
||||||
plugin_registry: PluginRegistry,
|
plugin_registry: PluginRegistry,
|
||||||
mcp_state: Option<Arc<Mutex<RuntimeMcpState>>>,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
runtime: Some(runtime),
|
runtime: Some(runtime),
|
||||||
plugin_registry,
|
plugin_registry,
|
||||||
plugins_active: true,
|
plugins_active: true,
|
||||||
mcp_state,
|
|
||||||
mcp_active: true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1642,19 +1623,6 @@ impl BuiltRuntime {
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn shutdown_mcp(&mut self) -> Result<(), Box<dyn std::error::Error>> {
|
|
||||||
if self.mcp_active {
|
|
||||||
if let Some(mcp_state) = &self.mcp_state {
|
|
||||||
mcp_state
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
||||||
.shutdown()?;
|
|
||||||
}
|
|
||||||
self.mcp_active = false;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Deref for BuiltRuntime {
|
impl Deref for BuiltRuntime {
|
||||||
|
|
@ -1677,284 +1645,10 @@ impl DerefMut for BuiltRuntime {
|
||||||
|
|
||||||
impl Drop for BuiltRuntime {
|
impl Drop for BuiltRuntime {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
let _ = self.shutdown_mcp();
|
|
||||||
let _ = self.shutdown_plugins();
|
let _ = self.shutdown_plugins();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct ToolSearchRequest {
|
|
||||||
query: String,
|
|
||||||
max_results: Option<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct McpToolRequest {
|
|
||||||
#[serde(rename = "qualifiedName")]
|
|
||||||
qualified_name: Option<String>,
|
|
||||||
tool: Option<String>,
|
|
||||||
arguments: Option<serde_json::Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct ListMcpResourcesRequest {
|
|
||||||
server: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct ReadMcpResourceRequest {
|
|
||||||
server: String,
|
|
||||||
uri: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RuntimeMcpState {
|
|
||||||
fn new(
|
|
||||||
runtime_config: &runtime::RuntimeConfig,
|
|
||||||
) -> Result<Option<(Self, runtime::McpToolDiscoveryReport)>, Box<dyn std::error::Error>> {
|
|
||||||
let mut manager = McpServerManager::from_runtime_config(runtime_config);
|
|
||||||
if manager.server_names().is_empty() && manager.unsupported_servers().is_empty() {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
let runtime = tokio::runtime::Runtime::new()?;
|
|
||||||
let discovery = runtime.block_on(manager.discover_tools_best_effort());
|
|
||||||
let pending_servers = discovery
|
|
||||||
.failed_servers
|
|
||||||
.iter()
|
|
||||||
.map(|failure| failure.server_name.clone())
|
|
||||||
.chain(
|
|
||||||
discovery
|
|
||||||
.unsupported_servers
|
|
||||||
.iter()
|
|
||||||
.map(|server| server.server_name.clone()),
|
|
||||||
)
|
|
||||||
.collect::<BTreeSet<_>>()
|
|
||||||
.into_iter()
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
Ok(Some((
|
|
||||||
Self {
|
|
||||||
runtime,
|
|
||||||
manager,
|
|
||||||
pending_servers,
|
|
||||||
},
|
|
||||||
discovery,
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn shutdown(&mut self) -> Result<(), Box<dyn std::error::Error>> {
|
|
||||||
self.runtime.block_on(self.manager.shutdown())?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn pending_servers(&self) -> Option<Vec<String>> {
|
|
||||||
(!self.pending_servers.is_empty()).then(|| self.pending_servers.clone())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn server_names(&self) -> Vec<String> {
|
|
||||||
self.manager.server_names()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn call_tool(
|
|
||||||
&mut self,
|
|
||||||
qualified_tool_name: &str,
|
|
||||||
arguments: Option<serde_json::Value>,
|
|
||||||
) -> Result<String, ToolError> {
|
|
||||||
let response = self
|
|
||||||
.runtime
|
|
||||||
.block_on(self.manager.call_tool(qualified_tool_name, arguments))
|
|
||||||
.map_err(|error| ToolError::new(error.to_string()))?;
|
|
||||||
if let Some(error) = response.error {
|
|
||||||
return Err(ToolError::new(format!(
|
|
||||||
"MCP tool `{qualified_tool_name}` returned JSON-RPC error: {} ({})",
|
|
||||||
error.message, error.code
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = response.result.ok_or_else(|| {
|
|
||||||
ToolError::new(format!(
|
|
||||||
"MCP tool `{qualified_tool_name}` returned no result payload"
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
serde_json::to_string_pretty(&result).map_err(|error| ToolError::new(error.to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn list_resources_for_server(&mut self, server_name: &str) -> Result<String, ToolError> {
|
|
||||||
let result = self
|
|
||||||
.runtime
|
|
||||||
.block_on(self.manager.list_resources(server_name))
|
|
||||||
.map_err(|error| ToolError::new(error.to_string()))?;
|
|
||||||
serde_json::to_string_pretty(&json!({
|
|
||||||
"server": server_name,
|
|
||||||
"resources": result.resources,
|
|
||||||
}))
|
|
||||||
.map_err(|error| ToolError::new(error.to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn list_resources_for_all_servers(&mut self) -> Result<String, ToolError> {
|
|
||||||
let mut resources = Vec::new();
|
|
||||||
let mut failures = Vec::new();
|
|
||||||
|
|
||||||
for server_name in self.server_names() {
|
|
||||||
match self
|
|
||||||
.runtime
|
|
||||||
.block_on(self.manager.list_resources(&server_name))
|
|
||||||
{
|
|
||||||
Ok(result) => resources.push(json!({
|
|
||||||
"server": server_name,
|
|
||||||
"resources": result.resources,
|
|
||||||
})),
|
|
||||||
Err(error) => failures.push(json!({
|
|
||||||
"server": server_name,
|
|
||||||
"error": error.to_string(),
|
|
||||||
})),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if resources.is_empty() && !failures.is_empty() {
|
|
||||||
let message = failures
|
|
||||||
.iter()
|
|
||||||
.filter_map(|failure| failure.get("error").and_then(serde_json::Value::as_str))
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join("; ");
|
|
||||||
return Err(ToolError::new(message));
|
|
||||||
}
|
|
||||||
|
|
||||||
serde_json::to_string_pretty(&json!({
|
|
||||||
"resources": resources,
|
|
||||||
"failures": failures,
|
|
||||||
}))
|
|
||||||
.map_err(|error| ToolError::new(error.to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read_resource(&mut self, server_name: &str, uri: &str) -> Result<String, ToolError> {
|
|
||||||
let result = self
|
|
||||||
.runtime
|
|
||||||
.block_on(self.manager.read_resource(server_name, uri))
|
|
||||||
.map_err(|error| ToolError::new(error.to_string()))?;
|
|
||||||
serde_json::to_string_pretty(&json!({
|
|
||||||
"server": server_name,
|
|
||||||
"contents": result.contents,
|
|
||||||
}))
|
|
||||||
.map_err(|error| ToolError::new(error.to_string()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_runtime_mcp_state(
|
|
||||||
runtime_config: &runtime::RuntimeConfig,
|
|
||||||
) -> Result<
|
|
||||||
(
|
|
||||||
Option<Arc<Mutex<RuntimeMcpState>>>,
|
|
||||||
Vec<RuntimeToolDefinition>,
|
|
||||||
),
|
|
||||||
Box<dyn std::error::Error>,
|
|
||||||
> {
|
|
||||||
let Some((mcp_state, discovery)) = RuntimeMcpState::new(runtime_config)? else {
|
|
||||||
return Ok((None, Vec::new()));
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut runtime_tools = discovery
|
|
||||||
.tools
|
|
||||||
.iter()
|
|
||||||
.map(mcp_runtime_tool_definition)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
if !mcp_state.server_names().is_empty() {
|
|
||||||
runtime_tools.extend(mcp_wrapper_tool_definitions());
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok((Some(Arc::new(Mutex::new(mcp_state))), runtime_tools))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mcp_runtime_tool_definition(tool: &runtime::ManagedMcpTool) -> RuntimeToolDefinition {
|
|
||||||
RuntimeToolDefinition {
|
|
||||||
name: tool.qualified_name.clone(),
|
|
||||||
description: Some(
|
|
||||||
tool.tool
|
|
||||||
.description
|
|
||||||
.clone()
|
|
||||||
.unwrap_or_else(|| format!("Invoke MCP tool `{}`.", tool.qualified_name)),
|
|
||||||
),
|
|
||||||
input_schema: tool
|
|
||||||
.tool
|
|
||||||
.input_schema
|
|
||||||
.clone()
|
|
||||||
.unwrap_or_else(|| json!({ "type": "object", "additionalProperties": true })),
|
|
||||||
required_permission: permission_mode_for_mcp_tool(&tool.tool),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mcp_wrapper_tool_definitions() -> Vec<RuntimeToolDefinition> {
|
|
||||||
vec![
|
|
||||||
RuntimeToolDefinition {
|
|
||||||
name: "MCPTool".to_string(),
|
|
||||||
description: Some(
|
|
||||||
"Call a configured MCP tool by its qualified name and JSON arguments.".to_string(),
|
|
||||||
),
|
|
||||||
input_schema: json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"qualifiedName": { "type": "string" },
|
|
||||||
"arguments": {}
|
|
||||||
},
|
|
||||||
"required": ["qualifiedName"],
|
|
||||||
"additionalProperties": false
|
|
||||||
}),
|
|
||||||
required_permission: PermissionMode::DangerFullAccess,
|
|
||||||
},
|
|
||||||
RuntimeToolDefinition {
|
|
||||||
name: "ListMcpResourcesTool".to_string(),
|
|
||||||
description: Some(
|
|
||||||
"List MCP resources from one configured server or from every connected server."
|
|
||||||
.to_string(),
|
|
||||||
),
|
|
||||||
input_schema: json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"server": { "type": "string" }
|
|
||||||
},
|
|
||||||
"additionalProperties": false
|
|
||||||
}),
|
|
||||||
required_permission: PermissionMode::ReadOnly,
|
|
||||||
},
|
|
||||||
RuntimeToolDefinition {
|
|
||||||
name: "ReadMcpResourceTool".to_string(),
|
|
||||||
description: Some("Read a specific MCP resource from a configured server.".to_string()),
|
|
||||||
input_schema: json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"server": { "type": "string" },
|
|
||||||
"uri": { "type": "string" }
|
|
||||||
},
|
|
||||||
"required": ["server", "uri"],
|
|
||||||
"additionalProperties": false
|
|
||||||
}),
|
|
||||||
required_permission: PermissionMode::ReadOnly,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
fn permission_mode_for_mcp_tool(tool: &McpTool) -> PermissionMode {
|
|
||||||
let read_only = mcp_annotation_flag(tool, "readOnlyHint");
|
|
||||||
let destructive = mcp_annotation_flag(tool, "destructiveHint");
|
|
||||||
let open_world = mcp_annotation_flag(tool, "openWorldHint");
|
|
||||||
|
|
||||||
if read_only && !destructive && !open_world {
|
|
||||||
PermissionMode::ReadOnly
|
|
||||||
} else if destructive || open_world {
|
|
||||||
PermissionMode::DangerFullAccess
|
|
||||||
} else {
|
|
||||||
PermissionMode::WorkspaceWrite
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mcp_annotation_flag(tool: &McpTool, key: &str) -> bool {
|
|
||||||
tool.annotations
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|annotations| annotations.get(key))
|
|
||||||
.and_then(serde_json::Value::as_bool)
|
|
||||||
.unwrap_or(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
struct HookAbortMonitor {
|
struct HookAbortMonitor {
|
||||||
stop_tx: Option<Sender<()>>,
|
stop_tx: Option<Sender<()>>,
|
||||||
join_handle: Option<JoinHandle<()>>,
|
join_handle: Option<JoinHandle<()>>,
|
||||||
|
|
@ -3867,14 +3561,11 @@ fn build_runtime_plugin_state_with_loader(
|
||||||
.feature_config()
|
.feature_config()
|
||||||
.clone()
|
.clone()
|
||||||
.with_hooks(runtime_config.hooks().merged(&plugin_hook_config));
|
.with_hooks(runtime_config.hooks().merged(&plugin_hook_config));
|
||||||
let (mcp_state, runtime_tools) = build_runtime_mcp_state(runtime_config)?;
|
let tool_registry = GlobalToolRegistry::with_plugin_tools(plugin_registry.aggregated_tools()?)?;
|
||||||
let tool_registry = GlobalToolRegistry::with_plugin_tools(plugin_registry.aggregated_tools()?)?
|
|
||||||
.with_runtime_tools(runtime_tools)?;
|
|
||||||
Ok(RuntimePluginState {
|
Ok(RuntimePluginState {
|
||||||
feature_config,
|
feature_config,
|
||||||
tool_registry,
|
tool_registry,
|
||||||
plugin_registry,
|
plugin_registry,
|
||||||
mcp_state,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -4296,7 +3987,6 @@ fn build_runtime_with_plugin_state(
|
||||||
feature_config,
|
feature_config,
|
||||||
tool_registry,
|
tool_registry,
|
||||||
plugin_registry,
|
plugin_registry,
|
||||||
mcp_state,
|
|
||||||
} = runtime_plugin_state;
|
} = runtime_plugin_state;
|
||||||
plugin_registry.initialize()?;
|
plugin_registry.initialize()?;
|
||||||
let policy = permission_policy(permission_mode, &feature_config, &tool_registry)
|
let policy = permission_policy(permission_mode, &feature_config, &tool_registry)
|
||||||
|
|
@ -4312,12 +4002,7 @@ fn build_runtime_with_plugin_state(
|
||||||
tool_registry.clone(),
|
tool_registry.clone(),
|
||||||
progress_reporter,
|
progress_reporter,
|
||||||
)?,
|
)?,
|
||||||
CliToolExecutor::new(
|
CliToolExecutor::new(allowed_tools.clone(), emit_output, tool_registry),
|
||||||
allowed_tools.clone(),
|
|
||||||
emit_output,
|
|
||||||
tool_registry.clone(),
|
|
||||||
mcp_state.clone(),
|
|
||||||
),
|
|
||||||
policy,
|
policy,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
&feature_config,
|
&feature_config,
|
||||||
|
|
@ -4325,7 +4010,7 @@ fn build_runtime_with_plugin_state(
|
||||||
if emit_output {
|
if emit_output {
|
||||||
runtime = runtime.with_hook_progress_reporter(Box::new(CliHookProgressReporter));
|
runtime = runtime.with_hook_progress_reporter(Box::new(CliHookProgressReporter));
|
||||||
}
|
}
|
||||||
Ok(BuiltRuntime::new(runtime, plugin_registry, mcp_state))
|
Ok(BuiltRuntime::new(runtime, plugin_registry))
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CliHookProgressReporter;
|
struct CliHookProgressReporter;
|
||||||
|
|
@ -5264,7 +4949,6 @@ struct CliToolExecutor {
|
||||||
emit_output: bool,
|
emit_output: bool,
|
||||||
allowed_tools: Option<AllowedToolSet>,
|
allowed_tools: Option<AllowedToolSet>,
|
||||||
tool_registry: GlobalToolRegistry,
|
tool_registry: GlobalToolRegistry,
|
||||||
mcp_state: Option<Arc<Mutex<RuntimeMcpState>>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CliToolExecutor {
|
impl CliToolExecutor {
|
||||||
|
|
@ -5272,72 +4956,12 @@ impl CliToolExecutor {
|
||||||
allowed_tools: Option<AllowedToolSet>,
|
allowed_tools: Option<AllowedToolSet>,
|
||||||
emit_output: bool,
|
emit_output: bool,
|
||||||
tool_registry: GlobalToolRegistry,
|
tool_registry: GlobalToolRegistry,
|
||||||
mcp_state: Option<Arc<Mutex<RuntimeMcpState>>>,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
renderer: TerminalRenderer::new(),
|
renderer: TerminalRenderer::new(),
|
||||||
emit_output,
|
emit_output,
|
||||||
allowed_tools,
|
allowed_tools,
|
||||||
tool_registry,
|
tool_registry,
|
||||||
mcp_state,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn execute_search_tool(&self, value: serde_json::Value) -> Result<String, ToolError> {
|
|
||||||
let input: ToolSearchRequest = serde_json::from_value(value)
|
|
||||||
.map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?;
|
|
||||||
let pending_mcp_servers = self.mcp_state.as_ref().and_then(|state| {
|
|
||||||
state
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
||||||
.pending_servers()
|
|
||||||
});
|
|
||||||
serde_json::to_string_pretty(&self.tool_registry.search(
|
|
||||||
&input.query,
|
|
||||||
input.max_results.unwrap_or(5),
|
|
||||||
pending_mcp_servers,
|
|
||||||
))
|
|
||||||
.map_err(|error| ToolError::new(error.to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn execute_runtime_tool(
|
|
||||||
&self,
|
|
||||||
tool_name: &str,
|
|
||||||
value: serde_json::Value,
|
|
||||||
) -> Result<String, ToolError> {
|
|
||||||
let Some(mcp_state) = &self.mcp_state else {
|
|
||||||
return Err(ToolError::new(format!(
|
|
||||||
"runtime tool `{tool_name}` is unavailable without configured MCP servers"
|
|
||||||
)));
|
|
||||||
};
|
|
||||||
let mut mcp_state = mcp_state
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
|
||||||
|
|
||||||
match tool_name {
|
|
||||||
"MCPTool" => {
|
|
||||||
let input: McpToolRequest = serde_json::from_value(value)
|
|
||||||
.map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?;
|
|
||||||
let qualified_name = input
|
|
||||||
.qualified_name
|
|
||||||
.or(input.tool)
|
|
||||||
.ok_or_else(|| ToolError::new("missing required field `qualifiedName`"))?;
|
|
||||||
mcp_state.call_tool(&qualified_name, input.arguments)
|
|
||||||
}
|
|
||||||
"ListMcpResourcesTool" => {
|
|
||||||
let input: ListMcpResourcesRequest = serde_json::from_value(value)
|
|
||||||
.map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?;
|
|
||||||
match input.server {
|
|
||||||
Some(server_name) => mcp_state.list_resources_for_server(&server_name),
|
|
||||||
None => mcp_state.list_resources_for_all_servers(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"ReadMcpResourceTool" => {
|
|
||||||
let input: ReadMcpResourceRequest = serde_json::from_value(value)
|
|
||||||
.map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?;
|
|
||||||
mcp_state.read_resource(&input.server, &input.uri)
|
|
||||||
}
|
|
||||||
_ => mcp_state.call_tool(tool_name, Some(value)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -5355,16 +4979,7 @@ impl ToolExecutor for CliToolExecutor {
|
||||||
}
|
}
|
||||||
let value = serde_json::from_str(input)
|
let value = serde_json::from_str(input)
|
||||||
.map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?;
|
.map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?;
|
||||||
let result = if tool_name == "ToolSearch" {
|
match self.tool_registry.execute(tool_name, &value) {
|
||||||
self.execute_search_tool(value)
|
|
||||||
} else if self.tool_registry.has_runtime_tool(tool_name) {
|
|
||||||
self.execute_runtime_tool(tool_name, value)
|
|
||||||
} else {
|
|
||||||
self.tool_registry
|
|
||||||
.execute(tool_name, &value)
|
|
||||||
.map_err(ToolError::new)
|
|
||||||
};
|
|
||||||
match result {
|
|
||||||
Ok(output) => {
|
Ok(output) => {
|
||||||
if self.emit_output {
|
if self.emit_output {
|
||||||
let markdown = format_tool_result(tool_name, &output, false);
|
let markdown = format_tool_result(tool_name, &output, false);
|
||||||
|
|
@ -5376,12 +4991,12 @@ impl ToolExecutor for CliToolExecutor {
|
||||||
}
|
}
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
if self.emit_output {
|
if self.emit_output {
|
||||||
let markdown = format_tool_result(tool_name, &error.to_string(), true);
|
let markdown = format_tool_result(tool_name, &error, true);
|
||||||
self.renderer
|
self.renderer
|
||||||
.stream_markdown(&markdown, &mut io::stdout())
|
.stream_markdown(&markdown, &mut io::stdout())
|
||||||
.map_err(|stream_error| ToolError::new(stream_error.to_string()))?;
|
.map_err(|stream_error| ToolError::new(stream_error.to_string()))?;
|
||||||
}
|
}
|
||||||
Err(error)
|
Err(ToolError::new(error))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -5580,13 +5195,12 @@ mod tests {
|
||||||
format_unknown_slash_command_message, normalize_permission_mode, parse_args,
|
format_unknown_slash_command_message, normalize_permission_mode, parse_args,
|
||||||
parse_git_status_branch, parse_git_status_metadata_for, parse_git_workspace_summary,
|
parse_git_status_branch, parse_git_status_metadata_for, parse_git_workspace_summary,
|
||||||
permission_policy, print_help_to, push_output_block, render_config_report,
|
permission_policy, print_help_to, push_output_block, render_config_report,
|
||||||
render_diff_report, render_diff_report_for, render_memory_report, render_repl_help, render_resume_usage,
|
render_diff_report, render_memory_report, render_repl_help, render_resume_usage,
|
||||||
resolve_model_alias, resolve_session_reference, response_to_events,
|
resolve_model_alias, resolve_session_reference, response_to_events,
|
||||||
resume_supported_slash_commands, run_resume_command,
|
resume_supported_slash_commands, run_resume_command,
|
||||||
slash_command_completion_candidates_with_sessions, status_context, validate_no_args,
|
slash_command_completion_candidates_with_sessions, status_context, validate_no_args,
|
||||||
write_mcp_server_fixture, CliAction, CliOutputFormat, CliToolExecutor, GitWorkspaceSummary,
|
CliAction, CliOutputFormat, GitWorkspaceSummary, InternalPromptProgressEvent,
|
||||||
InternalPromptProgressEvent, InternalPromptProgressState, LiveCli, SlashCommand,
|
InternalPromptProgressState, LiveCli, SlashCommand, StatusUsage, DEFAULT_MODEL,
|
||||||
StatusUsage, DEFAULT_MODEL,
|
|
||||||
};
|
};
|
||||||
use api::{MessageResponse, OutputContentBlock, Usage};
|
use api::{MessageResponse, OutputContentBlock, Usage};
|
||||||
use plugins::{
|
use plugins::{
|
||||||
|
|
@ -5594,7 +5208,7 @@ mod tests {
|
||||||
};
|
};
|
||||||
use runtime::{
|
use runtime::{
|
||||||
AssistantEvent, ConfigLoader, ContentBlock, ConversationMessage, MessageRole,
|
AssistantEvent, ConfigLoader, ContentBlock, ConversationMessage, MessageRole,
|
||||||
PermissionMode, Session, ToolExecutor,
|
PermissionMode, Session,
|
||||||
};
|
};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
|
@ -6318,11 +5932,7 @@ mod tests {
|
||||||
.map(|spec| spec.name)
|
.map(|spec| spec.name)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
// Now with 135+ slash commands, verify minimum resume support
|
// Now with 135+ slash commands, verify minimum resume support
|
||||||
assert!(
|
assert!(names.len() >= 39, "expected at least 39 resume-supported commands, got {}", names.len());
|
||||||
names.len() >= 39,
|
|
||||||
"expected at least 39 resume-supported commands, got {}",
|
|
||||||
names.len()
|
|
||||||
);
|
|
||||||
// Verify key resume commands still exist
|
// Verify key resume commands still exist
|
||||||
assert!(names.contains(&"help"));
|
assert!(names.contains(&"help"));
|
||||||
assert!(names.contains(&"status"));
|
assert!(names.contains(&"status"));
|
||||||
|
|
@ -6632,7 +6242,9 @@ UU conflicted.rs",
|
||||||
git(&["add", "tracked.txt"], &root);
|
git(&["add", "tracked.txt"], &root);
|
||||||
git(&["commit", "-m", "init", "--quiet"], &root);
|
git(&["commit", "-m", "init", "--quiet"], &root);
|
||||||
|
|
||||||
let report = render_diff_report_for(&root).expect("diff report should render");
|
let report = with_current_dir(&root, || {
|
||||||
|
render_diff_report().expect("diff report should render")
|
||||||
|
});
|
||||||
assert!(report.contains("clean working tree"));
|
assert!(report.contains("clean working tree"));
|
||||||
|
|
||||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
|
|
@ -6655,7 +6267,9 @@ UU conflicted.rs",
|
||||||
fs::write(root.join("tracked.txt"), "hello\nstaged\nunstaged\n")
|
fs::write(root.join("tracked.txt"), "hello\nstaged\nunstaged\n")
|
||||||
.expect("update file twice");
|
.expect("update file twice");
|
||||||
|
|
||||||
let report = render_diff_report_for(&root).expect("diff report should render");
|
let report = with_current_dir(&root, || {
|
||||||
|
render_diff_report().expect("diff report should render")
|
||||||
|
});
|
||||||
assert!(report.contains("Staged changes:"));
|
assert!(report.contains("Staged changes:"));
|
||||||
assert!(report.contains("Unstaged changes:"));
|
assert!(report.contains("Unstaged changes:"));
|
||||||
assert!(report.contains("tracked.txt"));
|
assert!(report.contains("tracked.txt"));
|
||||||
|
|
@ -6680,7 +6294,9 @@ UU conflicted.rs",
|
||||||
fs::write(root.join("ignored.txt"), "secret\n").expect("write ignored file");
|
fs::write(root.join("ignored.txt"), "secret\n").expect("write ignored file");
|
||||||
fs::write(root.join("tracked.txt"), "hello\nworld\n").expect("write tracked change");
|
fs::write(root.join("tracked.txt"), "hello\nworld\n").expect("write tracked change");
|
||||||
|
|
||||||
let report = render_diff_report_for(&root).expect("diff report should render");
|
let report = with_current_dir(&root, || {
|
||||||
|
render_diff_report().expect("diff report should render")
|
||||||
|
});
|
||||||
assert!(report.contains("tracked.txt"));
|
assert!(report.contains("tracked.txt"));
|
||||||
assert!(!report.contains("+++ b/ignored.txt"));
|
assert!(!report.contains("+++ b/ignored.txt"));
|
||||||
assert!(!report.contains("+++ b/.omx/state.json"));
|
assert!(!report.contains("+++ b/.omx/state.json"));
|
||||||
|
|
@ -6900,11 +6516,7 @@ UU conflicted.rs",
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn init_template_mentions_detected_rust_workspace() {
|
fn init_template_mentions_detected_rust_workspace() {
|
||||||
let _guard = cwd_lock()
|
let rendered = crate::init::render_init_claude_md(std::path::Path::new("."));
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
|
||||||
let workspace_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../..");
|
|
||||||
let rendered = crate::init::render_init_claude_md(&workspace_root);
|
|
||||||
assert!(rendered.contains("# CLAUDE.md"));
|
assert!(rendered.contains("# CLAUDE.md"));
|
||||||
assert!(rendered.contains("cargo clippy --workspace --all-targets -- -D warnings"));
|
assert!(rendered.contains("cargo clippy --workspace --all-targets -- -D warnings"));
|
||||||
}
|
}
|
||||||
|
|
@ -7295,111 +6907,6 @@ UU conflicted.rs",
|
||||||
let _ = fs::remove_dir_all(source_root);
|
let _ = fs::remove_dir_all(source_root);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn build_runtime_plugin_state_discovers_mcp_tools_and_surfaces_pending_servers() {
|
|
||||||
let config_home = temp_dir();
|
|
||||||
let workspace = temp_dir();
|
|
||||||
fs::create_dir_all(&config_home).expect("config home");
|
|
||||||
fs::create_dir_all(&workspace).expect("workspace");
|
|
||||||
let script_path = workspace.join("fixture-mcp.py");
|
|
||||||
write_mcp_server_fixture(&script_path);
|
|
||||||
fs::write(
|
|
||||||
config_home.join("settings.json"),
|
|
||||||
format!(
|
|
||||||
r#"{{
|
|
||||||
"mcpServers": {{
|
|
||||||
"alpha": {{
|
|
||||||
"command": "python3",
|
|
||||||
"args": ["{}"]
|
|
||||||
}},
|
|
||||||
"broken": {{
|
|
||||||
"command": "python3",
|
|
||||||
"args": ["-c", "import sys; sys.exit(0)"]
|
|
||||||
}}
|
|
||||||
}}
|
|
||||||
}}"#,
|
|
||||||
script_path.to_string_lossy()
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.expect("write mcp settings");
|
|
||||||
|
|
||||||
let loader = ConfigLoader::new(&workspace, &config_home);
|
|
||||||
let runtime_config = loader.load().expect("runtime config should load");
|
|
||||||
let state = build_runtime_plugin_state_with_loader(&workspace, &loader, &runtime_config)
|
|
||||||
.expect("runtime plugin state should load");
|
|
||||||
|
|
||||||
let allowed = state
|
|
||||||
.tool_registry
|
|
||||||
.normalize_allowed_tools(&["mcp__alpha__echo".to_string(), "MCPTool".to_string()])
|
|
||||||
.expect("mcp tools should be allow-listable")
|
|
||||||
.expect("allow-list should exist");
|
|
||||||
assert!(allowed.contains("mcp__alpha__echo"));
|
|
||||||
assert!(allowed.contains("MCPTool"));
|
|
||||||
|
|
||||||
let mut executor = CliToolExecutor::new(
|
|
||||||
None,
|
|
||||||
false,
|
|
||||||
state.tool_registry.clone(),
|
|
||||||
state.mcp_state.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let tool_output = executor
|
|
||||||
.execute("mcp__alpha__echo", r#"{"text":"hello"}"#)
|
|
||||||
.expect("discovered mcp tool should execute");
|
|
||||||
let tool_json: serde_json::Value =
|
|
||||||
serde_json::from_str(&tool_output).expect("tool output should be json");
|
|
||||||
assert_eq!(tool_json["structuredContent"]["echoed"], "hello");
|
|
||||||
|
|
||||||
let wrapped_output = executor
|
|
||||||
.execute(
|
|
||||||
"MCPTool",
|
|
||||||
r#"{"qualifiedName":"mcp__alpha__echo","arguments":{"text":"wrapped"}}"#,
|
|
||||||
)
|
|
||||||
.expect("generic mcp wrapper should execute");
|
|
||||||
let wrapped_json: serde_json::Value =
|
|
||||||
serde_json::from_str(&wrapped_output).expect("wrapped output should be json");
|
|
||||||
assert_eq!(wrapped_json["structuredContent"]["echoed"], "wrapped");
|
|
||||||
|
|
||||||
let search_output = executor
|
|
||||||
.execute("ToolSearch", r#"{"query":"alpha echo","max_results":5}"#)
|
|
||||||
.expect("tool search should execute");
|
|
||||||
let search_json: serde_json::Value =
|
|
||||||
serde_json::from_str(&search_output).expect("search output should be json");
|
|
||||||
assert_eq!(search_json["matches"][0], "mcp__alpha__echo");
|
|
||||||
assert_eq!(search_json["pending_mcp_servers"][0], "broken");
|
|
||||||
|
|
||||||
let listed = executor
|
|
||||||
.execute("ListMcpResourcesTool", r#"{"server":"alpha"}"#)
|
|
||||||
.expect("resources should list");
|
|
||||||
let listed_json: serde_json::Value =
|
|
||||||
serde_json::from_str(&listed).expect("resource output should be json");
|
|
||||||
assert_eq!(listed_json["resources"][0]["uri"], "file://guide.txt");
|
|
||||||
|
|
||||||
let read = executor
|
|
||||||
.execute(
|
|
||||||
"ReadMcpResourceTool",
|
|
||||||
r#"{"server":"alpha","uri":"file://guide.txt"}"#,
|
|
||||||
)
|
|
||||||
.expect("resource should read");
|
|
||||||
let read_json: serde_json::Value =
|
|
||||||
serde_json::from_str(&read).expect("resource read output should be json");
|
|
||||||
assert_eq!(
|
|
||||||
read_json["contents"][0]["text"],
|
|
||||||
"contents for file://guide.txt"
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(mcp_state) = state.mcp_state {
|
|
||||||
mcp_state
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
||||||
.shutdown()
|
|
||||||
.expect("mcp shutdown should succeed");
|
|
||||||
}
|
|
||||||
|
|
||||||
let _ = fs::remove_dir_all(config_home);
|
|
||||||
let _ = fs::remove_dir_all(workspace);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn build_runtime_runs_plugin_lifecycle_init_and_shutdown() {
|
fn build_runtime_runs_plugin_lifecycle_init_and_shutdown() {
|
||||||
let config_home = temp_dir();
|
let config_home = temp_dir();
|
||||||
|
|
@ -7458,105 +6965,6 @@ UU conflicted.rs",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn write_mcp_server_fixture(script_path: &Path) {
|
|
||||||
let script = [
|
|
||||||
"#!/usr/bin/env python3",
|
|
||||||
"import json, sys",
|
|
||||||
"",
|
|
||||||
"def read_message():",
|
|
||||||
" header = b''",
|
|
||||||
r" while not header.endswith(b'\r\n\r\n'):",
|
|
||||||
" chunk = sys.stdin.buffer.read(1)",
|
|
||||||
" if not chunk:",
|
|
||||||
" return None",
|
|
||||||
" header += chunk",
|
|
||||||
" length = 0",
|
|
||||||
r" for line in header.decode().split('\r\n'):",
|
|
||||||
r" if line.lower().startswith('content-length:'):",
|
|
||||||
" length = int(line.split(':', 1)[1].strip())",
|
|
||||||
" payload = sys.stdin.buffer.read(length)",
|
|
||||||
" return json.loads(payload.decode())",
|
|
||||||
"",
|
|
||||||
"def send_message(message):",
|
|
||||||
" payload = json.dumps(message).encode()",
|
|
||||||
r" sys.stdout.buffer.write(f'Content-Length: {len(payload)}\r\n\r\n'.encode() + payload)",
|
|
||||||
" sys.stdout.buffer.flush()",
|
|
||||||
"",
|
|
||||||
"while True:",
|
|
||||||
" request = read_message()",
|
|
||||||
" if request is None:",
|
|
||||||
" break",
|
|
||||||
" method = request['method']",
|
|
||||||
" if method == 'initialize':",
|
|
||||||
" send_message({",
|
|
||||||
" 'jsonrpc': '2.0',",
|
|
||||||
" 'id': request['id'],",
|
|
||||||
" 'result': {",
|
|
||||||
" 'protocolVersion': request['params']['protocolVersion'],",
|
|
||||||
" 'capabilities': {'tools': {}, 'resources': {}},",
|
|
||||||
" 'serverInfo': {'name': 'fixture', 'version': '1.0.0'}",
|
|
||||||
" }",
|
|
||||||
" })",
|
|
||||||
" elif method == 'tools/list':",
|
|
||||||
" send_message({",
|
|
||||||
" 'jsonrpc': '2.0',",
|
|
||||||
" 'id': request['id'],",
|
|
||||||
" 'result': {",
|
|
||||||
" 'tools': [",
|
|
||||||
" {",
|
|
||||||
" 'name': 'echo',",
|
|
||||||
" 'description': 'Echo from MCP fixture',",
|
|
||||||
" 'inputSchema': {",
|
|
||||||
" 'type': 'object',",
|
|
||||||
" 'properties': {'text': {'type': 'string'}},",
|
|
||||||
" 'required': ['text'],",
|
|
||||||
" 'additionalProperties': False",
|
|
||||||
" },",
|
|
||||||
" 'annotations': {'readOnlyHint': True}",
|
|
||||||
" }",
|
|
||||||
" ]",
|
|
||||||
" }",
|
|
||||||
" })",
|
|
||||||
" elif method == 'tools/call':",
|
|
||||||
" args = request['params'].get('arguments') or {}",
|
|
||||||
" send_message({",
|
|
||||||
" 'jsonrpc': '2.0',",
|
|
||||||
" 'id': request['id'],",
|
|
||||||
" 'result': {",
|
|
||||||
" 'content': [{'type': 'text', 'text': f\"echo:{args.get('text', '')}\"}],",
|
|
||||||
" 'structuredContent': {'echoed': args.get('text', '')},",
|
|
||||||
" 'isError': False",
|
|
||||||
" }",
|
|
||||||
" })",
|
|
||||||
" elif method == 'resources/list':",
|
|
||||||
" send_message({",
|
|
||||||
" 'jsonrpc': '2.0',",
|
|
||||||
" 'id': request['id'],",
|
|
||||||
" 'result': {",
|
|
||||||
" 'resources': [{'uri': 'file://guide.txt', 'name': 'guide', 'mimeType': 'text/plain'}]",
|
|
||||||
" }",
|
|
||||||
" })",
|
|
||||||
" elif method == 'resources/read':",
|
|
||||||
" uri = request['params']['uri']",
|
|
||||||
" send_message({",
|
|
||||||
" 'jsonrpc': '2.0',",
|
|
||||||
" 'id': request['id'],",
|
|
||||||
" 'result': {",
|
|
||||||
" 'contents': [{'uri': uri, 'mimeType': 'text/plain', 'text': f'contents for {uri}'}]",
|
|
||||||
" }",
|
|
||||||
" })",
|
|
||||||
" else:",
|
|
||||||
" send_message({",
|
|
||||||
" 'jsonrpc': '2.0',",
|
|
||||||
" 'id': request['id'],",
|
|
||||||
" 'error': {'code': -32601, 'message': method}",
|
|
||||||
" })",
|
|
||||||
"",
|
|
||||||
]
|
|
||||||
.join("\n");
|
|
||||||
fs::write(script_path, script).expect("mcp fixture script should write");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod sandbox_report_tests {
|
mod sandbox_report_tests {
|
||||||
use super::{format_sandbox_report, HookAbortMonitor};
|
use super::{format_sandbox_report, HookAbortMonitor};
|
||||||
|
|
|
||||||
|
|
@ -695,10 +695,7 @@ fn assert_auto_compact_triggered(_: &HarnessWorkspace, run: &ScenarioRun) {
|
||||||
);
|
);
|
||||||
// auto_compaction key must be present in JSON (may be null for below-threshold sessions)
|
// auto_compaction key must be present in JSON (may be null for below-threshold sessions)
|
||||||
assert!(
|
assert!(
|
||||||
run.response
|
run.response.as_object().expect("response object").contains_key("auto_compaction"),
|
||||||
.as_object()
|
|
||||||
.expect("response object")
|
|
||||||
.contains_key("auto_compaction"),
|
|
||||||
"auto_compaction key must be present in JSON output"
|
"auto_compaction key must be present in JSON output"
|
||||||
);
|
);
|
||||||
// Verify input_tokens field reflects the large mock token counts
|
// Verify input_tokens field reflects the large mock token counts
|
||||||
|
|
@ -713,10 +710,12 @@ fn assert_auto_compact_triggered(_: &HarnessWorkspace, run: &ScenarioRun) {
|
||||||
|
|
||||||
fn assert_token_cost_reporting(_: &HarnessWorkspace, run: &ScenarioRun) {
|
fn assert_token_cost_reporting(_: &HarnessWorkspace, run: &ScenarioRun) {
|
||||||
assert_eq!(run.response["iterations"], Value::from(1));
|
assert_eq!(run.response["iterations"], Value::from(1));
|
||||||
assert!(run.response["message"]
|
assert!(
|
||||||
.as_str()
|
run.response["message"]
|
||||||
.expect("message text")
|
.as_str()
|
||||||
.contains("token cost reporting parity complete."),);
|
.expect("message text")
|
||||||
|
.contains("token cost reporting parity complete."),
|
||||||
|
);
|
||||||
let usage = &run.response["usage"];
|
let usage = &run.response["usage"];
|
||||||
assert!(
|
assert!(
|
||||||
usage["input_tokens"].as_u64().unwrap_or(0) > 0,
|
usage["input_tokens"].as_u64().unwrap_or(0) > 0,
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
Loading…
Add table
Reference in a new issue