Source code for agentopera.engine.agent.base_agent

from __future__ import annotations

import inspect
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar, final

from typing_extensions import Self

from ..types.agent.agent import Agent
from ..types.agent import AgentId, AgentType, AgentMetadata
from ..types.agent.agent_instantiation_context import AgentInstantiationContext
from ..runtime.agent_engine import AgentEngine
from ..types.agent.cancellation_token import CancellationToken
from ..types.msg_context.message_context import MessageContext
from ..serialization import MessageSerializer, try_get_known_serializers_for_type
from ..subscription.subscription import Subscription, UnboundSubscription
from ..subscription.subscription_context import SubscriptionInstantiationContext
from ..types.msg_channel import MessageChannel
from ..subscription.topic_prefix_subscription import TopicPrefixSubscription

T = TypeVar("T", bound=Agent)

BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")


# Decorator for adding an unbound subscription to an agent
def subscription_factory(subscription: UnboundSubscription) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
    """:meta private:"""

    def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]:
        cls.internal_unbound_subscriptions_list.append(subscription)
        return cls

    return decorator


def handles(
    type: Type[Any], serializer: MessageSerializer[Any] | List[MessageSerializer[Any]] | None = None
) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
    def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]:
        if serializer is None:
            serializer_list = try_get_known_serializers_for_type(type)
        else:
            serializer_list = [serializer] if not isinstance(serializer, Sequence) else serializer

        if len(serializer_list) == 0:
            raise ValueError(f"No serializers found for type {type}. Please provide an explicit serializer.")

        cls.internal_extra_handles_types.append((type, serializer_list))
        return cls

    return decorator


[docs] class BaseAgent(ABC, Agent): internal_unbound_subscriptions_list: ClassVar[List[UnboundSubscription]] = [] """:meta private:""" internal_extra_handles_types: ClassVar[List[Tuple[Type[Any], List[MessageSerializer[Any]]]]] = [] """:meta private:""" def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) # Automatically set class_variable in each subclass so that they are not shared between subclasses cls.internal_extra_handles_types = [] cls.internal_unbound_subscriptions_list = [] @classmethod def _handles_types(cls) -> List[Tuple[Type[Any], List[MessageSerializer[Any]]]]: return cls.internal_extra_handles_types @classmethod def _unbound_subscriptions(cls) -> List[UnboundSubscription]: return cls.internal_unbound_subscriptions_list @property def metadata(self) -> AgentMetadata: assert self._id is not None return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description) def __init__(self, description: str) -> None: try: engine = AgentInstantiationContext.current_engine() id = AgentInstantiationContext.current_agent_id() except LookupError as e: raise RuntimeError( "BaseAgent must be instantiated within the context of an AgentEngine. It cannot be directly instantiated." ) from e self._engine: AgentEngine = engine self._id: AgentId = id if not isinstance(description, str): raise ValueError("Agent description must be a string") self._description = description @property def type(self) -> str: return self.id.type @property def id(self) -> AgentId: return self._id @property def engine(self) -> AgentEngine: return self._engine
[docs] @final async def on_message(self, message: Any, ctx: MessageContext) -> Any: return await self.on_message_impl(message, ctx)
[docs] @abstractmethod async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any: ...
[docs] async def send_message( self, message: Any, recipient: AgentId, *, cancellation_token: CancellationToken | None = None, message_id: str | None = None, ) -> Any: """See :py:meth:`agentopera.core.AgentEngine.send_message` for more information.""" if cancellation_token is None: cancellation_token = CancellationToken() return await self._engine.send_message( message, sender=self.id, recipient=recipient, cancellation_token=cancellation_token, message_id=message_id, )
[docs] async def publish_message( self, message: Any, message_channel: MessageChannel, *, cancellation_token: CancellationToken | None = None, message_id: str | None = None, ) -> None: await self._engine.publish_message(message, message_channel, sender=self.id, cancellation_token=cancellation_token, message_id=message_id)
[docs] async def save_state(self) -> Mapping[str, Any]: warnings.warn("save_state not implemented", stacklevel=2) return {}
[docs] async def load_state(self, state: Mapping[str, Any]) -> None: warnings.warn("load_state not implemented", stacklevel=2) pass
[docs] async def close(self) -> None: pass
[docs] @classmethod async def register( cls, engine: AgentEngine, type: str, agent_builder: Callable[[], Self | Awaitable[Self]], *, skip_class_subscriptions: bool = False, skip_direct_message_subscription: bool = False, ) -> AgentType: agent_type = AgentType(type) agent_type = await engine.register_agent_builder(type=agent_type, agent_builder=agent_builder, expected_class=cls) if not skip_class_subscriptions: with SubscriptionInstantiationContext.populate_context(agent_type): subscriptions: List[Subscription] = [] for unbound_subscription in cls._unbound_subscriptions(): subscriptions_list_result = unbound_subscription() if inspect.isawaitable(subscriptions_list_result): subscriptions_list = await subscriptions_list_result else: subscriptions_list = subscriptions_list_result subscriptions.extend(subscriptions_list) for subscription in subscriptions: await engine.subscribe(subscription) if not skip_direct_message_subscription: # Additionally adds a special prefix subscription for this agent to receive direct messages await engine.subscribe( TopicPrefixSubscription( # The prefix MUST include ":" to avoid collisions with other agents topic_prefix=agent_type.type + ":", agent_type=agent_type.type, ) ) # TODO: deduplication for _message_type, serializer in cls._handles_types(): engine.register_msg_serializer(serializer) return agent_type