51 lines
1.4 KiB
Python
51 lines
1.4 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, Iterator, List, Optional
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class ToolCall:
|
|
id: str
|
|
name: str
|
|
arguments: Dict[str, Any]
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class AssistantTurn:
|
|
content: str = ""
|
|
reasoning: str = ""
|
|
tool_calls: List[ToolCall] = field(default_factory=list)
|
|
raw: Any = None
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class StreamEvent:
|
|
type: str
|
|
delta: str = ""
|
|
turn: Optional[AssistantTurn] = None
|
|
tool_call: Optional[ToolCall] = None
|
|
raw: Any = None
|
|
|
|
|
|
class AgentProvider(ABC):
|
|
"""LLM provider interface used by the core agent loop."""
|
|
|
|
@abstractmethod
|
|
def generate(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> AssistantTurn:
|
|
raise NotImplementedError
|
|
|
|
def stream_generate(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
tools: List[Dict[str, Any]],
|
|
) -> Iterator[StreamEvent]:
|
|
"""Default streaming fallback for providers without native streaming."""
|
|
turn = self.generate(messages, tools)
|
|
if turn.reasoning:
|
|
yield StreamEvent(type="reasoning", delta=turn.reasoning, raw=turn.raw)
|
|
if turn.content:
|
|
yield StreamEvent(type="content", delta=turn.content, raw=turn.raw)
|
|
yield StreamEvent(type="turn", turn=turn, raw=turn.raw)
|