my_agent/providers/openai_compatible.py

142 lines
5.0 KiB
Python

from __future__ import annotations
import json
import uuid
from typing import Any, Dict, Iterator, List, Optional
from .base import AgentProvider, AssistantTurn, StreamEvent, ToolCall
class OpenAICompatibleProvider(AgentProvider):
"""Thin adapter for OpenAI-compatible chat-completions endpoints."""
def __init__(
self,
*,
model: str,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
temperature: float = 0.2,
max_tokens: Optional[int] = None,
timeout: float = 120.0,
) -> None:
self.model = model
self.api_key = api_key
self.base_url = base_url
self.temperature = temperature
self.max_tokens = max_tokens
self.timeout = timeout
def _client(self):
from openai import OpenAI
kwargs: Dict[str, Any] = {}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.base_url:
kwargs["base_url"] = self.base_url
kwargs["timeout"] = self.timeout
return OpenAI(**kwargs)
def _build_request(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> Dict[str, Any]:
request: Dict[str, Any] = {
"model": self.model,
"messages": messages,
"temperature": self.temperature,
}
if tools:
request["tools"] = tools
request["tool_choice"] = "auto"
request["parallel_tool_calls"] = False
if self.max_tokens is not None:
request["max_tokens"] = self.max_tokens
return request
def generate(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> AssistantTurn:
client = self._client()
request = self._build_request(messages, tools)
response = client.chat.completions.create(**request)
message = response.choices[0].message
reasoning = getattr(message, "reasoning", "") or ""
tool_calls: List[ToolCall] = []
for item in message.tool_calls or []:
raw_args = item.function.arguments or "{}"
arguments = _parse_tool_arguments(raw_args)
tool_calls.append(
ToolCall(
id=item.id or f"call_{uuid.uuid4().hex}",
name=item.function.name,
arguments=arguments,
)
)
return AssistantTurn(content=message.content or "", reasoning=reasoning, tool_calls=tool_calls, raw=response)
def stream_generate(
self,
messages: List[Dict[str, Any]],
tools: List[Dict[str, Any]],
) -> Iterator[StreamEvent]:
client = self._client()
request = self._build_request(messages, tools)
request["stream"] = True
stream = client.chat.completions.create(**request)
content_parts: List[str] = []
reasoning_parts: List[str] = []
tool_buffers: Dict[int, Dict[str, str]] = {}
for chunk in stream:
choice = chunk.choices[0] if chunk.choices else None
if choice is None:
continue
delta = choice.delta
reasoning_delta = getattr(delta, "reasoning", None)
if reasoning_delta:
reasoning_parts.append(reasoning_delta)
yield StreamEvent(type="reasoning", delta=reasoning_delta, raw=chunk)
content_delta = getattr(delta, "content", None)
if content_delta:
content_parts.append(content_delta)
yield StreamEvent(type="content", delta=content_delta, raw=chunk)
for tool_delta in getattr(delta, "tool_calls", None) or []:
index = getattr(tool_delta, "index", 0) or 0
buffer = tool_buffers.setdefault(index, {"id": "", "name": "", "arguments": ""})
if getattr(tool_delta, "id", None):
buffer["id"] = tool_delta.id
fn = getattr(tool_delta, "function", None)
if fn is not None:
if getattr(fn, "name", None):
buffer["name"] = fn.name
if getattr(fn, "arguments", None):
buffer["arguments"] += fn.arguments
tool_calls: List[ToolCall] = []
for index in sorted(tool_buffers):
item = tool_buffers[index]
tool_calls.append(
ToolCall(
id=item["id"] or f"call_{uuid.uuid4().hex}",
name=item["name"],
arguments=_parse_tool_arguments(item["arguments"] or "{}"),
)
)
turn = AssistantTurn(
content="".join(content_parts),
reasoning="".join(reasoning_parts),
tool_calls=tool_calls,
)
yield StreamEvent(type="turn", turn=turn)
def _parse_tool_arguments(raw_args: str) -> Dict[str, Any]:
try:
return json.loads(raw_args)
except json.JSONDecodeError:
return {"raw_arguments": raw_args}