my_agent/session.py

227 lines
9.5 KiB
Python
Raw Normal View History

from __future__ import annotations
from dataclasses import dataclass, field
import json
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional
from core_agent.compression import RollingContextCompressor
from core_agent.dispatch import TaskDispatcher
from core_agent.memory import SimpleMemoryStore
from core_agent.prompts import DEFAULT_SYSTEM_PROMPT, build_system_prompt
from core_agent.providers.base import AgentProvider, AssistantTurn, StreamEvent, ToolCall
from core_agent.skills import SkillStore
from core_agent.tools.builtin import build_default_registry
from core_agent.tools.registry import ToolContext, ToolRegistry
@dataclass(slots=True)
class ChatEvent:
type: str
delta: str = ""
tool_name: str = ""
tool_args: Dict[str, Any] = field(default_factory=dict)
tool_result: str = ""
turn: Optional[AssistantTurn] = None
final_response: str = ""
raw: Any = None
@dataclass(slots=True)
class AgentRunResult:
final_response: str
messages: List[Dict[str, Any]]
session: Dict[str, Any] = field(default_factory=dict)
class ConversationSession:
"""Multi-turn conversation state with bounded history and tool loops."""
def __init__(
self,
*,
provider: AgentProvider,
workspace: str | Path,
skill_dirs: Optional[List[str | Path]] = None,
tool_registry: Optional[ToolRegistry] = None,
dispatcher: Optional[TaskDispatcher] = None,
system_prompt: Optional[str] = None,
max_iterations: int = 12,
max_history_turns: int = 5,
memory_store: Optional[SimpleMemoryStore] = None,
context_compressor: Optional[RollingContextCompressor] = None,
auto_memory_threshold_tokens: int = 30000,
active_skills: Optional[List[str]] = None,
) -> None:
self.provider = provider
self.workspace = Path(workspace).resolve()
self.skill_store = SkillStore(skill_dirs or [self.workspace / "skills"])
self.tool_registry = tool_registry or build_default_registry()
self.dispatcher = dispatcher or TaskDispatcher()
self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
self.max_iterations = max_iterations
self.max_history_turns = max_history_turns
self.memory_store = memory_store or SimpleMemoryStore(self.workspace / ".core_agent" / "memory.json")
self.context_compressor = context_compressor or RollingContextCompressor()
self.auto_memory_threshold_tokens = auto_memory_threshold_tokens
self.history: List[Dict[str, str]] = []
self.last_messages: List[Dict[str, Any]] = []
self._last_auto_memory_signature: str = ""
self.session_state: Dict[str, Any] = {
"active_skills": list(active_skills or []),
"workspace": str(self.workspace),
}
def ask(self, user_message: str) -> AgentRunResult:
final_event: Optional[ChatEvent] = None
for event in self.stream_ask(user_message):
final_event = event
if final_event is None or final_event.type != "final":
raise RuntimeError("Conversation ended without a final response")
return AgentRunResult(
final_response=final_event.final_response,
messages=list(self.last_messages),
session=dict(self.session_state),
)
def stream_ask(self, user_message: str) -> Iterator[ChatEvent]:
messages = self.build_messages(user_message)
self.last_messages = messages
final_content = ""
for _ in range(self.max_iterations):
yield ChatEvent(type="round_start")
assistant_turn = yield from self._stream_turn(messages)
messages.append(_assistant_message_to_dict(assistant_turn))
if not assistant_turn.tool_calls:
final_content = assistant_turn.content or ""
self._append_history(user_message, final_content)
self.last_messages = list(messages)
yield ChatEvent(type="final", final_response=final_content, turn=assistant_turn)
return
ctx = self._tool_context()
for call in assistant_turn.tool_calls:
yield ChatEvent(type="tool_call", tool_name=call.name, tool_args=call.arguments, turn=assistant_turn)
result = self.tool_registry.execute(call.name, call.arguments, ctx)
messages.append(
{
"role": "tool",
"tool_call_id": call.id,
"name": call.name,
"content": result,
}
)
yield ChatEvent(type="tool_result", tool_name=call.name, tool_args=call.arguments, tool_result=result)
messages[0]["content"] = self._system_prompt()
final_content = f"Agent exceeded max_iterations={self.max_iterations}"
self._append_history(user_message, final_content)
self.last_messages = list(messages)
yield ChatEvent(type="final", final_response=final_content)
def build_messages(self, user_message: str) -> List[Dict[str, Any]]:
memory_block = self.memory_store.render_context() if self.memory_store else ""
recent_history = self.history[-self.max_history_turns * 6 :]
self._maybe_consolidate_history_to_memory(recent_history, memory_block)
memory_block = self.memory_store.render_context() if self.memory_store else ""
compression = self.context_compressor.compact(recent_history, memory_block=memory_block)
system_content = self._system_prompt()
if memory_block:
system_content = f"{system_content}\n\n{memory_block}"
messages: List[Dict[str, Any]] = [{"role": "system", "content": system_content}]
if compression.summary_message:
messages.append(compression.summary_message)
messages.extend(compression.tail_messages)
messages.append({"role": "user", "content": user_message})
self.session_state["compression"] = {
"did_compact": compression.did_compact,
"estimated_tokens": compression.estimated_tokens,
"has_summary": bool(compression.summary_message),
}
return messages
def _maybe_consolidate_history_to_memory(self, history: List[Dict[str, str]], memory_block: str) -> None:
if not self.memory_store or not history:
return
estimated = self.context_compressor.estimate_tokens(history, memory_block, self.context_compressor.rolling_summary)
self.session_state["compression"] = {
"did_compact": False,
"estimated_tokens": estimated,
"has_summary": bool(self.context_compressor.rolling_summary),
}
if estimated < self.auto_memory_threshold_tokens:
return
signature = json.dumps(history[:-2], ensure_ascii=False, sort_keys=True) if len(history) > 2 else json.dumps(history, ensure_ascii=False, sort_keys=True)
if signature == self._last_auto_memory_signature:
return
summary = self.context_compressor.build_memory_summary(history[:-2] or history)
if not summary:
return
entry = self.memory_store.add_if_new(summary, kind="memory")
if entry is not None:
self._last_auto_memory_signature = signature
self.session_state["auto_memory"] = {
"triggered": True,
"entry_id": entry.id,
"threshold_tokens": self.auto_memory_threshold_tokens,
}
def _stream_turn(self, messages: List[Dict[str, Any]]) -> Iterator[ChatEvent | AssistantTurn]:
collected_turn: Optional[AssistantTurn] = None
for event in self.provider.stream_generate(messages, self.tool_registry.definitions()):
if event.type == "reasoning":
yield ChatEvent(type="reasoning", delta=event.delta, raw=event.raw)
elif event.type == "content":
yield ChatEvent(type="content", delta=event.delta, raw=event.raw)
elif event.type == "turn":
collected_turn = event.turn
if collected_turn is None:
raise RuntimeError("Provider stream ended without a final turn")
return collected_turn
def _append_history(self, user_message: str, final_content: str) -> None:
self.history.append({"role": "user", "content": user_message})
self.history.append({"role": "assistant", "content": final_content})
def _system_prompt(self) -> str:
return build_system_prompt(
skill_store=self.skill_store,
active_skills=self.session_state["active_skills"],
tool_registry=self.tool_registry,
base_prompt=self.system_prompt,
)
def _tool_context(self) -> ToolContext:
return ToolContext(
workspace=self.workspace,
skill_store=self.skill_store,
dispatcher=self.dispatcher,
memory_store=self.memory_store,
session=self.session_state,
)
def _assistant_message_to_dict(turn: AssistantTurn) -> Dict[str, Any]:
message: Dict[str, Any] = {"role": "assistant", "content": turn.content}
if turn.tool_calls:
message["tool_calls"] = [
{
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": json.dumps(call.arguments, ensure_ascii=False),
},
}
for call in turn.tool_calls
]
return message