from typing import Any, Callable, List, Mapping
from pydantic import BaseModel
from typing_extensions import Self
from ..base import ChatAgent, TerminationCondition
from ..messages import AgentEvent, ChatMessage
from ..state import RoundRobinManagerState
from .base_group_chat import BaseGroupChat
from .base_group_chat_manager import BaseGroupChatManager
class RoundRobinGroupChatManager(BaseGroupChatManager):
"""A group chat manager that selects the next speaker in a round-robin fashion."""
def __init__(
self,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_descriptions: List[str],
termination_condition: TerminationCondition | None,
max_turns: int | None = None,
) -> None:
super().__init__(
group_topic_type,
output_topic_type,
participant_topic_types,
participant_descriptions,
termination_condition,
max_turns,
)
self._next_speaker_index = 0
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
pass
async def reset(self) -> None:
self._current_turn = 0
self._message_thread.clear()
if self._termination_condition is not None:
await self._termination_condition.reset()
self._next_speaker_index = 0
async def save_state(self) -> Mapping[str, Any]:
state = RoundRobinManagerState(
message_thread=list(self._message_thread),
current_turn=self._current_turn,
next_speaker_index=self._next_speaker_index,
)
return state.model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
round_robin_state = RoundRobinManagerState.model_validate(state)
self._message_thread = list(round_robin_state.message_thread)
self._current_turn = round_robin_state.current_turn
self._next_speaker_index = round_robin_state.next_speaker_index
async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
"""Select a speaker from the participants in a round-robin fashion."""
current_speaker_index = self._next_speaker_index
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_topic_types)
current_speaker = self._participant_topic_types[current_speaker_index]
return current_speaker
[docs]
class RoundRobinGroupChat(BaseGroupChat):
"""A team that runs a group chat with participants taking turns in a round-robin fashion
to publish a message to all.
If a single participant is in the team, the participant will be the only speaker.
Args:
participants (List[BaseChatAgent]): The participants in the group chat.
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
Without a termination condition, the group chat will run indefinitely.
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit.
Raises:
ValueError: If no participants are provided or if participant names are not unique.
Examples:
A team with one participant with tools:
.. code-block:: python
import asyncio
from agentopera.models.openai import OpenAIChatCompletionClient
from agentopera.chatflow.agents import AssistantAgent
from agentopera.chatflow.teams import RoundRobinGroupChat
from agentopera.chatflow.conditions import TextMentionTermination
from agentopera.chatflow.ui import Console
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
async def get_weather(location: str) -> str:
return f"The weather in {location} is sunny."
assistant = AssistantAgent(
"Assistant",
model_client=model_client,
tools=[get_weather],
)
termination = TextMentionTermination("TERMINATE")
team = RoundRobinGroupChat([assistant], termination_condition=termination)
await Console(team.run_stream(task="What's the weather in New York?"))
asyncio.run(main())
A team with multiple participants:
.. code-block:: python
import asyncio
from agentopera.models.openai import OpenAIChatCompletionClient
from agentopera.chatflow.agents import AssistantAgent
from agentopera.chatflow.teams import RoundRobinGroupChat
from agentopera.chatflow.conditions import TextMentionTermination
from agentopera.chatflow.ui import Console
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent1 = AssistantAgent("Assistant1", model_client=model_client)
agent2 = AssistantAgent("Assistant2", model_client=model_client)
termination = TextMentionTermination("TERMINATE")
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
await Console(team.run_stream(task="Tell me some jokes."))
asyncio.run(main())
"""
def __init__(
self,
participants: List[ChatAgent],
termination_condition: TerminationCondition | None = None,
max_turns: int | None = None,
) -> None:
super().__init__(
participants,
group_chat_manager_class=RoundRobinGroupChatManager,
termination_condition=termination_condition,
max_turns=max_turns,
)
def _create_group_chat_manager_factory(
self,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_descriptions: List[str],
termination_condition: TerminationCondition | None,
max_turns: int | None,
) -> Callable[[], RoundRobinGroupChatManager]:
def _factory() -> RoundRobinGroupChatManager:
return RoundRobinGroupChatManager(
group_topic_type,
output_topic_type,
participant_topic_types,
participant_descriptions,
termination_condition,
max_turns,
)
return _factory