dashboard-nanobot/backend/services/bot_message_service.py

247 lines
9.4 KiB
Python

from datetime import datetime
from typing import Any, Callable, Dict, Optional
from fastapi import HTTPException
from sqlmodel import Session, select
from models.bot import BotInstance, BotMessage
CacheKeyMessages = Callable[[str, int], str]
CacheKeyMessagesPage = Callable[[str, int, Optional[int]], str]
SerializeMessageRow = Callable[[str, BotMessage], Dict[str, Any]]
ResolveLocalDayRange = Callable[[str, Optional[int]], tuple[datetime, datetime]]
InvalidateMessagesCache = Callable[[str], None]
GetChatPullPageSize = Callable[[], int]
class BotMessageService:
def __init__(
self,
*,
cache: Any,
cache_key_bot_messages: CacheKeyMessages,
cache_key_bot_messages_page: CacheKeyMessagesPage,
serialize_bot_message_row: SerializeMessageRow,
resolve_local_day_range: ResolveLocalDayRange,
invalidate_bot_messages_cache: InvalidateMessagesCache,
get_chat_pull_page_size: GetChatPullPageSize,
) -> None:
self._cache = cache
self._cache_key_bot_messages = cache_key_bot_messages
self._cache_key_bot_messages_page = cache_key_bot_messages_page
self._serialize_bot_message_row = serialize_bot_message_row
self._resolve_local_day_range = resolve_local_day_range
self._invalidate_bot_messages_cache = invalidate_bot_messages_cache
self._get_chat_pull_page_size = get_chat_pull_page_size
def _require_bot(self, *, session: Session, bot_id: str) -> BotInstance:
bot = session.get(BotInstance, bot_id)
if not bot:
raise HTTPException(status_code=404, detail="Bot not found")
return bot
def list_messages(self, *, session: Session, bot_id: str, limit: int = 200) -> list[Dict[str, Any]]:
self._require_bot(session=session, bot_id=bot_id)
safe_limit = max(1, min(int(limit), 500))
cached = self._cache.get_json(self._cache_key_bot_messages(bot_id, safe_limit))
if isinstance(cached, list):
return cached
rows = session.exec(
select(BotMessage)
.where(BotMessage.bot_id == bot_id)
.order_by(BotMessage.created_at.desc(), BotMessage.id.desc())
.limit(safe_limit)
).all()
ordered = list(reversed(rows))
payload = [self._serialize_bot_message_row(bot_id, row) for row in ordered]
self._cache.set_json(self._cache_key_bot_messages(bot_id, safe_limit), payload, ttl=30)
return payload
def list_messages_page(
self,
*,
session: Session,
bot_id: str,
limit: Optional[int] = None,
before_id: Optional[int] = None,
) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
configured_limit = self._get_chat_pull_page_size()
safe_limit = max(1, min(int(limit if limit is not None else configured_limit), 500))
safe_before_id = int(before_id) if isinstance(before_id, int) and before_id > 0 else None
cache_key = self._cache_key_bot_messages_page(bot_id, safe_limit, safe_before_id)
cached = self._cache.get_json(cache_key)
if isinstance(cached, dict) and isinstance(cached.get("items"), list):
return cached
stmt = (
select(BotMessage)
.where(BotMessage.bot_id == bot_id)
.order_by(BotMessage.created_at.desc(), BotMessage.id.desc())
.limit(safe_limit + 1)
)
if safe_before_id is not None:
stmt = stmt.where(BotMessage.id < safe_before_id)
rows = session.exec(stmt).all()
has_more = len(rows) > safe_limit
if has_more:
rows = rows[:safe_limit]
ordered = list(reversed(rows))
items = [self._serialize_bot_message_row(bot_id, row) for row in ordered]
next_before_id = rows[-1].id if rows else None
payload = {
"items": items,
"has_more": bool(has_more),
"next_before_id": next_before_id,
"limit": safe_limit,
}
self._cache.set_json(cache_key, payload, ttl=30)
return payload
def list_messages_by_date(
self,
*,
session: Session,
bot_id: str,
date: str,
tz_offset_minutes: Optional[int] = None,
limit: Optional[int] = None,
) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
utc_start, utc_end = self._resolve_local_day_range(date, tz_offset_minutes)
configured_limit = max(60, self._get_chat_pull_page_size())
safe_limit = max(12, min(int(limit if limit is not None else configured_limit), 240))
before_limit = max(3, min(18, safe_limit // 4))
after_limit = max(0, safe_limit - before_limit - 1)
exact_anchor = session.exec(
select(BotMessage)
.where(
BotMessage.bot_id == bot_id,
BotMessage.created_at >= utc_start,
BotMessage.created_at < utc_end,
)
.order_by(BotMessage.created_at.asc(), BotMessage.id.asc())
.limit(1)
).first()
anchor = exact_anchor
matched_exact_date = exact_anchor is not None
if anchor is None:
next_row = session.exec(
select(BotMessage)
.where(BotMessage.bot_id == bot_id, BotMessage.created_at >= utc_end)
.order_by(BotMessage.created_at.asc(), BotMessage.id.asc())
.limit(1)
).first()
prev_row = session.exec(
select(BotMessage)
.where(BotMessage.bot_id == bot_id, BotMessage.created_at < utc_start)
.order_by(BotMessage.created_at.desc(), BotMessage.id.desc())
.limit(1)
).first()
if next_row and prev_row:
gap_after = next_row.created_at - utc_end
gap_before = utc_start - prev_row.created_at
anchor = next_row if gap_after <= gap_before else prev_row
else:
anchor = next_row or prev_row
if anchor is None or anchor.id is None:
return {
"items": [],
"anchor_id": None,
"resolved_ts": None,
"matched_exact_date": False,
"has_more_before": False,
"has_more_after": False,
}
before_rows = session.exec(
select(BotMessage)
.where(BotMessage.bot_id == bot_id, BotMessage.id < anchor.id)
.order_by(BotMessage.created_at.desc(), BotMessage.id.desc())
.limit(before_limit)
).all()
after_rows = session.exec(
select(BotMessage)
.where(BotMessage.bot_id == bot_id, BotMessage.id > anchor.id)
.order_by(BotMessage.created_at.asc(), BotMessage.id.asc())
.limit(after_limit)
).all()
ordered = list(reversed(before_rows)) + [anchor] + after_rows
first_row = ordered[0] if ordered else None
last_row = ordered[-1] if ordered else None
has_more_before = False
if first_row is not None and first_row.id is not None:
has_more_before = (
session.exec(
select(BotMessage.id)
.where(BotMessage.bot_id == bot_id, BotMessage.id < first_row.id)
.order_by(BotMessage.id.desc())
.limit(1)
).first()
is not None
)
has_more_after = False
if last_row is not None and last_row.id is not None:
has_more_after = (
session.exec(
select(BotMessage.id)
.where(BotMessage.bot_id == bot_id, BotMessage.id > last_row.id)
.order_by(BotMessage.id.asc())
.limit(1)
).first()
is not None
)
return {
"items": [self._serialize_bot_message_row(bot_id, row) for row in ordered],
"anchor_id": anchor.id,
"resolved_ts": int(anchor.created_at.timestamp() * 1000),
"matched_exact_date": matched_exact_date,
"has_more_before": has_more_before,
"has_more_after": has_more_after,
}
def update_feedback(
self,
*,
session: Session,
bot_id: str,
message_id: int,
feedback: Optional[str],
) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
row = session.get(BotMessage, message_id)
if not row or row.bot_id != bot_id:
raise HTTPException(status_code=404, detail="Message not found")
if row.role != "assistant":
raise HTTPException(status_code=400, detail="Only assistant messages support feedback")
raw = str(feedback or "").strip().lower()
if raw in {"", "none", "null"}:
row.feedback = None
row.feedback_at = None
elif raw in {"up", "down"}:
row.feedback = raw
row.feedback_at = datetime.utcnow()
else:
raise HTTPException(status_code=400, detail="feedback must be 'up' or 'down'")
session.add(row)
session.commit()
self._invalidate_bot_messages_cache(bot_id)
return {
"status": "updated",
"bot_id": bot_id,
"message_id": row.id,
"feedback": row.feedback,
"feedback_at": row.feedback_at.isoformat() if row.feedback_at else None,
}