231 lines
8.2 KiB
Python
231 lines
8.2 KiB
Python
import json
|
|
import uuid
|
|
from datetime import timedelta
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from sqlalchemy import func
|
|
from sqlmodel import Session, select
|
|
|
|
from models.platform import BotRequestUsage
|
|
from schemas.platform import PlatformUsageItem, PlatformUsageResponse, PlatformUsageSummary
|
|
|
|
from services.platform_common import estimate_tokens, utcnow
|
|
|
|
|
|
def create_usage_request(
|
|
session: Session,
|
|
bot_id: str,
|
|
command: str,
|
|
attachments: Optional[List[str]] = None,
|
|
channel: str = "dashboard",
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
provider: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
) -> str:
|
|
request_id = uuid.uuid4().hex
|
|
rows = [str(item).strip() for item in (attachments or []) if str(item).strip()]
|
|
input_tokens = estimate_tokens(command)
|
|
usage = BotRequestUsage(
|
|
bot_id=bot_id,
|
|
request_id=request_id,
|
|
channel=channel,
|
|
status="PENDING",
|
|
provider=(str(provider or "").strip() or None),
|
|
model=(str(model or "").strip() or None),
|
|
token_source="estimated",
|
|
input_tokens=input_tokens,
|
|
output_tokens=0,
|
|
total_tokens=input_tokens,
|
|
input_text_preview=str(command or "")[:400],
|
|
attachments_json=json.dumps(rows, ensure_ascii=False) if rows else None,
|
|
metadata_json=json.dumps(metadata or {}, ensure_ascii=False),
|
|
started_at=utcnow(),
|
|
created_at=utcnow(),
|
|
updated_at=utcnow(),
|
|
)
|
|
session.add(usage)
|
|
session.flush()
|
|
return request_id
|
|
|
|
|
|
def bind_usage_message(
|
|
session: Session,
|
|
bot_id: str,
|
|
request_id: str,
|
|
message_id: Optional[int],
|
|
) -> Optional[BotRequestUsage]:
|
|
if not request_id or not message_id:
|
|
return None
|
|
usage_row = find_pending_usage_by_request_id(session, bot_id, request_id)
|
|
if not usage_row:
|
|
return None
|
|
usage_row.message_id = int(message_id)
|
|
usage_row.updated_at = utcnow()
|
|
session.add(usage_row)
|
|
return usage_row
|
|
|
|
|
|
def find_latest_pending_usage(session: Session, bot_id: str) -> Optional[BotRequestUsage]:
|
|
stmt = (
|
|
select(BotRequestUsage)
|
|
.where(BotRequestUsage.bot_id == bot_id)
|
|
.where(BotRequestUsage.status == "PENDING")
|
|
.order_by(BotRequestUsage.started_at.desc(), BotRequestUsage.id.desc())
|
|
.limit(1)
|
|
)
|
|
return session.exec(stmt).first()
|
|
|
|
|
|
def find_pending_usage_by_request_id(session: Session, bot_id: str, request_id: str) -> Optional[BotRequestUsage]:
|
|
if not request_id:
|
|
return None
|
|
stmt = (
|
|
select(BotRequestUsage)
|
|
.where(BotRequestUsage.bot_id == bot_id)
|
|
.where(BotRequestUsage.request_id == request_id)
|
|
.where(BotRequestUsage.status == "PENDING")
|
|
.order_by(BotRequestUsage.started_at.desc(), BotRequestUsage.id.desc())
|
|
.limit(1)
|
|
)
|
|
return session.exec(stmt).first()
|
|
|
|
|
|
def finalize_usage_from_packet(session: Session, bot_id: str, packet: Dict[str, Any]) -> Optional[BotRequestUsage]:
|
|
request_id = str(packet.get("request_id") or "").strip()
|
|
usage_row = find_pending_usage_by_request_id(session, bot_id, request_id) or find_latest_pending_usage(session, bot_id)
|
|
if not usage_row:
|
|
return None
|
|
|
|
raw_usage = packet.get("usage")
|
|
input_tokens: Optional[int] = None
|
|
output_tokens: Optional[int] = None
|
|
source = "estimated"
|
|
if isinstance(raw_usage, dict):
|
|
for key in ("input_tokens", "prompt_tokens", "promptTokens"):
|
|
if raw_usage.get(key) is not None:
|
|
try:
|
|
input_tokens = int(raw_usage.get(key) or 0)
|
|
except Exception:
|
|
input_tokens = None
|
|
break
|
|
for key in ("output_tokens", "completion_tokens", "completionTokens"):
|
|
if raw_usage.get(key) is not None:
|
|
try:
|
|
output_tokens = int(raw_usage.get(key) or 0)
|
|
except Exception:
|
|
output_tokens = None
|
|
break
|
|
if input_tokens is not None or output_tokens is not None:
|
|
source = "exact"
|
|
|
|
text = str(packet.get("text") or packet.get("content") or "").strip()
|
|
provider = str(packet.get("provider") or "").strip()
|
|
model = str(packet.get("model") or "").strip()
|
|
message_id = packet.get("message_id")
|
|
if input_tokens is None:
|
|
input_tokens = usage_row.input_tokens
|
|
if output_tokens is None:
|
|
output_tokens = estimate_tokens(text)
|
|
if source == "exact":
|
|
source = "mixed"
|
|
|
|
if provider:
|
|
usage_row.provider = provider[:120]
|
|
if model:
|
|
usage_row.model = model[:255]
|
|
if message_id is not None:
|
|
try:
|
|
usage_row.message_id = int(message_id)
|
|
except Exception:
|
|
pass
|
|
usage_row.output_tokens = max(0, int(output_tokens or 0))
|
|
usage_row.input_tokens = max(0, int(input_tokens or 0))
|
|
usage_row.total_tokens = usage_row.input_tokens + usage_row.output_tokens
|
|
usage_row.output_text_preview = text[:400] if text else usage_row.output_text_preview
|
|
usage_row.status = "COMPLETED"
|
|
usage_row.token_source = source
|
|
usage_row.completed_at = utcnow()
|
|
usage_row.updated_at = utcnow()
|
|
session.add(usage_row)
|
|
return usage_row
|
|
|
|
|
|
def fail_latest_usage(session: Session, bot_id: str, detail: str) -> Optional[BotRequestUsage]:
|
|
usage_row = find_latest_pending_usage(session, bot_id)
|
|
if not usage_row:
|
|
return None
|
|
usage_row.status = "ERROR"
|
|
usage_row.error_text = str(detail or "")[:500]
|
|
usage_row.completed_at = utcnow()
|
|
usage_row.updated_at = utcnow()
|
|
session.add(usage_row)
|
|
return usage_row
|
|
|
|
|
|
def list_usage(
|
|
session: Session,
|
|
bot_id: Optional[str] = None,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> Dict[str, Any]:
|
|
safe_limit = max(1, min(int(limit), 500))
|
|
safe_offset = max(0, int(offset or 0))
|
|
stmt = (
|
|
select(BotRequestUsage)
|
|
.order_by(BotRequestUsage.started_at.desc(), BotRequestUsage.id.desc())
|
|
.offset(safe_offset)
|
|
.limit(safe_limit)
|
|
)
|
|
summary_stmt = select(
|
|
func.count(BotRequestUsage.id),
|
|
func.coalesce(func.sum(BotRequestUsage.input_tokens), 0),
|
|
func.coalesce(func.sum(BotRequestUsage.output_tokens), 0),
|
|
func.coalesce(func.sum(BotRequestUsage.total_tokens), 0),
|
|
)
|
|
total_stmt = select(func.count(BotRequestUsage.id))
|
|
if bot_id:
|
|
stmt = stmt.where(BotRequestUsage.bot_id == bot_id)
|
|
summary_stmt = summary_stmt.where(BotRequestUsage.bot_id == bot_id)
|
|
total_stmt = total_stmt.where(BotRequestUsage.bot_id == bot_id)
|
|
else:
|
|
since = utcnow() - timedelta(days=1)
|
|
summary_stmt = summary_stmt.where(BotRequestUsage.created_at >= since)
|
|
rows = session.exec(stmt).all()
|
|
count, input_sum, output_sum, total_sum = session.exec(summary_stmt).one()
|
|
total = int(session.exec(total_stmt).one() or 0)
|
|
items = [
|
|
PlatformUsageItem(
|
|
id=int(row.id or 0),
|
|
bot_id=row.bot_id,
|
|
message_id=int(row.message_id) if row.message_id is not None else None,
|
|
request_id=row.request_id,
|
|
channel=row.channel,
|
|
status=row.status,
|
|
provider=row.provider,
|
|
model=row.model,
|
|
token_source=row.token_source,
|
|
content=row.input_text_preview or row.output_text_preview,
|
|
input_tokens=int(row.input_tokens or 0),
|
|
output_tokens=int(row.output_tokens or 0),
|
|
total_tokens=int(row.total_tokens or 0),
|
|
input_text_preview=row.input_text_preview,
|
|
output_text_preview=row.output_text_preview,
|
|
started_at=row.started_at.isoformat() + "Z",
|
|
completed_at=row.completed_at.isoformat() + "Z" if row.completed_at else None,
|
|
).model_dump()
|
|
for row in rows
|
|
]
|
|
return PlatformUsageResponse(
|
|
summary=PlatformUsageSummary(
|
|
request_count=int(count or 0),
|
|
input_tokens=int(input_sum or 0),
|
|
output_tokens=int(output_sum or 0),
|
|
total_tokens=int(total_sum or 0),
|
|
),
|
|
items=[PlatformUsageItem.model_validate(item) for item in items],
|
|
total=total,
|
|
limit=safe_limit,
|
|
offset=safe_offset,
|
|
has_more=safe_offset + len(items) < total,
|
|
).model_dump()
|