87 lines
3.5 KiB
Python
87 lines
3.5 KiB
Python
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
import httpx
|
|
from fastapi import HTTPException
|
|
|
|
|
|
class ProviderTestService:
|
|
def __init__(self, *, provider_defaults: Optional[Callable[[str], Tuple[str, str]]] = None) -> None:
|
|
self._provider_defaults = provider_defaults or self.provider_defaults
|
|
|
|
@staticmethod
|
|
def provider_defaults(provider: str) -> tuple[str, str]:
|
|
normalized = provider.lower().strip()
|
|
if normalized in {"openrouter"}:
|
|
return "openrouter", "https://openrouter.ai/api/v1"
|
|
if normalized in {"dashscope", "aliyun", "qwen", "aliyun-qwen"}:
|
|
return "dashscope", "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
if normalized in {"xunfei", "iflytek", "xfyun"}:
|
|
return "openai", "https://spark-api-open.xf-yun.com/v1"
|
|
if normalized in {"kimi", "moonshot"}:
|
|
return "kimi", "https://api.moonshot.cn/v1"
|
|
if normalized in {"minimax"}:
|
|
return "minimax", "https://api.minimax.chat/v1"
|
|
if normalized in {"vllm"}:
|
|
return "vllm", ""
|
|
return normalized, ""
|
|
|
|
async def test_provider(self, *, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
provider = str(payload.get("provider") or "").strip()
|
|
api_key = str(payload.get("api_key") or "").strip()
|
|
model = str(payload.get("model") or "").strip()
|
|
api_base = str(payload.get("api_base") or "").strip()
|
|
|
|
if not provider or not api_key:
|
|
raise HTTPException(status_code=400, detail="provider and api_key are required")
|
|
|
|
normalized_provider, default_base = self._provider_defaults(provider)
|
|
base = (api_base or default_base).rstrip("/")
|
|
|
|
if normalized_provider not in {"openrouter", "dashscope", "kimi", "minimax", "openai", "deepseek", "vllm"}:
|
|
raise HTTPException(status_code=400, detail=f"provider not supported for test: {provider}")
|
|
|
|
if not base:
|
|
raise HTTPException(status_code=400, detail=f"api_base is required for provider: {provider}")
|
|
|
|
headers = {"Authorization": f"Bearer {api_key}"}
|
|
timeout = httpx.Timeout(20.0, connect=10.0)
|
|
url = f"{base}/models"
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
resp = await client.get(url, headers=headers)
|
|
|
|
if resp.status_code >= 400:
|
|
return {
|
|
"ok": False,
|
|
"provider": normalized_provider,
|
|
"status_code": resp.status_code,
|
|
"detail": resp.text[:500],
|
|
}
|
|
|
|
data = resp.json()
|
|
models_raw = data.get("data", []) if isinstance(data, dict) else []
|
|
model_ids: List[str] = []
|
|
for item in models_raw[:20]:
|
|
if isinstance(item, dict) and item.get("id"):
|
|
model_ids.append(str(item["id"]))
|
|
|
|
model_hint = ""
|
|
if model:
|
|
model_hint = "model_found" if any(model in value for value in model_ids) else "model_not_listed"
|
|
|
|
return {
|
|
"ok": True,
|
|
"provider": normalized_provider,
|
|
"endpoint": url,
|
|
"models_preview": model_ids[:8],
|
|
"model_hint": model_hint,
|
|
}
|
|
except Exception as exc:
|
|
return {
|
|
"ok": False,
|
|
"provider": normalized_provider,
|
|
"endpoint": url,
|
|
"detail": str(exc),
|
|
}
|