227 lines
9.5 KiB
Python
227 lines
9.5 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
import json
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any, Dict, Iterator, List, Optional
|
||
|
|
|
||
|
|
from core_agent.compression import RollingContextCompressor
|
||
|
|
from core_agent.dispatch import TaskDispatcher
|
||
|
|
from core_agent.memory import SimpleMemoryStore
|
||
|
|
from core_agent.prompts import DEFAULT_SYSTEM_PROMPT, build_system_prompt
|
||
|
|
from core_agent.providers.base import AgentProvider, AssistantTurn, StreamEvent, ToolCall
|
||
|
|
from core_agent.skills import SkillStore
|
||
|
|
from core_agent.tools.builtin import build_default_registry
|
||
|
|
from core_agent.tools.registry import ToolContext, ToolRegistry
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(slots=True)
|
||
|
|
class ChatEvent:
|
||
|
|
type: str
|
||
|
|
delta: str = ""
|
||
|
|
tool_name: str = ""
|
||
|
|
tool_args: Dict[str, Any] = field(default_factory=dict)
|
||
|
|
tool_result: str = ""
|
||
|
|
turn: Optional[AssistantTurn] = None
|
||
|
|
final_response: str = ""
|
||
|
|
raw: Any = None
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(slots=True)
|
||
|
|
class AgentRunResult:
|
||
|
|
final_response: str
|
||
|
|
messages: List[Dict[str, Any]]
|
||
|
|
session: Dict[str, Any] = field(default_factory=dict)
|
||
|
|
|
||
|
|
|
||
|
|
class ConversationSession:
|
||
|
|
"""Multi-turn conversation state with bounded history and tool loops."""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
provider: AgentProvider,
|
||
|
|
workspace: str | Path,
|
||
|
|
skill_dirs: Optional[List[str | Path]] = None,
|
||
|
|
tool_registry: Optional[ToolRegistry] = None,
|
||
|
|
dispatcher: Optional[TaskDispatcher] = None,
|
||
|
|
system_prompt: Optional[str] = None,
|
||
|
|
max_iterations: int = 12,
|
||
|
|
max_history_turns: int = 5,
|
||
|
|
memory_store: Optional[SimpleMemoryStore] = None,
|
||
|
|
context_compressor: Optional[RollingContextCompressor] = None,
|
||
|
|
auto_memory_threshold_tokens: int = 30000,
|
||
|
|
active_skills: Optional[List[str]] = None,
|
||
|
|
) -> None:
|
||
|
|
self.provider = provider
|
||
|
|
self.workspace = Path(workspace).resolve()
|
||
|
|
self.skill_store = SkillStore(skill_dirs or [self.workspace / "skills"])
|
||
|
|
self.tool_registry = tool_registry or build_default_registry()
|
||
|
|
self.dispatcher = dispatcher or TaskDispatcher()
|
||
|
|
self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
|
||
|
|
self.max_iterations = max_iterations
|
||
|
|
self.max_history_turns = max_history_turns
|
||
|
|
self.memory_store = memory_store or SimpleMemoryStore(self.workspace / ".core_agent" / "memory.json")
|
||
|
|
self.context_compressor = context_compressor or RollingContextCompressor()
|
||
|
|
self.auto_memory_threshold_tokens = auto_memory_threshold_tokens
|
||
|
|
self.history: List[Dict[str, str]] = []
|
||
|
|
self.last_messages: List[Dict[str, Any]] = []
|
||
|
|
self._last_auto_memory_signature: str = ""
|
||
|
|
self.session_state: Dict[str, Any] = {
|
||
|
|
"active_skills": list(active_skills or []),
|
||
|
|
"workspace": str(self.workspace),
|
||
|
|
}
|
||
|
|
|
||
|
|
def ask(self, user_message: str) -> AgentRunResult:
|
||
|
|
final_event: Optional[ChatEvent] = None
|
||
|
|
for event in self.stream_ask(user_message):
|
||
|
|
final_event = event
|
||
|
|
if final_event is None or final_event.type != "final":
|
||
|
|
raise RuntimeError("Conversation ended without a final response")
|
||
|
|
return AgentRunResult(
|
||
|
|
final_response=final_event.final_response,
|
||
|
|
messages=list(self.last_messages),
|
||
|
|
session=dict(self.session_state),
|
||
|
|
)
|
||
|
|
|
||
|
|
def stream_ask(self, user_message: str) -> Iterator[ChatEvent]:
|
||
|
|
messages = self.build_messages(user_message)
|
||
|
|
self.last_messages = messages
|
||
|
|
final_content = ""
|
||
|
|
|
||
|
|
for _ in range(self.max_iterations):
|
||
|
|
yield ChatEvent(type="round_start")
|
||
|
|
assistant_turn = yield from self._stream_turn(messages)
|
||
|
|
messages.append(_assistant_message_to_dict(assistant_turn))
|
||
|
|
|
||
|
|
if not assistant_turn.tool_calls:
|
||
|
|
final_content = assistant_turn.content or ""
|
||
|
|
self._append_history(user_message, final_content)
|
||
|
|
self.last_messages = list(messages)
|
||
|
|
yield ChatEvent(type="final", final_response=final_content, turn=assistant_turn)
|
||
|
|
return
|
||
|
|
|
||
|
|
ctx = self._tool_context()
|
||
|
|
for call in assistant_turn.tool_calls:
|
||
|
|
yield ChatEvent(type="tool_call", tool_name=call.name, tool_args=call.arguments, turn=assistant_turn)
|
||
|
|
result = self.tool_registry.execute(call.name, call.arguments, ctx)
|
||
|
|
messages.append(
|
||
|
|
{
|
||
|
|
"role": "tool",
|
||
|
|
"tool_call_id": call.id,
|
||
|
|
"name": call.name,
|
||
|
|
"content": result,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
yield ChatEvent(type="tool_result", tool_name=call.name, tool_args=call.arguments, tool_result=result)
|
||
|
|
|
||
|
|
messages[0]["content"] = self._system_prompt()
|
||
|
|
|
||
|
|
final_content = f"Agent exceeded max_iterations={self.max_iterations}"
|
||
|
|
self._append_history(user_message, final_content)
|
||
|
|
self.last_messages = list(messages)
|
||
|
|
yield ChatEvent(type="final", final_response=final_content)
|
||
|
|
|
||
|
|
def build_messages(self, user_message: str) -> List[Dict[str, Any]]:
|
||
|
|
memory_block = self.memory_store.render_context() if self.memory_store else ""
|
||
|
|
recent_history = self.history[-self.max_history_turns * 6 :]
|
||
|
|
self._maybe_consolidate_history_to_memory(recent_history, memory_block)
|
||
|
|
memory_block = self.memory_store.render_context() if self.memory_store else ""
|
||
|
|
compression = self.context_compressor.compact(recent_history, memory_block=memory_block)
|
||
|
|
|
||
|
|
system_content = self._system_prompt()
|
||
|
|
if memory_block:
|
||
|
|
system_content = f"{system_content}\n\n{memory_block}"
|
||
|
|
|
||
|
|
messages: List[Dict[str, Any]] = [{"role": "system", "content": system_content}]
|
||
|
|
if compression.summary_message:
|
||
|
|
messages.append(compression.summary_message)
|
||
|
|
messages.extend(compression.tail_messages)
|
||
|
|
messages.append({"role": "user", "content": user_message})
|
||
|
|
|
||
|
|
self.session_state["compression"] = {
|
||
|
|
"did_compact": compression.did_compact,
|
||
|
|
"estimated_tokens": compression.estimated_tokens,
|
||
|
|
"has_summary": bool(compression.summary_message),
|
||
|
|
}
|
||
|
|
return messages
|
||
|
|
|
||
|
|
def _maybe_consolidate_history_to_memory(self, history: List[Dict[str, str]], memory_block: str) -> None:
|
||
|
|
if not self.memory_store or not history:
|
||
|
|
return
|
||
|
|
estimated = self.context_compressor.estimate_tokens(history, memory_block, self.context_compressor.rolling_summary)
|
||
|
|
self.session_state["compression"] = {
|
||
|
|
"did_compact": False,
|
||
|
|
"estimated_tokens": estimated,
|
||
|
|
"has_summary": bool(self.context_compressor.rolling_summary),
|
||
|
|
}
|
||
|
|
if estimated < self.auto_memory_threshold_tokens:
|
||
|
|
return
|
||
|
|
|
||
|
|
signature = json.dumps(history[:-2], ensure_ascii=False, sort_keys=True) if len(history) > 2 else json.dumps(history, ensure_ascii=False, sort_keys=True)
|
||
|
|
if signature == self._last_auto_memory_signature:
|
||
|
|
return
|
||
|
|
|
||
|
|
summary = self.context_compressor.build_memory_summary(history[:-2] or history)
|
||
|
|
if not summary:
|
||
|
|
return
|
||
|
|
entry = self.memory_store.add_if_new(summary, kind="memory")
|
||
|
|
if entry is not None:
|
||
|
|
self._last_auto_memory_signature = signature
|
||
|
|
self.session_state["auto_memory"] = {
|
||
|
|
"triggered": True,
|
||
|
|
"entry_id": entry.id,
|
||
|
|
"threshold_tokens": self.auto_memory_threshold_tokens,
|
||
|
|
}
|
||
|
|
|
||
|
|
def _stream_turn(self, messages: List[Dict[str, Any]]) -> Iterator[ChatEvent | AssistantTurn]:
|
||
|
|
collected_turn: Optional[AssistantTurn] = None
|
||
|
|
for event in self.provider.stream_generate(messages, self.tool_registry.definitions()):
|
||
|
|
if event.type == "reasoning":
|
||
|
|
yield ChatEvent(type="reasoning", delta=event.delta, raw=event.raw)
|
||
|
|
elif event.type == "content":
|
||
|
|
yield ChatEvent(type="content", delta=event.delta, raw=event.raw)
|
||
|
|
elif event.type == "turn":
|
||
|
|
collected_turn = event.turn
|
||
|
|
if collected_turn is None:
|
||
|
|
raise RuntimeError("Provider stream ended without a final turn")
|
||
|
|
return collected_turn
|
||
|
|
|
||
|
|
def _append_history(self, user_message: str, final_content: str) -> None:
|
||
|
|
self.history.append({"role": "user", "content": user_message})
|
||
|
|
self.history.append({"role": "assistant", "content": final_content})
|
||
|
|
|
||
|
|
def _system_prompt(self) -> str:
|
||
|
|
return build_system_prompt(
|
||
|
|
skill_store=self.skill_store,
|
||
|
|
active_skills=self.session_state["active_skills"],
|
||
|
|
tool_registry=self.tool_registry,
|
||
|
|
base_prompt=self.system_prompt,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _tool_context(self) -> ToolContext:
|
||
|
|
return ToolContext(
|
||
|
|
workspace=self.workspace,
|
||
|
|
skill_store=self.skill_store,
|
||
|
|
dispatcher=self.dispatcher,
|
||
|
|
memory_store=self.memory_store,
|
||
|
|
session=self.session_state,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _assistant_message_to_dict(turn: AssistantTurn) -> Dict[str, Any]:
|
||
|
|
message: Dict[str, Any] = {"role": "assistant", "content": turn.content}
|
||
|
|
if turn.tool_calls:
|
||
|
|
message["tool_calls"] = [
|
||
|
|
{
|
||
|
|
"id": call.id,
|
||
|
|
"type": "function",
|
||
|
|
"function": {
|
||
|
|
"name": call.name,
|
||
|
|
"arguments": json.dumps(call.arguments, ensure_ascii=False),
|
||
|
|
},
|
||
|
|
}
|
||
|
|
for call in turn.tool_calls
|
||
|
|
]
|
||
|
|
return message
|