import os
import re
import aiohttp
import json
from typing import Any, Dict, AsyncGenerator
from agentopera.engine.types.msg_context import MessageContext
from agentopera.engine.agent import RoutedAgent, message_handler
from ..router.constants import READ_BUFFER_SIZE
from agentopera.chatflow.messages import TextMessage, StopMessage, ModelClientStreamingChunkEvent
from agentopera.engine.agent import RoutedAgent
from ..utils.logger import logger
[docs]
class BaseFlowAgent(RoutedAgent):
    """Base class for flow agents that handle streaming API responses."""
    def __init__(self, name: str, api_url: str, meta_info: str) -> None:
        """
        Initialize the flow agent.
        
        Args:
            name: The name of the agent
            api_url: The API endpoint URL
            meta_info: Description of the agent's functionality
        """
        super().__init__("Base Agent for chainopera flow")
        self.name = name
     
        api_key = os.getenv("FLOW_API_KEY")
        if not api_key:
            raise ValueError("FLOW_API_KEY environment variable is missing!")
        self.api_url = f"{api_url}&x-api-key={api_key}"
        self.meta_info = meta_info
        logger.info(f"[{self.__class__.__name__}] Initialized with functionality: {meta_info}")
    
    def _extract_agent_id(self) -> str:
        """
        Extracts the agent ID (UUID) from the given API URL.
        Args:
            api_url (str): The API URL containing the agent ID.
        Returns:
            str: The extracted agent ID if found, otherwise raises ValueError.
        """
        match = re.search(r"/build/([0-9a-fA-F-]+)/flow", self.api_url)
        
        if match:
            return match.group(1)
        else:
            try:
                return self.api_url.split("/build/")[1].split("/flow")[0]
            except IndexError:
                logger.error("Agent ID not found in API URL")
        
[docs]
    @message_handler
    async def my_message_handler(self, message: TextMessage, ctx: MessageContext) -> None:
        """Handles messages and streams response chunks."""
        assert ctx.message_channel is not None
        logger.debug(f"[{self.__class__.__name__}] Received message from {message.source}: {message.content}")
        flag = False
        session_id = message.metadata.get("session_id")
        if not session_id:
            raise ValueError("Message metadata is missing `session_id` field")
        async for chunk in self.fetch_data(message.content, session_id):
            if chunk['event'] == 'add_message':
                if chunk['data']['sender'] == 'User':
                    await self.publish_message(
                        TextMessage(content=chunk['data']['text'], source="user", metadata={"session_id": session_id}),
                        message_channel=DefaultMessageChannel(topic="response", source="user"),
                        message_id=ctx.message_id,
                    )
                    # Send the name of the agent
                    await self.publish_message(
                        TextMessage(content=self.name, source="agent_id", metadata={"session_id": session_id}),
                        message_channel=DefaultMessageChannel(topic="response", source=ctx.message_channel.topic),
                        message_id=ctx.message_id,
                    )
                elif chunk['data']['sender'] == 'Machine':
                    continue
            elif chunk['event'] == 'token':
                flag = True
                await self.publish_message(
                    ModelClientStreamingChunkEvent(content=chunk['data']['chunk'], source=ctx.message_channel.source, metadata={"session_id": session_id}),
                    message_channel=DefaultMessageChannel(topic="response", source=ctx.message_channel.source),
                    message_id=ctx.message_id,
                )
            elif chunk['event'] == 'end':
                await self.publish_message(
                    StopMessage(content="", source=ctx.message_channel.source, metadata={"session_id": session_id}),
                    message_channel=DefaultMessageChannel(topic="response", source=ctx.message_channel.source),
                    message_id=ctx.message_id,
                )
            elif chunk['event'] == 'end_vertex':
                if "ChatOutput" in chunk['data']['build_data']['id']:
                    if not flag:
                        await self.publish_message(
                            TextMessage(content=chunk['data']['build_data']['data']['artifacts']['message'], source="output", metadata={"session_id": session_id}),
                            message_channel=DefaultMessageChannel(topic="response", source="output"),
                            message_id=ctx.message_id,
                        )
            else:
                logger.error(f"[{self.__class__.__name__}] Unknown event: {chunk['event']}")
                continue 
[docs]
    async def fetch_data(self, user_query: str, session_id: str) -> AsyncGenerator[Dict[str, Any], None]:
        """Streams responses from an external API."""
        headers = {
            "Content-Type": "application/json",
        }
        payload = {
            "inputs": {
                "input_value": user_query,
                "session": session_id
            }
        }
        async with aiohttp.ClientSession(read_bufsize=READ_BUFFER_SIZE) as session:
            try:
                logger.debug(f"[{self.__class__.__name__}] Sending API request: {json.dumps(payload, indent=2)}")
                
                async with session.post(
                    self.api_url, 
                    headers=headers, 
                    json=payload, 
                    timeout=1000
                ) as response:
                    logger.debug(f"[{self.__class__.__name__}] API Response Status: {response.status}")
                    if response.status == 200:
                        async for line in response.content:
                            if line.strip():
                                event_data = json.loads(line)
                                yield event_data
                    else:
                        error_msg = f"Error: API returned status {response.status} - {await response.text()}"
                        logger.error(f"[{self.__class__.__name__}] {error_msg}")
                        yield {"error": error_msg}
            except Exception as e:
                logger.error(f"[{self.__class__.__name__}] Exception occurred: {e}")
                yield {"error": str(e)}