from typing import Any, Callable, List, Mapping
from pydantic import BaseModel
from typing_extensions import Self
from ..base import ChatAgent, TerminationCondition
from ..messages import AgentEvent, ChatMessage
from ..state import RoundRobinManagerState
from .base_group_chat import BaseGroupChat
from .base_group_chat_manager import BaseGroupChatManager
class RoundRobinGroupChatManager(BaseGroupChatManager):
    """A group chat manager that selects the next speaker in a round-robin fashion."""
    def __init__(
        self,
        group_topic_type: str,
        output_topic_type: str,
        participant_topic_types: List[str],
        participant_descriptions: List[str],
        termination_condition: TerminationCondition | None,
        max_turns: int | None = None,
    ) -> None:
        super().__init__(
            group_topic_type,
            output_topic_type,
            participant_topic_types,
            participant_descriptions,
            termination_condition,
            max_turns,
        )
        self._next_speaker_index = 0
    async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
        pass
    async def reset(self) -> None:
        self._current_turn = 0
        self._message_thread.clear()
        if self._termination_condition is not None:
            await self._termination_condition.reset()
        self._next_speaker_index = 0
    async def save_state(self) -> Mapping[str, Any]:
        state = RoundRobinManagerState(
            message_thread=list(self._message_thread),
            current_turn=self._current_turn,
            next_speaker_index=self._next_speaker_index,
        )
        return state.model_dump()
    async def load_state(self, state: Mapping[str, Any]) -> None:
        round_robin_state = RoundRobinManagerState.model_validate(state)
        self._message_thread = list(round_robin_state.message_thread)
        self._current_turn = round_robin_state.current_turn
        self._next_speaker_index = round_robin_state.next_speaker_index
    async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
        """Select a speaker from the participants in a round-robin fashion."""
        current_speaker_index = self._next_speaker_index
        self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_topic_types)
        current_speaker = self._participant_topic_types[current_speaker_index]
        return current_speaker
[docs]
class RoundRobinGroupChat(BaseGroupChat):
    """A team that runs a group chat with participants taking turns in a round-robin fashion
    to publish a message to all.
    If a single participant is in the team, the participant will be the only speaker.
    Args:
        participants (List[BaseChatAgent]): The participants in the group chat.
        termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
            Without a termination condition, the group chat will run indefinitely.
        max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit.
    Raises:
        ValueError: If no participants are provided or if participant names are not unique.
    Examples:
    A team with one participant with tools:
        .. code-block:: python
            import asyncio
            from agentopera.models.openai import OpenAIChatCompletionClient
            from agentopera.chatflow.agents import AssistantAgent
            from agentopera.chatflow.teams import RoundRobinGroupChat
            from agentopera.chatflow.conditions import TextMentionTermination
            from agentopera.chatflow.ui import Console
            async def main() -> None:
                model_client = OpenAIChatCompletionClient(model="gpt-4o")
                async def get_weather(location: str) -> str:
                    return f"The weather in {location} is sunny."
                assistant = AssistantAgent(
                    "Assistant",
                    model_client=model_client,
                    tools=[get_weather],
                )
                termination = TextMentionTermination("TERMINATE")
                team = RoundRobinGroupChat([assistant], termination_condition=termination)
                await Console(team.run_stream(task="What's the weather in New York?"))
            asyncio.run(main())
    A team with multiple participants:
        .. code-block:: python
            import asyncio
            from agentopera.models.openai import OpenAIChatCompletionClient
            from agentopera.chatflow.agents import AssistantAgent
            from agentopera.chatflow.teams import RoundRobinGroupChat
            from agentopera.chatflow.conditions import TextMentionTermination
            from agentopera.chatflow.ui import Console
            async def main() -> None:
                model_client = OpenAIChatCompletionClient(model="gpt-4o")
                agent1 = AssistantAgent("Assistant1", model_client=model_client)
                agent2 = AssistantAgent("Assistant2", model_client=model_client)
                termination = TextMentionTermination("TERMINATE")
                team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
                await Console(team.run_stream(task="Tell me some jokes."))
            asyncio.run(main())
    """
    def __init__(
        self,
        participants: List[ChatAgent],
        termination_condition: TerminationCondition | None = None,
        max_turns: int | None = None,
    ) -> None:
        super().__init__(
            participants,
            group_chat_manager_class=RoundRobinGroupChatManager,
            termination_condition=termination_condition,
            max_turns=max_turns,
        )
    def _create_group_chat_manager_factory(
        self,
        group_topic_type: str,
        output_topic_type: str,
        participant_topic_types: List[str],
        participant_descriptions: List[str],
        termination_condition: TerminationCondition | None,
        max_turns: int | None,
    ) -> Callable[[], RoundRobinGroupChatManager]:
        def _factory() -> RoundRobinGroupChatManager:
            return RoundRobinGroupChatManager(
                group_topic_type,
                output_topic_type,
                participant_topic_types,
                participant_descriptions,
                termination_condition,
                max_turns,
            )
        return _factory