247 lines
9.4 KiB
Python
247 lines
9.4 KiB
Python
from datetime import datetime
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
from fastapi import HTTPException
|
|
from sqlmodel import Session, select
|
|
|
|
from models.bot import BotInstance, BotMessage
|
|
|
|
CacheKeyMessages = Callable[[str, int], str]
|
|
CacheKeyMessagesPage = Callable[[str, int, Optional[int]], str]
|
|
SerializeMessageRow = Callable[[str, BotMessage], Dict[str, Any]]
|
|
ResolveLocalDayRange = Callable[[str, Optional[int]], tuple[datetime, datetime]]
|
|
InvalidateMessagesCache = Callable[[str], None]
|
|
GetChatPullPageSize = Callable[[], int]
|
|
|
|
|
|
class BotMessageService:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
cache: Any,
|
|
cache_key_bot_messages: CacheKeyMessages,
|
|
cache_key_bot_messages_page: CacheKeyMessagesPage,
|
|
serialize_bot_message_row: SerializeMessageRow,
|
|
resolve_local_day_range: ResolveLocalDayRange,
|
|
invalidate_bot_messages_cache: InvalidateMessagesCache,
|
|
get_chat_pull_page_size: GetChatPullPageSize,
|
|
) -> None:
|
|
self._cache = cache
|
|
self._cache_key_bot_messages = cache_key_bot_messages
|
|
self._cache_key_bot_messages_page = cache_key_bot_messages_page
|
|
self._serialize_bot_message_row = serialize_bot_message_row
|
|
self._resolve_local_day_range = resolve_local_day_range
|
|
self._invalidate_bot_messages_cache = invalidate_bot_messages_cache
|
|
self._get_chat_pull_page_size = get_chat_pull_page_size
|
|
|
|
def _require_bot(self, *, session: Session, bot_id: str) -> BotInstance:
|
|
bot = session.get(BotInstance, bot_id)
|
|
if not bot:
|
|
raise HTTPException(status_code=404, detail="Bot not found")
|
|
return bot
|
|
|
|
def list_messages(self, *, session: Session, bot_id: str, limit: int = 200) -> list[Dict[str, Any]]:
|
|
self._require_bot(session=session, bot_id=bot_id)
|
|
safe_limit = max(1, min(int(limit), 500))
|
|
cached = self._cache.get_json(self._cache_key_bot_messages(bot_id, safe_limit))
|
|
if isinstance(cached, list):
|
|
return cached
|
|
rows = session.exec(
|
|
select(BotMessage)
|
|
.where(BotMessage.bot_id == bot_id)
|
|
.order_by(BotMessage.created_at.desc(), BotMessage.id.desc())
|
|
.limit(safe_limit)
|
|
).all()
|
|
ordered = list(reversed(rows))
|
|
payload = [self._serialize_bot_message_row(bot_id, row) for row in ordered]
|
|
self._cache.set_json(self._cache_key_bot_messages(bot_id, safe_limit), payload, ttl=30)
|
|
return payload
|
|
|
|
def list_messages_page(
|
|
self,
|
|
*,
|
|
session: Session,
|
|
bot_id: str,
|
|
limit: Optional[int] = None,
|
|
before_id: Optional[int] = None,
|
|
) -> Dict[str, Any]:
|
|
self._require_bot(session=session, bot_id=bot_id)
|
|
configured_limit = self._get_chat_pull_page_size()
|
|
safe_limit = max(1, min(int(limit if limit is not None else configured_limit), 500))
|
|
safe_before_id = int(before_id) if isinstance(before_id, int) and before_id > 0 else None
|
|
cache_key = self._cache_key_bot_messages_page(bot_id, safe_limit, safe_before_id)
|
|
cached = self._cache.get_json(cache_key)
|
|
if isinstance(cached, dict) and isinstance(cached.get("items"), list):
|
|
return cached
|
|
|
|
stmt = (
|
|
select(BotMessage)
|
|
.where(BotMessage.bot_id == bot_id)
|
|
.order_by(BotMessage.created_at.desc(), BotMessage.id.desc())
|
|
.limit(safe_limit + 1)
|
|
)
|
|
if safe_before_id is not None:
|
|
stmt = stmt.where(BotMessage.id < safe_before_id)
|
|
|
|
rows = session.exec(stmt).all()
|
|
has_more = len(rows) > safe_limit
|
|
if has_more:
|
|
rows = rows[:safe_limit]
|
|
ordered = list(reversed(rows))
|
|
items = [self._serialize_bot_message_row(bot_id, row) for row in ordered]
|
|
next_before_id = rows[-1].id if rows else None
|
|
payload = {
|
|
"items": items,
|
|
"has_more": bool(has_more),
|
|
"next_before_id": next_before_id,
|
|
"limit": safe_limit,
|
|
}
|
|
self._cache.set_json(cache_key, payload, ttl=30)
|
|
return payload
|
|
|
|
def list_messages_by_date(
|
|
self,
|
|
*,
|
|
session: Session,
|
|
bot_id: str,
|
|
date: str,
|
|
tz_offset_minutes: Optional[int] = None,
|
|
limit: Optional[int] = None,
|
|
) -> Dict[str, Any]:
|
|
self._require_bot(session=session, bot_id=bot_id)
|
|
utc_start, utc_end = self._resolve_local_day_range(date, tz_offset_minutes)
|
|
configured_limit = max(60, self._get_chat_pull_page_size())
|
|
safe_limit = max(12, min(int(limit if limit is not None else configured_limit), 240))
|
|
before_limit = max(3, min(18, safe_limit // 4))
|
|
after_limit = max(0, safe_limit - before_limit - 1)
|
|
|
|
exact_anchor = session.exec(
|
|
select(BotMessage)
|
|
.where(
|
|
BotMessage.bot_id == bot_id,
|
|
BotMessage.created_at >= utc_start,
|
|
BotMessage.created_at < utc_end,
|
|
)
|
|
.order_by(BotMessage.created_at.asc(), BotMessage.id.asc())
|
|
.limit(1)
|
|
).first()
|
|
|
|
anchor = exact_anchor
|
|
matched_exact_date = exact_anchor is not None
|
|
if anchor is None:
|
|
next_row = session.exec(
|
|
select(BotMessage)
|
|
.where(BotMessage.bot_id == bot_id, BotMessage.created_at >= utc_end)
|
|
.order_by(BotMessage.created_at.asc(), BotMessage.id.asc())
|
|
.limit(1)
|
|
).first()
|
|
prev_row = session.exec(
|
|
select(BotMessage)
|
|
.where(BotMessage.bot_id == bot_id, BotMessage.created_at < utc_start)
|
|
.order_by(BotMessage.created_at.desc(), BotMessage.id.desc())
|
|
.limit(1)
|
|
).first()
|
|
|
|
if next_row and prev_row:
|
|
gap_after = next_row.created_at - utc_end
|
|
gap_before = utc_start - prev_row.created_at
|
|
anchor = next_row if gap_after <= gap_before else prev_row
|
|
else:
|
|
anchor = next_row or prev_row
|
|
|
|
if anchor is None or anchor.id is None:
|
|
return {
|
|
"items": [],
|
|
"anchor_id": None,
|
|
"resolved_ts": None,
|
|
"matched_exact_date": False,
|
|
"has_more_before": False,
|
|
"has_more_after": False,
|
|
}
|
|
|
|
before_rows = session.exec(
|
|
select(BotMessage)
|
|
.where(BotMessage.bot_id == bot_id, BotMessage.id < anchor.id)
|
|
.order_by(BotMessage.created_at.desc(), BotMessage.id.desc())
|
|
.limit(before_limit)
|
|
).all()
|
|
after_rows = session.exec(
|
|
select(BotMessage)
|
|
.where(BotMessage.bot_id == bot_id, BotMessage.id > anchor.id)
|
|
.order_by(BotMessage.created_at.asc(), BotMessage.id.asc())
|
|
.limit(after_limit)
|
|
).all()
|
|
|
|
ordered = list(reversed(before_rows)) + [anchor] + after_rows
|
|
first_row = ordered[0] if ordered else None
|
|
last_row = ordered[-1] if ordered else None
|
|
|
|
has_more_before = False
|
|
if first_row is not None and first_row.id is not None:
|
|
has_more_before = (
|
|
session.exec(
|
|
select(BotMessage.id)
|
|
.where(BotMessage.bot_id == bot_id, BotMessage.id < first_row.id)
|
|
.order_by(BotMessage.id.desc())
|
|
.limit(1)
|
|
).first()
|
|
is not None
|
|
)
|
|
|
|
has_more_after = False
|
|
if last_row is not None and last_row.id is not None:
|
|
has_more_after = (
|
|
session.exec(
|
|
select(BotMessage.id)
|
|
.where(BotMessage.bot_id == bot_id, BotMessage.id > last_row.id)
|
|
.order_by(BotMessage.id.asc())
|
|
.limit(1)
|
|
).first()
|
|
is not None
|
|
)
|
|
|
|
return {
|
|
"items": [self._serialize_bot_message_row(bot_id, row) for row in ordered],
|
|
"anchor_id": anchor.id,
|
|
"resolved_ts": int(anchor.created_at.timestamp() * 1000),
|
|
"matched_exact_date": matched_exact_date,
|
|
"has_more_before": has_more_before,
|
|
"has_more_after": has_more_after,
|
|
}
|
|
|
|
def update_feedback(
|
|
self,
|
|
*,
|
|
session: Session,
|
|
bot_id: str,
|
|
message_id: int,
|
|
feedback: Optional[str],
|
|
) -> Dict[str, Any]:
|
|
self._require_bot(session=session, bot_id=bot_id)
|
|
row = session.get(BotMessage, message_id)
|
|
if not row or row.bot_id != bot_id:
|
|
raise HTTPException(status_code=404, detail="Message not found")
|
|
if row.role != "assistant":
|
|
raise HTTPException(status_code=400, detail="Only assistant messages support feedback")
|
|
|
|
raw = str(feedback or "").strip().lower()
|
|
if raw in {"", "none", "null"}:
|
|
row.feedback = None
|
|
row.feedback_at = None
|
|
elif raw in {"up", "down"}:
|
|
row.feedback = raw
|
|
row.feedback_at = datetime.utcnow()
|
|
else:
|
|
raise HTTPException(status_code=400, detail="feedback must be 'up' or 'down'")
|
|
|
|
session.add(row)
|
|
session.commit()
|
|
self._invalidate_bot_messages_cache(bot_id)
|
|
return {
|
|
"status": "updated",
|
|
"bot_id": bot_id,
|
|
"message_id": row.id,
|
|
"feedback": row.feedback,
|
|
"feedback_at": row.feedback_at.isoformat() if row.feedback_at else None,
|
|
}
|