121 lines
4.6 KiB
Python
121 lines
4.6 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List
|
|
|
|
|
|
SUMMARY_PREFIX = (
|
|
"[CONTEXT COMPACTION - REFERENCE ONLY] Earlier turns were compacted into the "
|
|
"summary below. Treat it as background reference, not as fresh user input. "
|
|
"Prefer the latest user message and the authoritative memory block over this summary.\n\n"
|
|
)
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class CompressionResult:
|
|
summary_message: Dict[str, str] | None
|
|
tail_messages: List[Dict[str, str]]
|
|
did_compact: bool
|
|
estimated_tokens: int
|
|
|
|
|
|
class RollingContextCompressor:
|
|
"""Head/tail preserving rolling summary for long multi-turn chats."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
max_input_tokens: int = 12000,
|
|
keep_last_turns: int = 3,
|
|
summary_char_limit: int = 4000,
|
|
) -> None:
|
|
self.max_input_tokens = max_input_tokens
|
|
self.keep_last_turns = keep_last_turns
|
|
self.summary_char_limit = summary_char_limit
|
|
self.rolling_summary = ""
|
|
|
|
def compact(self, history: List[Dict[str, str]], memory_block: str = "") -> CompressionResult:
|
|
estimated_tokens = self.estimate_tokens(history, memory_block, self.rolling_summary)
|
|
if estimated_tokens <= self.max_input_tokens:
|
|
return CompressionResult(
|
|
summary_message=self._summary_message() if self.rolling_summary else None,
|
|
tail_messages=list(history),
|
|
did_compact=False,
|
|
estimated_tokens=estimated_tokens,
|
|
)
|
|
|
|
tail_count = max(0, self.keep_last_turns * 2)
|
|
tail_messages = list(history[-tail_count:]) if tail_count else []
|
|
middle_messages = history[:-tail_count] if tail_count else list(history)
|
|
if not middle_messages:
|
|
return CompressionResult(
|
|
summary_message=self._summary_message() if self.rolling_summary else None,
|
|
tail_messages=tail_messages,
|
|
did_compact=False,
|
|
estimated_tokens=estimated_tokens,
|
|
)
|
|
|
|
merged = self._merge_summary(self.rolling_summary, self._summarize_messages(middle_messages))
|
|
self.rolling_summary = merged[: self.summary_char_limit].strip()
|
|
return CompressionResult(
|
|
summary_message=self._summary_message(),
|
|
tail_messages=tail_messages,
|
|
did_compact=True,
|
|
estimated_tokens=self.estimate_tokens(tail_messages, memory_block, self.rolling_summary),
|
|
)
|
|
|
|
def build_memory_summary(self, history: List[Dict[str, str]]) -> str:
|
|
body = self._summarize_messages(history)
|
|
if not body:
|
|
return ""
|
|
return (
|
|
"Auto-consolidated long-context summary.\n\n"
|
|
f"{body}"
|
|
).strip()
|
|
|
|
def estimate_tokens(self, history: List[Dict[str, str]], memory_block: str = "", summary: str = "") -> int:
|
|
text = memory_block + summary + "\n".join(
|
|
f"{item.get('role', '')}: {item.get('content', '')}" for item in history
|
|
)
|
|
return max(1, len(text) // 4)
|
|
|
|
def _summary_message(self) -> Dict[str, str]:
|
|
return {"role": "assistant", "content": SUMMARY_PREFIX + self.rolling_summary}
|
|
|
|
def _merge_summary(self, existing: str, fresh: str) -> str:
|
|
if not existing:
|
|
return fresh
|
|
return (
|
|
"## Prior Summary\n"
|
|
f"{existing.strip()}\n\n"
|
|
"## Newly Compacted Turns\n"
|
|
f"{fresh.strip()}"
|
|
)
|
|
|
|
def _summarize_messages(self, messages: List[Dict[str, str]]) -> str:
|
|
facts: List[str] = []
|
|
latest_user = ""
|
|
latest_assistant = ""
|
|
for message in messages:
|
|
role = message.get("role", "")
|
|
content = str(message.get("content", "")).strip()
|
|
if not content:
|
|
continue
|
|
compact = " ".join(content.split())
|
|
if len(compact) > 240:
|
|
compact = compact[:240] + "..."
|
|
if role == "user":
|
|
latest_user = compact
|
|
facts.append(f"- User asked: {compact}")
|
|
elif role == "assistant":
|
|
latest_assistant = compact
|
|
facts.append(f"- Assistant responded: {compact}")
|
|
else:
|
|
facts.append(f"- {role}: {compact}")
|
|
|
|
lines = ["## Active Task", latest_user or "Continue the latest user request.", ""]
|
|
lines.append("## Recent Progress")
|
|
lines.extend(facts[-12:] or ["- No earlier progress captured."])
|
|
lines.extend(["", "## Remaining Work", latest_assistant or "Use the latest visible conversation state to continue."])
|
|
return "\n".join(lines).strip()
|