Compare commits
16 commits
main
...
fix/clippy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9a571ae57 | ||
|
|
29b99b21f7 | ||
|
|
0d2cd3229e | ||
|
|
6a01aa52e6 | ||
|
|
1e8e152148 | ||
|
|
7976bf994c | ||
|
|
d9c1883879 | ||
|
|
722eee3ea5 | ||
|
|
23ba21bdd5 | ||
|
|
464a870180 | ||
|
|
2959cd1e51 | ||
|
|
86bc510722 | ||
|
|
920842fffe | ||
|
|
ec5d3a546b | ||
|
|
9cfbb8ceed | ||
|
|
ace763eede |
7 changed files with 109 additions and 86 deletions
|
|
@ -21,7 +21,7 @@ async fn stream_via_provider<P: Provider>(
|
|||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProviderClient {
|
||||
Anthropic(AnthropicClient),
|
||||
Anthropic(Box<AnthropicClient>),
|
||||
Xai(OpenAiCompatClient),
|
||||
OpenAi(OpenAiCompatClient),
|
||||
}
|
||||
|
|
@ -37,10 +37,10 @@ impl ProviderClient {
|
|||
) -> Result<Self, ApiError> {
|
||||
let resolved_model = providers::resolve_model_alias(model);
|
||||
match providers::detect_provider_kind(&resolved_model) {
|
||||
ProviderKind::Anthropic => Ok(Self::Anthropic(match anthropic_auth {
|
||||
ProviderKind::Anthropic => Ok(Self::Anthropic(Box::new(match anthropic_auth {
|
||||
Some(auth) => AnthropicClient::from_auth(auth),
|
||||
None => AnthropicClient::from_env()?,
|
||||
})),
|
||||
}))),
|
||||
ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env(
|
||||
OpenAiCompatConfig::xai(),
|
||||
)?)),
|
||||
|
|
@ -62,7 +62,9 @@ impl ProviderClient {
|
|||
#[must_use]
|
||||
pub fn with_prompt_cache(self, prompt_cache: PromptCache) -> Self {
|
||||
match self {
|
||||
Self::Anthropic(client) => Self::Anthropic(client.with_prompt_cache(prompt_cache)),
|
||||
Self::Anthropic(client) => {
|
||||
Self::Anthropic(Box::new((*client).with_prompt_cache(prompt_cache)))
|
||||
}
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
|
@ -88,7 +90,7 @@ impl ProviderClient {
|
|||
request: &MessageRequest,
|
||||
) -> Result<MessageResponse, ApiError> {
|
||||
match self {
|
||||
Self::Anthropic(client) => send_via_provider(client, request).await,
|
||||
Self::Anthropic(client) => send_via_provider(client.as_ref(), request).await,
|
||||
Self::Xai(client) | Self::OpenAi(client) => send_via_provider(client, request).await,
|
||||
}
|
||||
}
|
||||
|
|
@ -98,7 +100,7 @@ impl ProviderClient {
|
|||
request: &MessageRequest,
|
||||
) -> Result<MessageStream, ApiError> {
|
||||
match self {
|
||||
Self::Anthropic(client) => stream_via_provider(client, request)
|
||||
Self::Anthropic(client) => stream_via_provider(client.as_ref(), request)
|
||||
.await
|
||||
.map(MessageStream::Anthropic),
|
||||
Self::Xai(client) | Self::OpenAi(client) => stream_via_provider(client, request)
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ impl MessageStream {
|
|||
}
|
||||
|
||||
if self.done {
|
||||
self.pending.extend(self.state.finish()?);
|
||||
self.pending.extend(self.state.finish());
|
||||
if let Some(event) = self.pending.pop_front() {
|
||||
return Ok(Some(event));
|
||||
}
|
||||
|
|
@ -261,7 +261,7 @@ impl MessageStream {
|
|||
match self.response.chunk().await? {
|
||||
Some(chunk) => {
|
||||
for parsed in self.parser.push(&chunk)? {
|
||||
self.pending.extend(self.state.ingest_chunk(parsed)?);
|
||||
self.pending.extend(self.state.ingest_chunk(parsed));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
|
|
@ -299,33 +299,41 @@ impl OpenAiSseParser {
|
|||
#[derive(Debug)]
|
||||
struct StreamState {
|
||||
model: String,
|
||||
message_started: bool,
|
||||
text_started: bool,
|
||||
text_finished: bool,
|
||||
finished: bool,
|
||||
message: MessageState,
|
||||
text: TextState,
|
||||
stop_reason: Option<String>,
|
||||
usage: Option<Usage>,
|
||||
tool_calls: BTreeMap<u32, ToolCallState>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct MessageState {
|
||||
started: bool,
|
||||
finished: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct TextState {
|
||||
started: bool,
|
||||
finished: bool,
|
||||
}
|
||||
|
||||
impl StreamState {
|
||||
fn new(model: String) -> Self {
|
||||
Self {
|
||||
model,
|
||||
message_started: false,
|
||||
text_started: false,
|
||||
text_finished: false,
|
||||
finished: false,
|
||||
message: MessageState::default(),
|
||||
text: TextState::default(),
|
||||
stop_reason: None,
|
||||
usage: None,
|
||||
tool_calls: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result<Vec<StreamEvent>, ApiError> {
|
||||
fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Vec<StreamEvent> {
|
||||
let mut events = Vec::new();
|
||||
if !self.message_started {
|
||||
self.message_started = true;
|
||||
if !self.message.started {
|
||||
self.message.started = true;
|
||||
events.push(StreamEvent::MessageStart(MessageStartEvent {
|
||||
message: MessageResponse {
|
||||
id: chunk.id.clone(),
|
||||
|
|
@ -357,8 +365,8 @@ impl StreamState {
|
|||
|
||||
for choice in chunk.choices {
|
||||
if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) {
|
||||
if !self.text_started {
|
||||
self.text_started = true;
|
||||
if !self.text.started {
|
||||
self.text.started = true;
|
||||
events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent {
|
||||
index: 0,
|
||||
content_block: OutputContentBlock::Text {
|
||||
|
|
@ -377,7 +385,7 @@ impl StreamState {
|
|||
state.apply(tool_call);
|
||||
let block_index = state.block_index();
|
||||
if !state.started {
|
||||
if let Some(start_event) = state.start_event()? {
|
||||
if let Some(start_event) = state.start_event() {
|
||||
state.started = true;
|
||||
events.push(StreamEvent::ContentBlockStart(start_event));
|
||||
} else {
|
||||
|
|
@ -410,18 +418,18 @@ impl StreamState {
|
|||
}
|
||||
}
|
||||
|
||||
Ok(events)
|
||||
events
|
||||
}
|
||||
|
||||
fn finish(&mut self) -> Result<Vec<StreamEvent>, ApiError> {
|
||||
if self.finished {
|
||||
return Ok(Vec::new());
|
||||
fn finish(&mut self) -> Vec<StreamEvent> {
|
||||
if self.message.finished {
|
||||
return Vec::new();
|
||||
}
|
||||
self.finished = true;
|
||||
self.message.finished = true;
|
||||
|
||||
let mut events = Vec::new();
|
||||
if self.text_started && !self.text_finished {
|
||||
self.text_finished = true;
|
||||
if self.text.started && !self.text.finished {
|
||||
self.text.finished = true;
|
||||
events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
|
||||
index: 0,
|
||||
}));
|
||||
|
|
@ -429,7 +437,7 @@ impl StreamState {
|
|||
|
||||
for state in self.tool_calls.values_mut() {
|
||||
if !state.started {
|
||||
if let Some(start_event) = state.start_event()? {
|
||||
if let Some(start_event) = state.start_event() {
|
||||
state.started = true;
|
||||
events.push(StreamEvent::ContentBlockStart(start_event));
|
||||
if let Some(delta_event) = state.delta_event() {
|
||||
|
|
@ -445,7 +453,7 @@ impl StreamState {
|
|||
}
|
||||
}
|
||||
|
||||
if self.message_started {
|
||||
if self.message.started {
|
||||
events.push(StreamEvent::MessageDelta(MessageDeltaEvent {
|
||||
delta: MessageDelta {
|
||||
stop_reason: Some(
|
||||
|
|
@ -464,7 +472,7 @@ impl StreamState {
|
|||
}));
|
||||
events.push(StreamEvent::MessageStop(MessageStopEvent {}));
|
||||
}
|
||||
Ok(events)
|
||||
events
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -497,22 +505,20 @@ impl ToolCallState {
|
|||
self.openai_index + 1
|
||||
}
|
||||
|
||||
fn start_event(&self) -> Result<Option<ContentBlockStartEvent>, ApiError> {
|
||||
let Some(name) = self.name.clone() else {
|
||||
return Ok(None);
|
||||
};
|
||||
fn start_event(&self) -> Option<ContentBlockStartEvent> {
|
||||
let name = self.name.clone()?;
|
||||
let id = self
|
||||
.id
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("tool_call_{}", self.openai_index));
|
||||
Ok(Some(ContentBlockStartEvent {
|
||||
Some(ContentBlockStartEvent {
|
||||
index: self.block_index(),
|
||||
content_block: OutputContentBlock::ToolUse {
|
||||
id,
|
||||
name,
|
||||
input: json!({}),
|
||||
},
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
fn delta_event(&mut self) -> Option<ContentBlockDeltaEvent> {
|
||||
|
|
|
|||
|
|
@ -407,7 +407,7 @@ async fn provider_client_dispatches_anthropic_requests() {
|
|||
.expect("anthropic provider client should be constructed");
|
||||
let client = match client {
|
||||
ProviderClient::Anthropic(client) => {
|
||||
ProviderClient::Anthropic(client.with_base_url(server.base_url()))
|
||||
ProviderClient::Anthropic(Box::new((*client).with_base_url(server.base_url())))
|
||||
}
|
||||
other => panic!("expected anthropic provider, got {other:?}"),
|
||||
};
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ where
|
|||
system_prompt,
|
||||
max_iterations: usize::MAX,
|
||||
usage_tracker,
|
||||
hook_runner: HookRunner::from_feature_config(&feature_config),
|
||||
hook_runner: HookRunner::from_feature_config(feature_config),
|
||||
auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(),
|
||||
hook_abort_signal: HookAbortSignal::default(),
|
||||
hook_progress_reporter: None,
|
||||
|
|
@ -349,7 +349,7 @@ where
|
|||
);
|
||||
|
||||
self.session
|
||||
.push_message(assistant_message.clone())
|
||||
.push_message(&assistant_message)
|
||||
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
||||
assistant_messages.push(assistant_message);
|
||||
|
||||
|
|
@ -440,7 +440,7 @@ where
|
|||
),
|
||||
};
|
||||
self.session
|
||||
.push_message(result_message.clone())
|
||||
.push_message(&result_message)
|
||||
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
||||
self.record_tool_finished(iterations, &result_message);
|
||||
tool_results.push(result_message);
|
||||
|
|
|
|||
|
|
@ -168,20 +168,19 @@ impl Session {
|
|||
{
|
||||
Self::from_json(&value)?
|
||||
}
|
||||
Err(_) => Self::from_jsonl(&contents)?,
|
||||
Ok(_) => Self::from_jsonl(&contents)?,
|
||||
Err(_) | Ok(_) => Self::from_jsonl(&contents)?,
|
||||
};
|
||||
Ok(session.with_persistence_path(path.to_path_buf()))
|
||||
}
|
||||
|
||||
pub fn push_message(&mut self, message: ConversationMessage) -> Result<(), SessionError> {
|
||||
pub fn push_message(&mut self, message: &ConversationMessage) -> Result<(), SessionError> {
|
||||
self.touch();
|
||||
self.messages.push(message.clone());
|
||||
self.append_persisted_message(&message)
|
||||
self.append_persisted_message(message)
|
||||
}
|
||||
|
||||
pub fn push_user_text(&mut self, text: impl Into<String>) -> Result<(), SessionError> {
|
||||
self.push_message(ConversationMessage::user_text(text))
|
||||
self.push_message(&ConversationMessage::user_text(text))
|
||||
}
|
||||
|
||||
pub fn record_compaction(&mut self, summary: impl Into<String>, removed_message_count: usize) {
|
||||
|
|
@ -270,8 +269,7 @@ impl Session {
|
|||
let session_id = object
|
||||
.get("session_id")
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(ToOwned::to_owned)
|
||||
.unwrap_or_else(generate_session_id);
|
||||
.map_or_else(generate_session_id, ToOwned::to_owned);
|
||||
let created_at_ms = object
|
||||
.get("created_at_ms")
|
||||
.map(|value| required_u64_from_value(value, "created_at_ms"))
|
||||
|
|
@ -813,10 +811,16 @@ fn normalize_optional_string(value: Option<String>) -> Option<String> {
|
|||
fn current_time_millis() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|duration| duration.as_millis() as u64)
|
||||
.map(|duration| u64::try_from(duration.as_millis()).unwrap_or(u64::MAX))
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn has_jsonl_extension(path: &Path) -> bool {
|
||||
path.extension()
|
||||
.and_then(|value| value.to_str())
|
||||
.is_some_and(|extension| extension.eq_ignore_ascii_case("jsonl"))
|
||||
}
|
||||
|
||||
fn generate_session_id() -> String {
|
||||
let millis = current_time_millis();
|
||||
let counter = SESSION_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
|
|
@ -881,7 +885,8 @@ fn cleanup_rotated_logs(path: &Path) -> Result<(), SessionError> {
|
|||
entry_path
|
||||
.file_name()
|
||||
.and_then(|value| value.to_str())
|
||||
.is_some_and(|name| name.starts_with(&prefix) && name.ends_with(".jsonl"))
|
||||
.is_some_and(|name| name.starts_with(&prefix))
|
||||
&& has_jsonl_extension(entry_path)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
|
|
@ -907,7 +912,7 @@ mod tests {
|
|||
use crate::json::JsonValue;
|
||||
use crate::usage::TokenUsage;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
#[test]
|
||||
|
|
@ -917,7 +922,7 @@ mod tests {
|
|||
.push_user_text("hello")
|
||||
.expect("user message should append");
|
||||
session
|
||||
.push_message(ConversationMessage::assistant_with_usage(
|
||||
.push_message(&ConversationMessage::assistant_with_usage(
|
||||
vec![
|
||||
ContentBlock::Text {
|
||||
text: "thinking".to_string(),
|
||||
|
|
@ -937,7 +942,7 @@ mod tests {
|
|||
))
|
||||
.expect("assistant message should append");
|
||||
session
|
||||
.push_message(ConversationMessage::tool_result(
|
||||
.push_message(&ConversationMessage::tool_result(
|
||||
"tool-1", "bash", "hi", false,
|
||||
))
|
||||
.expect("tool result should append");
|
||||
|
|
@ -994,7 +999,7 @@ mod tests {
|
|||
.push_user_text("hi")
|
||||
.expect("user append should succeed");
|
||||
session
|
||||
.push_message(ConversationMessage::assistant(vec![ContentBlock::Text {
|
||||
.push_message(&ConversationMessage::assistant(vec![ContentBlock::Text {
|
||||
text: "hello".to_string(),
|
||||
}]))
|
||||
.expect("assistant append should succeed");
|
||||
|
|
@ -1057,8 +1062,14 @@ mod tests {
|
|||
#[test]
|
||||
fn rotates_and_cleans_up_large_session_logs() {
|
||||
let path = temp_session_path("rotation");
|
||||
fs::write(&path, "x".repeat((super::ROTATE_AFTER_BYTES + 10) as usize))
|
||||
.expect("oversized file should write");
|
||||
fs::write(
|
||||
&path,
|
||||
"x".repeat(
|
||||
usize::try_from(super::ROTATE_AFTER_BYTES + 10)
|
||||
.expect("rotation threshold should fit usize"),
|
||||
),
|
||||
)
|
||||
.expect("oversized file should write");
|
||||
rotate_session_file_if_needed(&path).expect("rotation should succeed");
|
||||
assert!(
|
||||
!path.exists(),
|
||||
|
|
@ -1086,7 +1097,7 @@ mod tests {
|
|||
std::env::temp_dir().join(format!("runtime-session-{label}-{nanos}.json"))
|
||||
}
|
||||
|
||||
fn rotation_files(path: &PathBuf) -> Vec<PathBuf> {
|
||||
fn rotation_files(path: &Path) -> Vec<PathBuf> {
|
||||
let stem = path
|
||||
.file_stem()
|
||||
.and_then(|value| value.to_str())
|
||||
|
|
@ -1100,9 +1111,8 @@ mod tests {
|
|||
entry_path
|
||||
.file_name()
|
||||
.and_then(|value| value.to_str())
|
||||
.is_some_and(|name| {
|
||||
name.starts_with(&format!("{stem}.rot-")) && name.ends_with(".jsonl")
|
||||
})
|
||||
.is_some_and(|name| name.starts_with(&format!("{stem}.rot-")))
|
||||
&& super::has_jsonl_extension(entry_path)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2058,7 +2058,17 @@ fn list_managed_sessions() -> Result<Vec<ManagedSessionSummary>, Box<dyn std::er
|
|||
.map(|duration| duration.as_secs())
|
||||
.unwrap_or_default();
|
||||
let (id, message_count, parent_session_id, branch_name) = Session::load_from_path(&path)
|
||||
.map(|session| {
|
||||
.map_or_else(|_| {
|
||||
(
|
||||
path.file_stem()
|
||||
.and_then(|value| value.to_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string(),
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
}, |session| {
|
||||
let parent_session_id = session
|
||||
.fork
|
||||
.as_ref()
|
||||
|
|
@ -2073,17 +2083,6 @@ fn list_managed_sessions() -> Result<Vec<ManagedSessionSummary>, Box<dyn std::er
|
|||
parent_session_id,
|
||||
branch_name,
|
||||
)
|
||||
})
|
||||
.unwrap_or_else(|_| {
|
||||
(
|
||||
path.file_stem()
|
||||
.and_then(|value| value.to_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string(),
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
});
|
||||
sessions.push(ManagedSessionSummary {
|
||||
id,
|
||||
|
|
|
|||
|
|
@ -1632,7 +1632,7 @@ fn build_agent_runtime(
|
|||
.clone()
|
||||
.unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string());
|
||||
let allowed_tools = job.allowed_tools.clone();
|
||||
let api_client = ProviderRuntimeClient::new(model, allowed_tools.clone())?;
|
||||
let api_client = ProviderRuntimeClient::new(&model, allowed_tools.clone())?;
|
||||
let tool_executor = SubagentToolExecutor::new(allowed_tools);
|
||||
Ok(ConversationRuntime::new(
|
||||
Session::new(),
|
||||
|
|
@ -1809,8 +1809,8 @@ struct ProviderRuntimeClient {
|
|||
}
|
||||
|
||||
impl ProviderRuntimeClient {
|
||||
fn new(model: String, allowed_tools: BTreeSet<String>) -> Result<Self, String> {
|
||||
let model = resolve_model_alias(&model).to_string();
|
||||
fn new(model: &str, allowed_tools: BTreeSet<String>) -> Result<Self, String> {
|
||||
let model = resolve_model_alias(model).clone();
|
||||
let client = ProviderClient::from_model(&model).map_err(|error| error.to_string())?;
|
||||
Ok(Self {
|
||||
runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?,
|
||||
|
|
@ -1902,19 +1902,11 @@ impl ApiClient for ProviderRuntimeClient {
|
|||
|
||||
push_prompt_cache_record(&self.client, &mut events);
|
||||
|
||||
if !saw_stop
|
||||
&& events.iter().any(|event| {
|
||||
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
|
||||
|| matches!(event, AssistantEvent::ToolUse { .. })
|
||||
})
|
||||
{
|
||||
if should_append_message_stop(&events, saw_stop) {
|
||||
events.push(AssistantEvent::MessageStop);
|
||||
}
|
||||
|
||||
if events
|
||||
.iter()
|
||||
.any(|event| matches!(event, AssistantEvent::MessageStop))
|
||||
{
|
||||
if has_message_stop(&events) {
|
||||
return Ok(events);
|
||||
}
|
||||
|
||||
|
|
@ -1933,6 +1925,20 @@ impl ApiClient for ProviderRuntimeClient {
|
|||
}
|
||||
}
|
||||
|
||||
fn should_append_message_stop(events: &[AssistantEvent], saw_stop: bool) -> bool {
|
||||
!saw_stop
|
||||
&& events.iter().any(|event| {
|
||||
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
|
||||
|| matches!(event, AssistantEvent::ToolUse { .. })
|
||||
})
|
||||
}
|
||||
|
||||
fn has_message_stop(events: &[AssistantEvent]) -> bool {
|
||||
events
|
||||
.iter()
|
||||
.any(|event| matches!(event, AssistantEvent::MessageStop))
|
||||
}
|
||||
|
||||
struct SubagentToolExecutor {
|
||||
allowed_tools: BTreeSet<String>,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue