Compare commits
16 commits
main
...
fix/clippy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9a571ae57 | ||
|
|
29b99b21f7 | ||
|
|
0d2cd3229e | ||
|
|
6a01aa52e6 | ||
|
|
1e8e152148 | ||
|
|
7976bf994c | ||
|
|
d9c1883879 | ||
|
|
722eee3ea5 | ||
|
|
23ba21bdd5 | ||
|
|
464a870180 | ||
|
|
2959cd1e51 | ||
|
|
86bc510722 | ||
|
|
920842fffe | ||
|
|
ec5d3a546b | ||
|
|
9cfbb8ceed | ||
|
|
ace763eede |
7 changed files with 109 additions and 86 deletions
|
|
@ -21,7 +21,7 @@ async fn stream_via_provider<P: Provider>(
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum ProviderClient {
|
pub enum ProviderClient {
|
||||||
Anthropic(AnthropicClient),
|
Anthropic(Box<AnthropicClient>),
|
||||||
Xai(OpenAiCompatClient),
|
Xai(OpenAiCompatClient),
|
||||||
OpenAi(OpenAiCompatClient),
|
OpenAi(OpenAiCompatClient),
|
||||||
}
|
}
|
||||||
|
|
@ -37,10 +37,10 @@ impl ProviderClient {
|
||||||
) -> Result<Self, ApiError> {
|
) -> Result<Self, ApiError> {
|
||||||
let resolved_model = providers::resolve_model_alias(model);
|
let resolved_model = providers::resolve_model_alias(model);
|
||||||
match providers::detect_provider_kind(&resolved_model) {
|
match providers::detect_provider_kind(&resolved_model) {
|
||||||
ProviderKind::Anthropic => Ok(Self::Anthropic(match anthropic_auth {
|
ProviderKind::Anthropic => Ok(Self::Anthropic(Box::new(match anthropic_auth {
|
||||||
Some(auth) => AnthropicClient::from_auth(auth),
|
Some(auth) => AnthropicClient::from_auth(auth),
|
||||||
None => AnthropicClient::from_env()?,
|
None => AnthropicClient::from_env()?,
|
||||||
})),
|
}))),
|
||||||
ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env(
|
ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env(
|
||||||
OpenAiCompatConfig::xai(),
|
OpenAiCompatConfig::xai(),
|
||||||
)?)),
|
)?)),
|
||||||
|
|
@ -62,7 +62,9 @@ impl ProviderClient {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn with_prompt_cache(self, prompt_cache: PromptCache) -> Self {
|
pub fn with_prompt_cache(self, prompt_cache: PromptCache) -> Self {
|
||||||
match self {
|
match self {
|
||||||
Self::Anthropic(client) => Self::Anthropic(client.with_prompt_cache(prompt_cache)),
|
Self::Anthropic(client) => {
|
||||||
|
Self::Anthropic(Box::new((*client).with_prompt_cache(prompt_cache)))
|
||||||
|
}
|
||||||
other => other,
|
other => other,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -88,7 +90,7 @@ impl ProviderClient {
|
||||||
request: &MessageRequest,
|
request: &MessageRequest,
|
||||||
) -> Result<MessageResponse, ApiError> {
|
) -> Result<MessageResponse, ApiError> {
|
||||||
match self {
|
match self {
|
||||||
Self::Anthropic(client) => send_via_provider(client, request).await,
|
Self::Anthropic(client) => send_via_provider(client.as_ref(), request).await,
|
||||||
Self::Xai(client) | Self::OpenAi(client) => send_via_provider(client, request).await,
|
Self::Xai(client) | Self::OpenAi(client) => send_via_provider(client, request).await,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -98,7 +100,7 @@ impl ProviderClient {
|
||||||
request: &MessageRequest,
|
request: &MessageRequest,
|
||||||
) -> Result<MessageStream, ApiError> {
|
) -> Result<MessageStream, ApiError> {
|
||||||
match self {
|
match self {
|
||||||
Self::Anthropic(client) => stream_via_provider(client, request)
|
Self::Anthropic(client) => stream_via_provider(client.as_ref(), request)
|
||||||
.await
|
.await
|
||||||
.map(MessageStream::Anthropic),
|
.map(MessageStream::Anthropic),
|
||||||
Self::Xai(client) | Self::OpenAi(client) => stream_via_provider(client, request)
|
Self::Xai(client) | Self::OpenAi(client) => stream_via_provider(client, request)
|
||||||
|
|
|
||||||
|
|
@ -251,7 +251,7 @@ impl MessageStream {
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.done {
|
if self.done {
|
||||||
self.pending.extend(self.state.finish()?);
|
self.pending.extend(self.state.finish());
|
||||||
if let Some(event) = self.pending.pop_front() {
|
if let Some(event) = self.pending.pop_front() {
|
||||||
return Ok(Some(event));
|
return Ok(Some(event));
|
||||||
}
|
}
|
||||||
|
|
@ -261,7 +261,7 @@ impl MessageStream {
|
||||||
match self.response.chunk().await? {
|
match self.response.chunk().await? {
|
||||||
Some(chunk) => {
|
Some(chunk) => {
|
||||||
for parsed in self.parser.push(&chunk)? {
|
for parsed in self.parser.push(&chunk)? {
|
||||||
self.pending.extend(self.state.ingest_chunk(parsed)?);
|
self.pending.extend(self.state.ingest_chunk(parsed));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
|
|
@ -299,33 +299,41 @@ impl OpenAiSseParser {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct StreamState {
|
struct StreamState {
|
||||||
model: String,
|
model: String,
|
||||||
message_started: bool,
|
message: MessageState,
|
||||||
text_started: bool,
|
text: TextState,
|
||||||
text_finished: bool,
|
|
||||||
finished: bool,
|
|
||||||
stop_reason: Option<String>,
|
stop_reason: Option<String>,
|
||||||
usage: Option<Usage>,
|
usage: Option<Usage>,
|
||||||
tool_calls: BTreeMap<u32, ToolCallState>,
|
tool_calls: BTreeMap<u32, ToolCallState>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
struct MessageState {
|
||||||
|
started: bool,
|
||||||
|
finished: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
struct TextState {
|
||||||
|
started: bool,
|
||||||
|
finished: bool,
|
||||||
|
}
|
||||||
|
|
||||||
impl StreamState {
|
impl StreamState {
|
||||||
fn new(model: String) -> Self {
|
fn new(model: String) -> Self {
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
message_started: false,
|
message: MessageState::default(),
|
||||||
text_started: false,
|
text: TextState::default(),
|
||||||
text_finished: false,
|
|
||||||
finished: false,
|
|
||||||
stop_reason: None,
|
stop_reason: None,
|
||||||
usage: None,
|
usage: None,
|
||||||
tool_calls: BTreeMap::new(),
|
tool_calls: BTreeMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result<Vec<StreamEvent>, ApiError> {
|
fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Vec<StreamEvent> {
|
||||||
let mut events = Vec::new();
|
let mut events = Vec::new();
|
||||||
if !self.message_started {
|
if !self.message.started {
|
||||||
self.message_started = true;
|
self.message.started = true;
|
||||||
events.push(StreamEvent::MessageStart(MessageStartEvent {
|
events.push(StreamEvent::MessageStart(MessageStartEvent {
|
||||||
message: MessageResponse {
|
message: MessageResponse {
|
||||||
id: chunk.id.clone(),
|
id: chunk.id.clone(),
|
||||||
|
|
@ -357,8 +365,8 @@ impl StreamState {
|
||||||
|
|
||||||
for choice in chunk.choices {
|
for choice in chunk.choices {
|
||||||
if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) {
|
if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) {
|
||||||
if !self.text_started {
|
if !self.text.started {
|
||||||
self.text_started = true;
|
self.text.started = true;
|
||||||
events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent {
|
events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent {
|
||||||
index: 0,
|
index: 0,
|
||||||
content_block: OutputContentBlock::Text {
|
content_block: OutputContentBlock::Text {
|
||||||
|
|
@ -377,7 +385,7 @@ impl StreamState {
|
||||||
state.apply(tool_call);
|
state.apply(tool_call);
|
||||||
let block_index = state.block_index();
|
let block_index = state.block_index();
|
||||||
if !state.started {
|
if !state.started {
|
||||||
if let Some(start_event) = state.start_event()? {
|
if let Some(start_event) = state.start_event() {
|
||||||
state.started = true;
|
state.started = true;
|
||||||
events.push(StreamEvent::ContentBlockStart(start_event));
|
events.push(StreamEvent::ContentBlockStart(start_event));
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -410,18 +418,18 @@ impl StreamState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(events)
|
events
|
||||||
}
|
}
|
||||||
|
|
||||||
fn finish(&mut self) -> Result<Vec<StreamEvent>, ApiError> {
|
fn finish(&mut self) -> Vec<StreamEvent> {
|
||||||
if self.finished {
|
if self.message.finished {
|
||||||
return Ok(Vec::new());
|
return Vec::new();
|
||||||
}
|
}
|
||||||
self.finished = true;
|
self.message.finished = true;
|
||||||
|
|
||||||
let mut events = Vec::new();
|
let mut events = Vec::new();
|
||||||
if self.text_started && !self.text_finished {
|
if self.text.started && !self.text.finished {
|
||||||
self.text_finished = true;
|
self.text.finished = true;
|
||||||
events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
|
events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
|
||||||
index: 0,
|
index: 0,
|
||||||
}));
|
}));
|
||||||
|
|
@ -429,7 +437,7 @@ impl StreamState {
|
||||||
|
|
||||||
for state in self.tool_calls.values_mut() {
|
for state in self.tool_calls.values_mut() {
|
||||||
if !state.started {
|
if !state.started {
|
||||||
if let Some(start_event) = state.start_event()? {
|
if let Some(start_event) = state.start_event() {
|
||||||
state.started = true;
|
state.started = true;
|
||||||
events.push(StreamEvent::ContentBlockStart(start_event));
|
events.push(StreamEvent::ContentBlockStart(start_event));
|
||||||
if let Some(delta_event) = state.delta_event() {
|
if let Some(delta_event) = state.delta_event() {
|
||||||
|
|
@ -445,7 +453,7 @@ impl StreamState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.message_started {
|
if self.message.started {
|
||||||
events.push(StreamEvent::MessageDelta(MessageDeltaEvent {
|
events.push(StreamEvent::MessageDelta(MessageDeltaEvent {
|
||||||
delta: MessageDelta {
|
delta: MessageDelta {
|
||||||
stop_reason: Some(
|
stop_reason: Some(
|
||||||
|
|
@ -464,7 +472,7 @@ impl StreamState {
|
||||||
}));
|
}));
|
||||||
events.push(StreamEvent::MessageStop(MessageStopEvent {}));
|
events.push(StreamEvent::MessageStop(MessageStopEvent {}));
|
||||||
}
|
}
|
||||||
Ok(events)
|
events
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -497,22 +505,20 @@ impl ToolCallState {
|
||||||
self.openai_index + 1
|
self.openai_index + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_event(&self) -> Result<Option<ContentBlockStartEvent>, ApiError> {
|
fn start_event(&self) -> Option<ContentBlockStartEvent> {
|
||||||
let Some(name) = self.name.clone() else {
|
let name = self.name.clone()?;
|
||||||
return Ok(None);
|
|
||||||
};
|
|
||||||
let id = self
|
let id = self
|
||||||
.id
|
.id
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| format!("tool_call_{}", self.openai_index));
|
.unwrap_or_else(|| format!("tool_call_{}", self.openai_index));
|
||||||
Ok(Some(ContentBlockStartEvent {
|
Some(ContentBlockStartEvent {
|
||||||
index: self.block_index(),
|
index: self.block_index(),
|
||||||
content_block: OutputContentBlock::ToolUse {
|
content_block: OutputContentBlock::ToolUse {
|
||||||
id,
|
id,
|
||||||
name,
|
name,
|
||||||
input: json!({}),
|
input: json!({}),
|
||||||
},
|
},
|
||||||
}))
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn delta_event(&mut self) -> Option<ContentBlockDeltaEvent> {
|
fn delta_event(&mut self) -> Option<ContentBlockDeltaEvent> {
|
||||||
|
|
|
||||||
|
|
@ -407,7 +407,7 @@ async fn provider_client_dispatches_anthropic_requests() {
|
||||||
.expect("anthropic provider client should be constructed");
|
.expect("anthropic provider client should be constructed");
|
||||||
let client = match client {
|
let client = match client {
|
||||||
ProviderClient::Anthropic(client) => {
|
ProviderClient::Anthropic(client) => {
|
||||||
ProviderClient::Anthropic(client.with_base_url(server.base_url()))
|
ProviderClient::Anthropic(Box::new((*client).with_base_url(server.base_url())))
|
||||||
}
|
}
|
||||||
other => panic!("expected anthropic provider, got {other:?}"),
|
other => panic!("expected anthropic provider, got {other:?}"),
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -170,7 +170,7 @@ where
|
||||||
system_prompt,
|
system_prompt,
|
||||||
max_iterations: usize::MAX,
|
max_iterations: usize::MAX,
|
||||||
usage_tracker,
|
usage_tracker,
|
||||||
hook_runner: HookRunner::from_feature_config(&feature_config),
|
hook_runner: HookRunner::from_feature_config(feature_config),
|
||||||
auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(),
|
auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(),
|
||||||
hook_abort_signal: HookAbortSignal::default(),
|
hook_abort_signal: HookAbortSignal::default(),
|
||||||
hook_progress_reporter: None,
|
hook_progress_reporter: None,
|
||||||
|
|
@ -349,7 +349,7 @@ where
|
||||||
);
|
);
|
||||||
|
|
||||||
self.session
|
self.session
|
||||||
.push_message(assistant_message.clone())
|
.push_message(&assistant_message)
|
||||||
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
||||||
assistant_messages.push(assistant_message);
|
assistant_messages.push(assistant_message);
|
||||||
|
|
||||||
|
|
@ -440,7 +440,7 @@ where
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
self.session
|
self.session
|
||||||
.push_message(result_message.clone())
|
.push_message(&result_message)
|
||||||
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
||||||
self.record_tool_finished(iterations, &result_message);
|
self.record_tool_finished(iterations, &result_message);
|
||||||
tool_results.push(result_message);
|
tool_results.push(result_message);
|
||||||
|
|
|
||||||
|
|
@ -168,20 +168,19 @@ impl Session {
|
||||||
{
|
{
|
||||||
Self::from_json(&value)?
|
Self::from_json(&value)?
|
||||||
}
|
}
|
||||||
Err(_) => Self::from_jsonl(&contents)?,
|
Err(_) | Ok(_) => Self::from_jsonl(&contents)?,
|
||||||
Ok(_) => Self::from_jsonl(&contents)?,
|
|
||||||
};
|
};
|
||||||
Ok(session.with_persistence_path(path.to_path_buf()))
|
Ok(session.with_persistence_path(path.to_path_buf()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn push_message(&mut self, message: ConversationMessage) -> Result<(), SessionError> {
|
pub fn push_message(&mut self, message: &ConversationMessage) -> Result<(), SessionError> {
|
||||||
self.touch();
|
self.touch();
|
||||||
self.messages.push(message.clone());
|
self.messages.push(message.clone());
|
||||||
self.append_persisted_message(&message)
|
self.append_persisted_message(message)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn push_user_text(&mut self, text: impl Into<String>) -> Result<(), SessionError> {
|
pub fn push_user_text(&mut self, text: impl Into<String>) -> Result<(), SessionError> {
|
||||||
self.push_message(ConversationMessage::user_text(text))
|
self.push_message(&ConversationMessage::user_text(text))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn record_compaction(&mut self, summary: impl Into<String>, removed_message_count: usize) {
|
pub fn record_compaction(&mut self, summary: impl Into<String>, removed_message_count: usize) {
|
||||||
|
|
@ -270,8 +269,7 @@ impl Session {
|
||||||
let session_id = object
|
let session_id = object
|
||||||
.get("session_id")
|
.get("session_id")
|
||||||
.and_then(JsonValue::as_str)
|
.and_then(JsonValue::as_str)
|
||||||
.map(ToOwned::to_owned)
|
.map_or_else(generate_session_id, ToOwned::to_owned);
|
||||||
.unwrap_or_else(generate_session_id);
|
|
||||||
let created_at_ms = object
|
let created_at_ms = object
|
||||||
.get("created_at_ms")
|
.get("created_at_ms")
|
||||||
.map(|value| required_u64_from_value(value, "created_at_ms"))
|
.map(|value| required_u64_from_value(value, "created_at_ms"))
|
||||||
|
|
@ -813,10 +811,16 @@ fn normalize_optional_string(value: Option<String>) -> Option<String> {
|
||||||
fn current_time_millis() -> u64 {
|
fn current_time_millis() -> u64 {
|
||||||
SystemTime::now()
|
SystemTime::now()
|
||||||
.duration_since(UNIX_EPOCH)
|
.duration_since(UNIX_EPOCH)
|
||||||
.map(|duration| duration.as_millis() as u64)
|
.map(|duration| u64::try_from(duration.as_millis()).unwrap_or(u64::MAX))
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn has_jsonl_extension(path: &Path) -> bool {
|
||||||
|
path.extension()
|
||||||
|
.and_then(|value| value.to_str())
|
||||||
|
.is_some_and(|extension| extension.eq_ignore_ascii_case("jsonl"))
|
||||||
|
}
|
||||||
|
|
||||||
fn generate_session_id() -> String {
|
fn generate_session_id() -> String {
|
||||||
let millis = current_time_millis();
|
let millis = current_time_millis();
|
||||||
let counter = SESSION_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
|
let counter = SESSION_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||||
|
|
@ -881,7 +885,8 @@ fn cleanup_rotated_logs(path: &Path) -> Result<(), SessionError> {
|
||||||
entry_path
|
entry_path
|
||||||
.file_name()
|
.file_name()
|
||||||
.and_then(|value| value.to_str())
|
.and_then(|value| value.to_str())
|
||||||
.is_some_and(|name| name.starts_with(&prefix) && name.ends_with(".jsonl"))
|
.is_some_and(|name| name.starts_with(&prefix))
|
||||||
|
&& has_jsonl_extension(entry_path)
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
|
@ -907,7 +912,7 @@ mod tests {
|
||||||
use crate::json::JsonValue;
|
use crate::json::JsonValue;
|
||||||
use crate::usage::TokenUsage;
|
use crate::usage::TokenUsage;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::path::PathBuf;
|
use std::path::{Path, PathBuf};
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -917,7 +922,7 @@ mod tests {
|
||||||
.push_user_text("hello")
|
.push_user_text("hello")
|
||||||
.expect("user message should append");
|
.expect("user message should append");
|
||||||
session
|
session
|
||||||
.push_message(ConversationMessage::assistant_with_usage(
|
.push_message(&ConversationMessage::assistant_with_usage(
|
||||||
vec![
|
vec![
|
||||||
ContentBlock::Text {
|
ContentBlock::Text {
|
||||||
text: "thinking".to_string(),
|
text: "thinking".to_string(),
|
||||||
|
|
@ -937,7 +942,7 @@ mod tests {
|
||||||
))
|
))
|
||||||
.expect("assistant message should append");
|
.expect("assistant message should append");
|
||||||
session
|
session
|
||||||
.push_message(ConversationMessage::tool_result(
|
.push_message(&ConversationMessage::tool_result(
|
||||||
"tool-1", "bash", "hi", false,
|
"tool-1", "bash", "hi", false,
|
||||||
))
|
))
|
||||||
.expect("tool result should append");
|
.expect("tool result should append");
|
||||||
|
|
@ -994,7 +999,7 @@ mod tests {
|
||||||
.push_user_text("hi")
|
.push_user_text("hi")
|
||||||
.expect("user append should succeed");
|
.expect("user append should succeed");
|
||||||
session
|
session
|
||||||
.push_message(ConversationMessage::assistant(vec![ContentBlock::Text {
|
.push_message(&ConversationMessage::assistant(vec![ContentBlock::Text {
|
||||||
text: "hello".to_string(),
|
text: "hello".to_string(),
|
||||||
}]))
|
}]))
|
||||||
.expect("assistant append should succeed");
|
.expect("assistant append should succeed");
|
||||||
|
|
@ -1057,7 +1062,13 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn rotates_and_cleans_up_large_session_logs() {
|
fn rotates_and_cleans_up_large_session_logs() {
|
||||||
let path = temp_session_path("rotation");
|
let path = temp_session_path("rotation");
|
||||||
fs::write(&path, "x".repeat((super::ROTATE_AFTER_BYTES + 10) as usize))
|
fs::write(
|
||||||
|
&path,
|
||||||
|
"x".repeat(
|
||||||
|
usize::try_from(super::ROTATE_AFTER_BYTES + 10)
|
||||||
|
.expect("rotation threshold should fit usize"),
|
||||||
|
),
|
||||||
|
)
|
||||||
.expect("oversized file should write");
|
.expect("oversized file should write");
|
||||||
rotate_session_file_if_needed(&path).expect("rotation should succeed");
|
rotate_session_file_if_needed(&path).expect("rotation should succeed");
|
||||||
assert!(
|
assert!(
|
||||||
|
|
@ -1086,7 +1097,7 @@ mod tests {
|
||||||
std::env::temp_dir().join(format!("runtime-session-{label}-{nanos}.json"))
|
std::env::temp_dir().join(format!("runtime-session-{label}-{nanos}.json"))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rotation_files(path: &PathBuf) -> Vec<PathBuf> {
|
fn rotation_files(path: &Path) -> Vec<PathBuf> {
|
||||||
let stem = path
|
let stem = path
|
||||||
.file_stem()
|
.file_stem()
|
||||||
.and_then(|value| value.to_str())
|
.and_then(|value| value.to_str())
|
||||||
|
|
@ -1100,9 +1111,8 @@ mod tests {
|
||||||
entry_path
|
entry_path
|
||||||
.file_name()
|
.file_name()
|
||||||
.and_then(|value| value.to_str())
|
.and_then(|value| value.to_str())
|
||||||
.is_some_and(|name| {
|
.is_some_and(|name| name.starts_with(&format!("{stem}.rot-")))
|
||||||
name.starts_with(&format!("{stem}.rot-")) && name.ends_with(".jsonl")
|
&& super::has_jsonl_extension(entry_path)
|
||||||
})
|
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2058,7 +2058,17 @@ fn list_managed_sessions() -> Result<Vec<ManagedSessionSummary>, Box<dyn std::er
|
||||||
.map(|duration| duration.as_secs())
|
.map(|duration| duration.as_secs())
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
let (id, message_count, parent_session_id, branch_name) = Session::load_from_path(&path)
|
let (id, message_count, parent_session_id, branch_name) = Session::load_from_path(&path)
|
||||||
.map(|session| {
|
.map_or_else(|_| {
|
||||||
|
(
|
||||||
|
path.file_stem()
|
||||||
|
.and_then(|value| value.to_str())
|
||||||
|
.unwrap_or("unknown")
|
||||||
|
.to_string(),
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
}, |session| {
|
||||||
let parent_session_id = session
|
let parent_session_id = session
|
||||||
.fork
|
.fork
|
||||||
.as_ref()
|
.as_ref()
|
||||||
|
|
@ -2073,17 +2083,6 @@ fn list_managed_sessions() -> Result<Vec<ManagedSessionSummary>, Box<dyn std::er
|
||||||
parent_session_id,
|
parent_session_id,
|
||||||
branch_name,
|
branch_name,
|
||||||
)
|
)
|
||||||
})
|
|
||||||
.unwrap_or_else(|_| {
|
|
||||||
(
|
|
||||||
path.file_stem()
|
|
||||||
.and_then(|value| value.to_str())
|
|
||||||
.unwrap_or("unknown")
|
|
||||||
.to_string(),
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
});
|
});
|
||||||
sessions.push(ManagedSessionSummary {
|
sessions.push(ManagedSessionSummary {
|
||||||
id,
|
id,
|
||||||
|
|
|
||||||
|
|
@ -1632,7 +1632,7 @@ fn build_agent_runtime(
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string());
|
.unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string());
|
||||||
let allowed_tools = job.allowed_tools.clone();
|
let allowed_tools = job.allowed_tools.clone();
|
||||||
let api_client = ProviderRuntimeClient::new(model, allowed_tools.clone())?;
|
let api_client = ProviderRuntimeClient::new(&model, allowed_tools.clone())?;
|
||||||
let tool_executor = SubagentToolExecutor::new(allowed_tools);
|
let tool_executor = SubagentToolExecutor::new(allowed_tools);
|
||||||
Ok(ConversationRuntime::new(
|
Ok(ConversationRuntime::new(
|
||||||
Session::new(),
|
Session::new(),
|
||||||
|
|
@ -1809,8 +1809,8 @@ struct ProviderRuntimeClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ProviderRuntimeClient {
|
impl ProviderRuntimeClient {
|
||||||
fn new(model: String, allowed_tools: BTreeSet<String>) -> Result<Self, String> {
|
fn new(model: &str, allowed_tools: BTreeSet<String>) -> Result<Self, String> {
|
||||||
let model = resolve_model_alias(&model).to_string();
|
let model = resolve_model_alias(model).clone();
|
||||||
let client = ProviderClient::from_model(&model).map_err(|error| error.to_string())?;
|
let client = ProviderClient::from_model(&model).map_err(|error| error.to_string())?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?,
|
runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?,
|
||||||
|
|
@ -1902,19 +1902,11 @@ impl ApiClient for ProviderRuntimeClient {
|
||||||
|
|
||||||
push_prompt_cache_record(&self.client, &mut events);
|
push_prompt_cache_record(&self.client, &mut events);
|
||||||
|
|
||||||
if !saw_stop
|
if should_append_message_stop(&events, saw_stop) {
|
||||||
&& events.iter().any(|event| {
|
|
||||||
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
|
|
||||||
|| matches!(event, AssistantEvent::ToolUse { .. })
|
|
||||||
})
|
|
||||||
{
|
|
||||||
events.push(AssistantEvent::MessageStop);
|
events.push(AssistantEvent::MessageStop);
|
||||||
}
|
}
|
||||||
|
|
||||||
if events
|
if has_message_stop(&events) {
|
||||||
.iter()
|
|
||||||
.any(|event| matches!(event, AssistantEvent::MessageStop))
|
|
||||||
{
|
|
||||||
return Ok(events);
|
return Ok(events);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1933,6 +1925,20 @@ impl ApiClient for ProviderRuntimeClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn should_append_message_stop(events: &[AssistantEvent], saw_stop: bool) -> bool {
|
||||||
|
!saw_stop
|
||||||
|
&& events.iter().any(|event| {
|
||||||
|
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
|
||||||
|
|| matches!(event, AssistantEvent::ToolUse { .. })
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn has_message_stop(events: &[AssistantEvent]) -> bool {
|
||||||
|
events
|
||||||
|
.iter()
|
||||||
|
.any(|event| matches!(event, AssistantEvent::MessageStop))
|
||||||
|
}
|
||||||
|
|
||||||
struct SubagentToolExecutor {
|
struct SubagentToolExecutor {
|
||||||
allowed_tools: BTreeSet<String>,
|
allowed_tools: BTreeSet<String>,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue