from datetime import datetime from typing import Any, Callable, Dict, Optional from fastapi import HTTPException from sqlmodel import Session, select from models.bot import BotInstance, BotMessage CacheKeyMessages = Callable[[str, int], str] CacheKeyMessagesPage = Callable[[str, int, Optional[int]], str] SerializeMessageRow = Callable[[str, BotMessage], Dict[str, Any]] ResolveLocalDayRange = Callable[[str, Optional[int]], tuple[datetime, datetime]] InvalidateMessagesCache = Callable[[str], None] GetChatPullPageSize = Callable[[], int] class BotMessageService: def __init__( self, *, cache: Any, cache_key_bot_messages: CacheKeyMessages, cache_key_bot_messages_page: CacheKeyMessagesPage, serialize_bot_message_row: SerializeMessageRow, resolve_local_day_range: ResolveLocalDayRange, invalidate_bot_messages_cache: InvalidateMessagesCache, get_chat_pull_page_size: GetChatPullPageSize, ) -> None: self._cache = cache self._cache_key_bot_messages = cache_key_bot_messages self._cache_key_bot_messages_page = cache_key_bot_messages_page self._serialize_bot_message_row = serialize_bot_message_row self._resolve_local_day_range = resolve_local_day_range self._invalidate_bot_messages_cache = invalidate_bot_messages_cache self._get_chat_pull_page_size = get_chat_pull_page_size def _require_bot(self, *, session: Session, bot_id: str) -> BotInstance: bot = session.get(BotInstance, bot_id) if not bot: raise HTTPException(status_code=404, detail="Bot not found") return bot def list_messages(self, *, session: Session, bot_id: str, limit: int = 200) -> list[Dict[str, Any]]: self._require_bot(session=session, bot_id=bot_id) safe_limit = max(1, min(int(limit), 500)) cached = self._cache.get_json(self._cache_key_bot_messages(bot_id, safe_limit)) if isinstance(cached, list): return cached rows = session.exec( select(BotMessage) .where(BotMessage.bot_id == bot_id) .order_by(BotMessage.created_at.desc(), BotMessage.id.desc()) .limit(safe_limit) ).all() ordered = list(reversed(rows)) payload = [self._serialize_bot_message_row(bot_id, row) for row in ordered] self._cache.set_json(self._cache_key_bot_messages(bot_id, safe_limit), payload, ttl=30) return payload def list_messages_page( self, *, session: Session, bot_id: str, limit: Optional[int] = None, before_id: Optional[int] = None, ) -> Dict[str, Any]: self._require_bot(session=session, bot_id=bot_id) configured_limit = self._get_chat_pull_page_size() safe_limit = max(1, min(int(limit if limit is not None else configured_limit), 500)) safe_before_id = int(before_id) if isinstance(before_id, int) and before_id > 0 else None cache_key = self._cache_key_bot_messages_page(bot_id, safe_limit, safe_before_id) cached = self._cache.get_json(cache_key) if isinstance(cached, dict) and isinstance(cached.get("items"), list): return cached stmt = ( select(BotMessage) .where(BotMessage.bot_id == bot_id) .order_by(BotMessage.created_at.desc(), BotMessage.id.desc()) .limit(safe_limit + 1) ) if safe_before_id is not None: stmt = stmt.where(BotMessage.id < safe_before_id) rows = session.exec(stmt).all() has_more = len(rows) > safe_limit if has_more: rows = rows[:safe_limit] ordered = list(reversed(rows)) items = [self._serialize_bot_message_row(bot_id, row) for row in ordered] next_before_id = rows[-1].id if rows else None payload = { "items": items, "has_more": bool(has_more), "next_before_id": next_before_id, "limit": safe_limit, } self._cache.set_json(cache_key, payload, ttl=30) return payload def list_messages_by_date( self, *, session: Session, bot_id: str, date: str, tz_offset_minutes: Optional[int] = None, limit: Optional[int] = None, ) -> Dict[str, Any]: self._require_bot(session=session, bot_id=bot_id) utc_start, utc_end = self._resolve_local_day_range(date, tz_offset_minutes) configured_limit = max(60, self._get_chat_pull_page_size()) safe_limit = max(12, min(int(limit if limit is not None else configured_limit), 240)) before_limit = max(3, min(18, safe_limit // 4)) after_limit = max(0, safe_limit - before_limit - 1) exact_anchor = session.exec( select(BotMessage) .where( BotMessage.bot_id == bot_id, BotMessage.created_at >= utc_start, BotMessage.created_at < utc_end, ) .order_by(BotMessage.created_at.asc(), BotMessage.id.asc()) .limit(1) ).first() anchor = exact_anchor matched_exact_date = exact_anchor is not None if anchor is None: next_row = session.exec( select(BotMessage) .where(BotMessage.bot_id == bot_id, BotMessage.created_at >= utc_end) .order_by(BotMessage.created_at.asc(), BotMessage.id.asc()) .limit(1) ).first() prev_row = session.exec( select(BotMessage) .where(BotMessage.bot_id == bot_id, BotMessage.created_at < utc_start) .order_by(BotMessage.created_at.desc(), BotMessage.id.desc()) .limit(1) ).first() if next_row and prev_row: gap_after = next_row.created_at - utc_end gap_before = utc_start - prev_row.created_at anchor = next_row if gap_after <= gap_before else prev_row else: anchor = next_row or prev_row if anchor is None or anchor.id is None: return { "items": [], "anchor_id": None, "resolved_ts": None, "matched_exact_date": False, "has_more_before": False, "has_more_after": False, } before_rows = session.exec( select(BotMessage) .where(BotMessage.bot_id == bot_id, BotMessage.id < anchor.id) .order_by(BotMessage.created_at.desc(), BotMessage.id.desc()) .limit(before_limit) ).all() after_rows = session.exec( select(BotMessage) .where(BotMessage.bot_id == bot_id, BotMessage.id > anchor.id) .order_by(BotMessage.created_at.asc(), BotMessage.id.asc()) .limit(after_limit) ).all() ordered = list(reversed(before_rows)) + [anchor] + after_rows first_row = ordered[0] if ordered else None last_row = ordered[-1] if ordered else None has_more_before = False if first_row is not None and first_row.id is not None: has_more_before = ( session.exec( select(BotMessage.id) .where(BotMessage.bot_id == bot_id, BotMessage.id < first_row.id) .order_by(BotMessage.id.desc()) .limit(1) ).first() is not None ) has_more_after = False if last_row is not None and last_row.id is not None: has_more_after = ( session.exec( select(BotMessage.id) .where(BotMessage.bot_id == bot_id, BotMessage.id > last_row.id) .order_by(BotMessage.id.asc()) .limit(1) ).first() is not None ) return { "items": [self._serialize_bot_message_row(bot_id, row) for row in ordered], "anchor_id": anchor.id, "resolved_ts": int(anchor.created_at.timestamp() * 1000), "matched_exact_date": matched_exact_date, "has_more_before": has_more_before, "has_more_after": has_more_after, } def update_feedback( self, *, session: Session, bot_id: str, message_id: int, feedback: Optional[str], ) -> Dict[str, Any]: self._require_bot(session=session, bot_id=bot_id) row = session.get(BotMessage, message_id) if not row or row.bot_id != bot_id: raise HTTPException(status_code=404, detail="Message not found") if row.role != "assistant": raise HTTPException(status_code=400, detail="Only assistant messages support feedback") raw = str(feedback or "").strip().lower() if raw in {"", "none", "null"}: row.feedback = None row.feedback_at = None elif raw in {"up", "down"}: row.feedback = raw row.feedback_at = datetime.utcnow() else: raise HTTPException(status_code=400, detail="feedback must be 'up' or 'down'") session.add(row) session.commit() self._invalidate_bot_messages_cache(bot_id) return { "status": "updated", "bot_id": bot_id, "message_id": row.id, "feedback": row.feedback, "feedback_at": row.feedback_at.isoformat() if row.feedback_at else None, }