dashboard-nanobot/backend/services/bot_message_service.py

247 lines
9.4 KiB
Python
Raw Permalink Normal View History

2026-03-26 16:12:46 +00:00
from datetime import datetime
from typing import Any, Callable, Dict, Optional
from fastapi import HTTPException
from sqlmodel import Session, select
from models.bot import BotInstance, BotMessage
CacheKeyMessages = Callable[[str, int], str]
CacheKeyMessagesPage = Callable[[str, int, Optional[int]], str]
SerializeMessageRow = Callable[[str, BotMessage], Dict[str, Any]]
ResolveLocalDayRange = Callable[[str, Optional[int]], tuple[datetime, datetime]]
InvalidateMessagesCache = Callable[[str], None]
GetChatPullPageSize = Callable[[], int]
class BotMessageService:
def __init__(
self,
*,
cache: Any,
cache_key_bot_messages: CacheKeyMessages,
cache_key_bot_messages_page: CacheKeyMessagesPage,
serialize_bot_message_row: SerializeMessageRow,
resolve_local_day_range: ResolveLocalDayRange,
invalidate_bot_messages_cache: InvalidateMessagesCache,
get_chat_pull_page_size: GetChatPullPageSize,
) -> None:
self._cache = cache
self._cache_key_bot_messages = cache_key_bot_messages
self._cache_key_bot_messages_page = cache_key_bot_messages_page
self._serialize_bot_message_row = serialize_bot_message_row
self._resolve_local_day_range = resolve_local_day_range
self._invalidate_bot_messages_cache = invalidate_bot_messages_cache
self._get_chat_pull_page_size = get_chat_pull_page_size
def _require_bot(self, *, session: Session, bot_id: str) -> BotInstance:
bot = session.get(BotInstance, bot_id)
if not bot:
raise HTTPException(status_code=404, detail="Bot not found")
return bot
def list_messages(self, *, session: Session, bot_id: str, limit: int = 200) -> list[Dict[str, Any]]:
self._require_bot(session=session, bot_id=bot_id)
safe_limit = max(1, min(int(limit), 500))
cached = self._cache.get_json(self._cache_key_bot_messages(bot_id, safe_limit))
if isinstance(cached, list):
return cached
rows = session.exec(
select(BotMessage)
.where(BotMessage.bot_id == bot_id)
.order_by(BotMessage.created_at.desc(), BotMessage.id.desc())
.limit(safe_limit)
).all()
ordered = list(reversed(rows))
payload = [self._serialize_bot_message_row(bot_id, row) for row in ordered]
self._cache.set_json(self._cache_key_bot_messages(bot_id, safe_limit), payload, ttl=30)
return payload
def list_messages_page(
self,
*,
session: Session,
bot_id: str,
limit: Optional[int] = None,
before_id: Optional[int] = None,
) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
configured_limit = self._get_chat_pull_page_size()
safe_limit = max(1, min(int(limit if limit is not None else configured_limit), 500))
safe_before_id = int(before_id) if isinstance(before_id, int) and before_id > 0 else None
cache_key = self._cache_key_bot_messages_page(bot_id, safe_limit, safe_before_id)
cached = self._cache.get_json(cache_key)
if isinstance(cached, dict) and isinstance(cached.get("items"), list):
return cached
stmt = (
select(BotMessage)
.where(BotMessage.bot_id == bot_id)
.order_by(BotMessage.created_at.desc(), BotMessage.id.desc())
.limit(safe_limit + 1)
)
if safe_before_id is not None:
stmt = stmt.where(BotMessage.id < safe_before_id)
rows = session.exec(stmt).all()
has_more = len(rows) > safe_limit
if has_more:
rows = rows[:safe_limit]
ordered = list(reversed(rows))
items = [self._serialize_bot_message_row(bot_id, row) for row in ordered]
next_before_id = rows[-1].id if rows else None
payload = {
"items": items,
"has_more": bool(has_more),
"next_before_id": next_before_id,
"limit": safe_limit,
}
self._cache.set_json(cache_key, payload, ttl=30)
return payload
def list_messages_by_date(
self,
*,
session: Session,
bot_id: str,
date: str,
tz_offset_minutes: Optional[int] = None,
limit: Optional[int] = None,
) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
utc_start, utc_end = self._resolve_local_day_range(date, tz_offset_minutes)
configured_limit = max(60, self._get_chat_pull_page_size())
safe_limit = max(12, min(int(limit if limit is not None else configured_limit), 240))
before_limit = max(3, min(18, safe_limit // 4))
after_limit = max(0, safe_limit - before_limit - 1)
exact_anchor = session.exec(
select(BotMessage)
.where(
BotMessage.bot_id == bot_id,
BotMessage.created_at >= utc_start,
BotMessage.created_at < utc_end,
)
.order_by(BotMessage.created_at.asc(), BotMessage.id.asc())
.limit(1)
).first()
anchor = exact_anchor
matched_exact_date = exact_anchor is not None
if anchor is None:
next_row = session.exec(
select(BotMessage)
.where(BotMessage.bot_id == bot_id, BotMessage.created_at >= utc_end)
.order_by(BotMessage.created_at.asc(), BotMessage.id.asc())
.limit(1)
).first()
prev_row = session.exec(
select(BotMessage)
.where(BotMessage.bot_id == bot_id, BotMessage.created_at < utc_start)
.order_by(BotMessage.created_at.desc(), BotMessage.id.desc())
.limit(1)
).first()
if next_row and prev_row:
gap_after = next_row.created_at - utc_end
gap_before = utc_start - prev_row.created_at
anchor = next_row if gap_after <= gap_before else prev_row
else:
anchor = next_row or prev_row
if anchor is None or anchor.id is None:
return {
"items": [],
"anchor_id": None,
"resolved_ts": None,
"matched_exact_date": False,
"has_more_before": False,
"has_more_after": False,
}
before_rows = session.exec(
select(BotMessage)
.where(BotMessage.bot_id == bot_id, BotMessage.id < anchor.id)
.order_by(BotMessage.created_at.desc(), BotMessage.id.desc())
.limit(before_limit)
).all()
after_rows = session.exec(
select(BotMessage)
.where(BotMessage.bot_id == bot_id, BotMessage.id > anchor.id)
.order_by(BotMessage.created_at.asc(), BotMessage.id.asc())
.limit(after_limit)
).all()
ordered = list(reversed(before_rows)) + [anchor] + after_rows
first_row = ordered[0] if ordered else None
last_row = ordered[-1] if ordered else None
has_more_before = False
if first_row is not None and first_row.id is not None:
has_more_before = (
session.exec(
select(BotMessage.id)
.where(BotMessage.bot_id == bot_id, BotMessage.id < first_row.id)
.order_by(BotMessage.id.desc())
.limit(1)
).first()
is not None
)
has_more_after = False
if last_row is not None and last_row.id is not None:
has_more_after = (
session.exec(
select(BotMessage.id)
.where(BotMessage.bot_id == bot_id, BotMessage.id > last_row.id)
.order_by(BotMessage.id.asc())
.limit(1)
).first()
is not None
)
return {
"items": [self._serialize_bot_message_row(bot_id, row) for row in ordered],
"anchor_id": anchor.id,
"resolved_ts": int(anchor.created_at.timestamp() * 1000),
"matched_exact_date": matched_exact_date,
"has_more_before": has_more_before,
"has_more_after": has_more_after,
}
def update_feedback(
self,
*,
session: Session,
bot_id: str,
message_id: int,
feedback: Optional[str],
) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
row = session.get(BotMessage, message_id)
if not row or row.bot_id != bot_id:
raise HTTPException(status_code=404, detail="Message not found")
if row.role != "assistant":
raise HTTPException(status_code=400, detail="Only assistant messages support feedback")
raw = str(feedback or "").strip().lower()
if raw in {"", "none", "null"}:
row.feedback = None
row.feedback_at = None
elif raw in {"up", "down"}:
row.feedback = raw
row.feedback_at = datetime.utcnow()
else:
raise HTTPException(status_code=400, detail="feedback must be 'up' or 'down'")
session.add(row)
session.commit()
self._invalidate_bot_messages_cache(bot_id)
return {
"status": "updated",
"bot_id": bot_id,
"message_id": row.id,
"feedback": row.feedback,
"feedback_at": row.feedback_at.isoformat() if row.feedback_at else None,
}