import os
import re
import aiohttp
import json
from typing import Any, Dict, AsyncGenerator
from agentopera.engine.types.msg_context import MessageContext
from agentopera.engine.agent import RoutedAgent, message_handler
from ..router.constants import READ_BUFFER_SIZE
from agentopera.chatflow.messages import TextMessage, StopMessage, ModelClientStreamingChunkEvent
from agentopera.engine.agent import RoutedAgent
from ..utils.logger import logger
[docs]
class BaseFlowAgent(RoutedAgent):
"""Base class for flow agents that handle streaming API responses."""
def __init__(self, name: str, api_url: str, meta_info: str) -> None:
"""
Initialize the flow agent.
Args:
name: The name of the agent
api_url: The API endpoint URL
meta_info: Description of the agent's functionality
"""
super().__init__("Base Agent for chainopera flow")
self.name = name
api_key = os.getenv("FLOW_API_KEY")
if not api_key:
raise ValueError("FLOW_API_KEY environment variable is missing!")
self.api_url = f"{api_url}&x-api-key={api_key}"
self.meta_info = meta_info
logger.info(f"[{self.__class__.__name__}] Initialized with functionality: {meta_info}")
def _extract_agent_id(self) -> str:
"""
Extracts the agent ID (UUID) from the given API URL.
Args:
api_url (str): The API URL containing the agent ID.
Returns:
str: The extracted agent ID if found, otherwise raises ValueError.
"""
match = re.search(r"/build/([0-9a-fA-F-]+)/flow", self.api_url)
if match:
return match.group(1)
else:
try:
return self.api_url.split("/build/")[1].split("/flow")[0]
except IndexError:
logger.error("Agent ID not found in API URL")
[docs]
@message_handler
async def my_message_handler(self, message: TextMessage, ctx: MessageContext) -> None:
"""Handles messages and streams response chunks."""
assert ctx.message_channel is not None
logger.debug(f"[{self.__class__.__name__}] Received message from {message.source}: {message.content}")
flag = False
session_id = message.metadata.get("session_id")
if not session_id:
raise ValueError("Message metadata is missing `session_id` field")
async for chunk in self.fetch_data(message.content, session_id):
if chunk['event'] == 'add_message':
if chunk['data']['sender'] == 'User':
await self.publish_message(
TextMessage(content=chunk['data']['text'], source="user", metadata={"session_id": session_id}),
message_channel=DefaultMessageChannel(topic="response", source="user"),
message_id=ctx.message_id,
)
# Send the name of the agent
await self.publish_message(
TextMessage(content=self.name, source="agent_id", metadata={"session_id": session_id}),
message_channel=DefaultMessageChannel(topic="response", source=ctx.message_channel.topic),
message_id=ctx.message_id,
)
elif chunk['data']['sender'] == 'Machine':
continue
elif chunk['event'] == 'token':
flag = True
await self.publish_message(
ModelClientStreamingChunkEvent(content=chunk['data']['chunk'], source=ctx.message_channel.source, metadata={"session_id": session_id}),
message_channel=DefaultMessageChannel(topic="response", source=ctx.message_channel.source),
message_id=ctx.message_id,
)
elif chunk['event'] == 'end':
await self.publish_message(
StopMessage(content="", source=ctx.message_channel.source, metadata={"session_id": session_id}),
message_channel=DefaultMessageChannel(topic="response", source=ctx.message_channel.source),
message_id=ctx.message_id,
)
elif chunk['event'] == 'end_vertex':
if "ChatOutput" in chunk['data']['build_data']['id']:
if not flag:
await self.publish_message(
TextMessage(content=chunk['data']['build_data']['data']['artifacts']['message'], source="output", metadata={"session_id": session_id}),
message_channel=DefaultMessageChannel(topic="response", source="output"),
message_id=ctx.message_id,
)
else:
logger.error(f"[{self.__class__.__name__}] Unknown event: {chunk['event']}")
continue
[docs]
async def fetch_data(self, user_query: str, session_id: str) -> AsyncGenerator[Dict[str, Any], None]:
"""Streams responses from an external API."""
headers = {
"Content-Type": "application/json",
}
payload = {
"inputs": {
"input_value": user_query,
"session": session_id
}
}
async with aiohttp.ClientSession(read_bufsize=READ_BUFFER_SIZE) as session:
try:
logger.debug(f"[{self.__class__.__name__}] Sending API request: {json.dumps(payload, indent=2)}")
async with session.post(
self.api_url,
headers=headers,
json=payload,
timeout=1000
) as response:
logger.debug(f"[{self.__class__.__name__}] API Response Status: {response.status}")
if response.status == 200:
async for line in response.content:
if line.strip():
event_data = json.loads(line)
yield event_data
else:
error_msg = f"Error: API returned status {response.status} - {await response.text()}"
logger.error(f"[{self.__class__.__name__}] {error_msg}")
yield {"error": error_msg}
except Exception as e:
logger.error(f"[{self.__class__.__name__}] Exception occurred: {e}")
yield {"error": str(e)}