import json import math import re import uuid from collections import defaultdict from datetime import datetime, timedelta from typing import Any, Dict, List, Optional from sqlalchemy import func from sqlmodel import Session, select from models.platform import BotRequestUsage from schemas.platform import ( PlatformUsageAnalytics, PlatformUsageAnalyticsSeries, PlatformUsageItem, PlatformUsageResponse, PlatformUsageSummary, ) def _utcnow() -> datetime: return datetime.utcnow() def estimate_tokens(text: str) -> int: content = str(text or "").strip() if not content: return 0 pieces = re.findall(r"[\u4e00-\u9fff]|[A-Za-z0-9_]+|[^\s]", content) total = 0 for piece in pieces: if re.fullmatch(r"[\u4e00-\u9fff]", piece): total += 1 elif re.fullmatch(r"[A-Za-z0-9_]+", piece): total += max(1, math.ceil(len(piece) / 4)) else: total += 1 return max(1, total) def create_usage_request( session: Session, bot_id: str, command: str, attachments: Optional[List[str]] = None, channel: str = "dashboard", metadata: Optional[Dict[str, Any]] = None, provider: Optional[str] = None, model: Optional[str] = None, ) -> str: request_id = uuid.uuid4().hex rows = [str(item).strip() for item in (attachments or []) if str(item).strip()] input_tokens = estimate_tokens(command) usage = BotRequestUsage( bot_id=bot_id, request_id=request_id, channel=channel, status="PENDING", provider=(str(provider or "").strip() or None), model=(str(model or "").strip() or None), token_source="estimated", input_tokens=input_tokens, output_tokens=0, total_tokens=input_tokens, input_text_preview=str(command or "")[:400], attachments_json=json.dumps(rows, ensure_ascii=False) if rows else None, metadata_json=json.dumps(metadata or {}, ensure_ascii=False), started_at=_utcnow(), created_at=_utcnow(), updated_at=_utcnow(), ) session.add(usage) session.flush() return request_id def _find_latest_pending_usage(session: Session, bot_id: str) -> Optional[BotRequestUsage]: stmt = ( select(BotRequestUsage) .where(BotRequestUsage.bot_id == bot_id) .where(BotRequestUsage.status == "PENDING") .order_by(BotRequestUsage.started_at.desc(), BotRequestUsage.id.desc()) .limit(1) ) return session.exec(stmt).first() def _find_pending_usage_by_request_id(session: Session, bot_id: str, request_id: str) -> Optional[BotRequestUsage]: if not request_id: return None stmt = ( select(BotRequestUsage) .where(BotRequestUsage.bot_id == bot_id) .where(BotRequestUsage.request_id == request_id) .where(BotRequestUsage.status == "PENDING") .order_by(BotRequestUsage.started_at.desc(), BotRequestUsage.id.desc()) .limit(1) ) return session.exec(stmt).first() def bind_usage_message( session: Session, bot_id: str, request_id: str, message_id: Optional[int], ) -> Optional[BotRequestUsage]: if not request_id or not message_id: return None usage_row = _find_pending_usage_by_request_id(session, bot_id, request_id) if not usage_row: return None usage_row.message_id = int(message_id) usage_row.updated_at = _utcnow() session.add(usage_row) return usage_row def finalize_usage_from_packet(session: Session, bot_id: str, packet: Dict[str, Any]) -> Optional[BotRequestUsage]: request_id = str(packet.get("request_id") or "").strip() usage_row = _find_pending_usage_by_request_id(session, bot_id, request_id) or _find_latest_pending_usage(session, bot_id) if not usage_row: return None raw_usage = packet.get("usage") input_tokens: Optional[int] = None output_tokens: Optional[int] = None source = "estimated" if isinstance(raw_usage, dict): for key in ("input_tokens", "prompt_tokens", "promptTokens"): if raw_usage.get(key) is not None: try: input_tokens = int(raw_usage.get(key) or 0) except Exception: input_tokens = None break for key in ("output_tokens", "completion_tokens", "completionTokens"): if raw_usage.get(key) is not None: try: output_tokens = int(raw_usage.get(key) or 0) except Exception: output_tokens = None break if input_tokens is not None or output_tokens is not None: source = "exact" text = str(packet.get("text") or packet.get("content") or "").strip() provider = str(packet.get("provider") or "").strip() model = str(packet.get("model") or "").strip() message_id = packet.get("message_id") if input_tokens is None: input_tokens = usage_row.input_tokens if output_tokens is None: output_tokens = estimate_tokens(text) if source == "exact": source = "mixed" if provider: usage_row.provider = provider[:120] if model: usage_row.model = model[:255] if message_id is not None: try: usage_row.message_id = int(message_id) except Exception: pass usage_row.output_tokens = max(0, int(output_tokens or 0)) usage_row.input_tokens = max(0, int(input_tokens or 0)) usage_row.total_tokens = usage_row.input_tokens + usage_row.output_tokens usage_row.output_text_preview = text[:400] if text else usage_row.output_text_preview usage_row.status = "COMPLETED" usage_row.token_source = source usage_row.completed_at = _utcnow() usage_row.updated_at = _utcnow() session.add(usage_row) return usage_row def fail_latest_usage(session: Session, bot_id: str, detail: str) -> Optional[BotRequestUsage]: usage_row = _find_latest_pending_usage(session, bot_id) if not usage_row: return None usage_row.status = "ERROR" usage_row.error_text = str(detail or "")[:500] usage_row.completed_at = _utcnow() usage_row.updated_at = _utcnow() session.add(usage_row) return usage_row def _build_usage_analytics( session: Session, bot_id: Optional[str] = None, window_days: int = 7, ) -> PlatformUsageAnalytics: safe_window_days = max(1, int(window_days or 0)) today = _utcnow().date() days = [today - timedelta(days=offset) for offset in range(safe_window_days - 1, -1, -1)] day_keys = [day.isoformat() for day in days] day_labels = [day.strftime("%m-%d") for day in days] first_day = days[0] first_started_at = datetime.combine(first_day, datetime.min.time()) stmt = select(BotRequestUsage.model, BotRequestUsage.started_at).where(BotRequestUsage.started_at >= first_started_at) if bot_id: stmt = stmt.where(BotRequestUsage.bot_id == bot_id) counts_by_model: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) total_requests = 0 for model_name, started_at in session.exec(stmt).all(): if not started_at: continue day_key = started_at.date().isoformat() if day_key not in day_keys: continue normalized_model = str(model_name or "").strip() or "Unknown" counts_by_model[normalized_model][day_key] += 1 total_requests += 1 series = [ PlatformUsageAnalyticsSeries( model=model_name, total_requests=sum(day_counts.values()), daily_counts=[int(day_counts.get(day_key, 0)) for day_key in day_keys], ) for model_name, day_counts in counts_by_model.items() ] series.sort(key=lambda item: (-item.total_requests, item.model.lower())) return PlatformUsageAnalytics( window_days=safe_window_days, days=day_labels, total_requests=total_requests, series=series, ) def list_usage( session: Session, bot_id: Optional[str] = None, limit: int = 100, offset: int = 0, ) -> Dict[str, Any]: safe_limit = max(1, min(int(limit), 500)) safe_offset = max(0, int(offset or 0)) stmt = ( select(BotRequestUsage) .order_by(BotRequestUsage.started_at.desc(), BotRequestUsage.id.desc()) .offset(safe_offset) .limit(safe_limit) ) summary_stmt = select( func.count(BotRequestUsage.id), func.coalesce(func.sum(BotRequestUsage.input_tokens), 0), func.coalesce(func.sum(BotRequestUsage.output_tokens), 0), func.coalesce(func.sum(BotRequestUsage.total_tokens), 0), ) total_stmt = select(func.count(BotRequestUsage.id)) if bot_id: stmt = stmt.where(BotRequestUsage.bot_id == bot_id) summary_stmt = summary_stmt.where(BotRequestUsage.bot_id == bot_id) total_stmt = total_stmt.where(BotRequestUsage.bot_id == bot_id) else: since = _utcnow() - timedelta(days=1) summary_stmt = summary_stmt.where(BotRequestUsage.created_at >= since) rows = session.exec(stmt).all() count, input_sum, output_sum, total_sum = session.exec(summary_stmt).one() total = int(session.exec(total_stmt).one() or 0) items = [ PlatformUsageItem( id=int(row.id or 0), bot_id=row.bot_id, message_id=int(row.message_id) if row.message_id is not None else None, request_id=row.request_id, channel=row.channel, status=row.status, provider=row.provider, model=row.model, token_source=row.token_source, content=row.input_text_preview or row.output_text_preview, input_tokens=int(row.input_tokens or 0), output_tokens=int(row.output_tokens or 0), total_tokens=int(row.total_tokens or 0), input_text_preview=row.input_text_preview, output_text_preview=row.output_text_preview, started_at=row.started_at.isoformat() + "Z", completed_at=row.completed_at.isoformat() + "Z" if row.completed_at else None, ).model_dump() for row in rows ] return PlatformUsageResponse( summary=PlatformUsageSummary( request_count=int(count or 0), input_tokens=int(input_sum or 0), output_tokens=int(output_sum or 0), total_tokens=int(total_sum or 0), ), items=[PlatformUsageItem.model_validate(item) for item in items], total=total, limit=safe_limit, offset=safe_offset, has_more=safe_offset + len(items) < total, analytics=_build_usage_analytics(session, bot_id=bot_id), ).model_dump()