27 lines
951 B
Python
27 lines
951 B
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Dict, Iterator, List
|
||
|
|
|
||
|
|
from .base import AgentProvider, AssistantTurn, StreamEvent
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ScriptedProvider(AgentProvider):
|
||
|
|
"""Deterministic provider for tests and demos."""
|
||
|
|
|
||
|
|
turns: List[AssistantTurn]
|
||
|
|
|
||
|
|
def generate(self, messages: List[Dict[str, str]], tools: List[Dict[str, str]]) -> AssistantTurn:
|
||
|
|
if not self.turns:
|
||
|
|
raise RuntimeError("ScriptedProvider ran out of scripted turns")
|
||
|
|
return self.turns.pop(0)
|
||
|
|
|
||
|
|
def stream_generate(self, messages: List[Dict[str, str]], tools: List[Dict[str, str]]) -> Iterator[StreamEvent]:
|
||
|
|
turn = self.generate(messages, tools)
|
||
|
|
if turn.reasoning:
|
||
|
|
yield StreamEvent(type="reasoning", delta=turn.reasoning)
|
||
|
|
if turn.content:
|
||
|
|
yield StreamEvent(type="content", delta=turn.content)
|
||
|
|
yield StreamEvent(type="turn", turn=turn)
|