my_agent/compression.py

121 lines
4.6 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List
SUMMARY_PREFIX = (
"[CONTEXT COMPACTION - REFERENCE ONLY] Earlier turns were compacted into the "
"summary below. Treat it as background reference, not as fresh user input. "
"Prefer the latest user message and the authoritative memory block over this summary.\n\n"
)
@dataclass(slots=True)
class CompressionResult:
summary_message: Dict[str, str] | None
tail_messages: List[Dict[str, str]]
did_compact: bool
estimated_tokens: int
class RollingContextCompressor:
"""Head/tail preserving rolling summary for long multi-turn chats."""
def __init__(
self,
*,
max_input_tokens: int = 12000,
keep_last_turns: int = 3,
summary_char_limit: int = 4000,
) -> None:
self.max_input_tokens = max_input_tokens
self.keep_last_turns = keep_last_turns
self.summary_char_limit = summary_char_limit
self.rolling_summary = ""
def compact(self, history: List[Dict[str, str]], memory_block: str = "") -> CompressionResult:
estimated_tokens = self.estimate_tokens(history, memory_block, self.rolling_summary)
if estimated_tokens <= self.max_input_tokens:
return CompressionResult(
summary_message=self._summary_message() if self.rolling_summary else None,
tail_messages=list(history),
did_compact=False,
estimated_tokens=estimated_tokens,
)
tail_count = max(0, self.keep_last_turns * 2)
tail_messages = list(history[-tail_count:]) if tail_count else []
middle_messages = history[:-tail_count] if tail_count else list(history)
if not middle_messages:
return CompressionResult(
summary_message=self._summary_message() if self.rolling_summary else None,
tail_messages=tail_messages,
did_compact=False,
estimated_tokens=estimated_tokens,
)
merged = self._merge_summary(self.rolling_summary, self._summarize_messages(middle_messages))
self.rolling_summary = merged[: self.summary_char_limit].strip()
return CompressionResult(
summary_message=self._summary_message(),
tail_messages=tail_messages,
did_compact=True,
estimated_tokens=self.estimate_tokens(tail_messages, memory_block, self.rolling_summary),
)
def build_memory_summary(self, history: List[Dict[str, str]]) -> str:
body = self._summarize_messages(history)
if not body:
return ""
return (
"Auto-consolidated long-context summary.\n\n"
f"{body}"
).strip()
def estimate_tokens(self, history: List[Dict[str, str]], memory_block: str = "", summary: str = "") -> int:
text = memory_block + summary + "\n".join(
f"{item.get('role', '')}: {item.get('content', '')}" for item in history
)
return max(1, len(text) // 4)
def _summary_message(self) -> Dict[str, str]:
return {"role": "assistant", "content": SUMMARY_PREFIX + self.rolling_summary}
def _merge_summary(self, existing: str, fresh: str) -> str:
if not existing:
return fresh
return (
"## Prior Summary\n"
f"{existing.strip()}\n\n"
"## Newly Compacted Turns\n"
f"{fresh.strip()}"
)
def _summarize_messages(self, messages: List[Dict[str, str]]) -> str:
facts: List[str] = []
latest_user = ""
latest_assistant = ""
for message in messages:
role = message.get("role", "")
content = str(message.get("content", "")).strip()
if not content:
continue
compact = " ".join(content.split())
if len(compact) > 240:
compact = compact[:240] + "..."
if role == "user":
latest_user = compact
facts.append(f"- User asked: {compact}")
elif role == "assistant":
latest_assistant = compact
facts.append(f"- Assistant responded: {compact}")
else:
facts.append(f"- {role}: {compact}")
lines = ["## Active Task", latest_user or "Continue the latest user request.", ""]
lines.append("## Recent Progress")
lines.extend(facts[-12:] or ["- No earlier progress captured."])
lines.extend(["", "## Remaining Work", latest_assistant or "Use the latest visible conversation state to continue."])
return "\n".join(lines).strip()