my_agent/providers/base.py

51 lines
1.4 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Optional
@dataclass(slots=True)
class ToolCall:
id: str
name: str
arguments: Dict[str, Any]
@dataclass(slots=True)
class AssistantTurn:
content: str = ""
reasoning: str = ""
tool_calls: List[ToolCall] = field(default_factory=list)
raw: Any = None
@dataclass(slots=True)
class StreamEvent:
type: str
delta: str = ""
turn: Optional[AssistantTurn] = None
tool_call: Optional[ToolCall] = None
raw: Any = None
class AgentProvider(ABC):
"""LLM provider interface used by the core agent loop."""
@abstractmethod
def generate(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> AssistantTurn:
raise NotImplementedError
def stream_generate(
self,
messages: List[Dict[str, Any]],
tools: List[Dict[str, Any]],
) -> Iterator[StreamEvent]:
"""Default streaming fallback for providers without native streaming."""
turn = self.generate(messages, tools)
if turn.reasoning:
yield StreamEvent(type="reasoning", delta=turn.reasoning, raw=turn.raw)
if turn.content:
yield StreamEvent(type="content", delta=turn.content, raw=turn.raw)
yield StreamEvent(type="turn", turn=turn, raw=turn.raw)