import json import uuid from datetime import timedelta from typing import Any, Dict, List, Optional from sqlalchemy import func from sqlmodel import Session, select from models.platform import BotRequestUsage from schemas.platform import PlatformUsageItem, PlatformUsageResponse, PlatformUsageSummary from services.platform_common import estimate_tokens, utcnow def create_usage_request( session: Session, bot_id: str, command: str, attachments: Optional[List[str]] = None, channel: str = "dashboard", metadata: Optional[Dict[str, Any]] = None, provider: Optional[str] = None, model: Optional[str] = None, ) -> str: request_id = uuid.uuid4().hex rows = [str(item).strip() for item in (attachments or []) if str(item).strip()] input_tokens = estimate_tokens(command) usage = BotRequestUsage( bot_id=bot_id, request_id=request_id, channel=channel, status="PENDING", provider=(str(provider or "").strip() or None), model=(str(model or "").strip() or None), token_source="estimated", input_tokens=input_tokens, output_tokens=0, total_tokens=input_tokens, input_text_preview=str(command or "")[:400], attachments_json=json.dumps(rows, ensure_ascii=False) if rows else None, metadata_json=json.dumps(metadata or {}, ensure_ascii=False), started_at=utcnow(), created_at=utcnow(), updated_at=utcnow(), ) session.add(usage) session.flush() return request_id def bind_usage_message( session: Session, bot_id: str, request_id: str, message_id: Optional[int], ) -> Optional[BotRequestUsage]: if not request_id or not message_id: return None usage_row = find_pending_usage_by_request_id(session, bot_id, request_id) if not usage_row: return None usage_row.message_id = int(message_id) usage_row.updated_at = utcnow() session.add(usage_row) return usage_row def find_latest_pending_usage(session: Session, bot_id: str) -> Optional[BotRequestUsage]: stmt = ( select(BotRequestUsage) .where(BotRequestUsage.bot_id == bot_id) .where(BotRequestUsage.status == "PENDING") .order_by(BotRequestUsage.started_at.desc(), BotRequestUsage.id.desc()) .limit(1) ) return session.exec(stmt).first() def find_pending_usage_by_request_id(session: Session, bot_id: str, request_id: str) -> Optional[BotRequestUsage]: if not request_id: return None stmt = ( select(BotRequestUsage) .where(BotRequestUsage.bot_id == bot_id) .where(BotRequestUsage.request_id == request_id) .where(BotRequestUsage.status == "PENDING") .order_by(BotRequestUsage.started_at.desc(), BotRequestUsage.id.desc()) .limit(1) ) return session.exec(stmt).first() def finalize_usage_from_packet(session: Session, bot_id: str, packet: Dict[str, Any]) -> Optional[BotRequestUsage]: request_id = str(packet.get("request_id") or "").strip() usage_row = find_pending_usage_by_request_id(session, bot_id, request_id) or find_latest_pending_usage(session, bot_id) if not usage_row: return None raw_usage = packet.get("usage") input_tokens: Optional[int] = None output_tokens: Optional[int] = None source = "estimated" if isinstance(raw_usage, dict): for key in ("input_tokens", "prompt_tokens", "promptTokens"): if raw_usage.get(key) is not None: try: input_tokens = int(raw_usage.get(key) or 0) except Exception: input_tokens = None break for key in ("output_tokens", "completion_tokens", "completionTokens"): if raw_usage.get(key) is not None: try: output_tokens = int(raw_usage.get(key) or 0) except Exception: output_tokens = None break if input_tokens is not None or output_tokens is not None: source = "exact" text = str(packet.get("text") or packet.get("content") or "").strip() provider = str(packet.get("provider") or "").strip() model = str(packet.get("model") or "").strip() message_id = packet.get("message_id") if input_tokens is None: input_tokens = usage_row.input_tokens if output_tokens is None: output_tokens = estimate_tokens(text) if source == "exact": source = "mixed" if provider: usage_row.provider = provider[:120] if model: usage_row.model = model[:255] if message_id is not None: try: usage_row.message_id = int(message_id) except Exception: pass usage_row.output_tokens = max(0, int(output_tokens or 0)) usage_row.input_tokens = max(0, int(input_tokens or 0)) usage_row.total_tokens = usage_row.input_tokens + usage_row.output_tokens usage_row.output_text_preview = text[:400] if text else usage_row.output_text_preview usage_row.status = "COMPLETED" usage_row.token_source = source usage_row.completed_at = utcnow() usage_row.updated_at = utcnow() session.add(usage_row) return usage_row def fail_latest_usage(session: Session, bot_id: str, detail: str) -> Optional[BotRequestUsage]: usage_row = find_latest_pending_usage(session, bot_id) if not usage_row: return None usage_row.status = "ERROR" usage_row.error_text = str(detail or "")[:500] usage_row.completed_at = utcnow() usage_row.updated_at = utcnow() session.add(usage_row) return usage_row def list_usage( session: Session, bot_id: Optional[str] = None, limit: int = 100, offset: int = 0, ) -> Dict[str, Any]: safe_limit = max(1, min(int(limit), 500)) safe_offset = max(0, int(offset or 0)) stmt = ( select(BotRequestUsage) .order_by(BotRequestUsage.started_at.desc(), BotRequestUsage.id.desc()) .offset(safe_offset) .limit(safe_limit) ) summary_stmt = select( func.count(BotRequestUsage.id), func.coalesce(func.sum(BotRequestUsage.input_tokens), 0), func.coalesce(func.sum(BotRequestUsage.output_tokens), 0), func.coalesce(func.sum(BotRequestUsage.total_tokens), 0), ) total_stmt = select(func.count(BotRequestUsage.id)) if bot_id: stmt = stmt.where(BotRequestUsage.bot_id == bot_id) summary_stmt = summary_stmt.where(BotRequestUsage.bot_id == bot_id) total_stmt = total_stmt.where(BotRequestUsage.bot_id == bot_id) else: since = utcnow() - timedelta(days=1) summary_stmt = summary_stmt.where(BotRequestUsage.created_at >= since) rows = session.exec(stmt).all() count, input_sum, output_sum, total_sum = session.exec(summary_stmt).one() total = int(session.exec(total_stmt).one() or 0) items = [ PlatformUsageItem( id=int(row.id or 0), bot_id=row.bot_id, message_id=int(row.message_id) if row.message_id is not None else None, request_id=row.request_id, channel=row.channel, status=row.status, provider=row.provider, model=row.model, token_source=row.token_source, content=row.input_text_preview or row.output_text_preview, input_tokens=int(row.input_tokens or 0), output_tokens=int(row.output_tokens or 0), total_tokens=int(row.total_tokens or 0), input_text_preview=row.input_text_preview, output_text_preview=row.output_text_preview, started_at=row.started_at.isoformat() + "Z", completed_at=row.completed_at.isoformat() + "Z" if row.completed_at else None, ).model_dump() for row in rows ] return PlatformUsageResponse( summary=PlatformUsageSummary( request_count=int(count or 0), input_tokens=int(input_sum or 0), output_tokens=int(output_sum or 0), total_tokens=int(total_sum or 0), ), items=[PlatformUsageItem.model_validate(item) for item in items], total=total, limit=safe_limit, offset=safe_offset, has_more=safe_offset + len(items) < total, ).model_dump()