my_agent/providers/scripted.py

27 lines
951 B
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Iterator, List
from .base import AgentProvider, AssistantTurn, StreamEvent
@dataclass
class ScriptedProvider(AgentProvider):
"""Deterministic provider for tests and demos."""
turns: List[AssistantTurn]
def generate(self, messages: List[Dict[str, str]], tools: List[Dict[str, str]]) -> AssistantTurn:
if not self.turns:
raise RuntimeError("ScriptedProvider ran out of scripted turns")
return self.turns.pop(0)
def stream_generate(self, messages: List[Dict[str, str]], tools: List[Dict[str, str]]) -> Iterator[StreamEvent]:
turn = self.generate(messages, tools)
if turn.reasoning:
yield StreamEvent(type="reasoning", delta=turn.reasoning)
if turn.content:
yield StreamEvent(type="content", delta=turn.content)
yield StreamEvent(type="turn", turn=turn)