Source code for agentopera.router.session.session_manager

import asyncio
from typing import Dict, Optional
from datetime import datetime, timedelta

from .session import Session
from .session_repo import SessionRepository
from agentopera.engine.types.agent.agent_id import AgentId
from agentopera.engine.agent import RoutedAgent 
from agentopera.engine.runtime import DistAgentEngine
from agentopera.utils.logger import logger

[docs] class SessionManager: def __init__(self, repository: SessionRepository, agent_runtime: DistAgentEngine, session_timeout_minutes: int = 30, cleanup_interval_sec: int = 60): self.repository = repository self.agent_runtime = agent_runtime self.active_sessions: Dict[str, Session] = {} self.session_timeout = timedelta(minutes=session_timeout_minutes) self._cleanup_task: Optional[asyncio.Task] = None self._cleanup_interval = cleanup_interval_sec
[docs] async def start(self): """Starts the session manager background task.""" if self._cleanup_task is None: self._cleanup_task = asyncio.create_task(self._run_cleanup_loop())
[docs] async def shutdown(self): """Gracefully stops the background task.""" if self._cleanup_task: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: logger.info("SessionManager cleanup task cancelled.")
async def _run_cleanup_loop(self): """Periodically clears expired sessions.""" logger.info("SessionManager cleanup task started.") while True: try: await asyncio.sleep(self._cleanup_interval) await self.expire_old_sessions() except Exception as e: logger.error(f"Error in cleanup loop: {e}")
[docs] async def start_session(self, session_id: str) -> Session: """Gets or creates session by ID.""" session = await self.get_session(session_id) if session: return session return await self.start_new_session(session_id)
[docs] async def start_new_session(self, session_id: str) -> Session: """Creates a brand-new session.""" session = Session(session_id=session_id) self.active_sessions[session_id] = session await self.repository.save_session(session) return session
[docs] async def get_session(self, session_id: str) -> Optional[Session]: """Gets session and updates last-used time.""" session = self.active_sessions.get(session_id) if not session: session = await self.repository.get_session(session_id) if session: self.active_sessions[session_id] = session if session: session.touch() await self.repository.save_session(session) return session
[docs] async def register_agent(self, session_id: str, agent_id: AgentId): #logger.info(f"add agent {agent_id} into session {session_id}") session = await self.start_session(session_id) session.mark_agent_streaming(agent_id) await self.repository.save_session(session)
[docs] async def interrupt_streaming_agents(self, session_id: str): """ Sends a transient cancel signal to all currently streaming agents in a session. This does not end the session or mark it as inactive. """ session = await self.get_session(session_id) if not session: logger.warning(f"No active session found for session_id={session_id}") return # Step 1: Make a copy to avoid mutating dict during iteration active_streaming_agents = list(session.agents_streaming) logger.info(f"current streaming agents {active_streaming_agents}") # Step 2: Clear the tracking state session.agents_streaming.clear() await self.repository.save_session(session) # Persist state ASAP # Step 3: Cancel each agent for agent_id in active_streaming_agents: logger.info(f"[Session {session_id}] Interrupting streaming agent: {agent_id}") try: agent_instance = await self.agent_runtime.try_get_underlying_agent_instance(id=agent_id) if not isinstance(agent_instance, RoutedAgent): logger.error(f"[Session {session_id}] Unsupported agent type {type(agent_instance)} for {agent_id}") continue agent_instance.trigger_transient_cancel() except Exception as e: logger.error(f"[Session {session_id}] Failed to cancel agent {agent_id}: {e}") session.unmark_agent_streaming(agent_id)
[docs] async def unmark_agent_streaming(self, session_id: str, agent_id: AgentId): #logger.info(f"remove agent {agent_id} from session {session_id}") session = await self.get_session(session_id) if session: session.unmark_agent_streaming(agent_id) await self.repository.save_session(session)
[docs] async def expire_old_sessions(self): now = datetime.utcnow() expired = [ sid for sid, sess in self.active_sessions.items() if now - sess.last_updated > self.session_timeout ] for sid in expired: await self.end_session(sid) logger.info(f"Expired session: {sid}")
[docs] async def end_session(self, session_id: str): session = await self.get_session(session_id) if session: # Cancel any streaming tasks for agent_id in list(session.agents_streaming): try: logger.info(f"Interrupting agent {agent_id} before teardown.") agent_instance = await self.agent_runtime.try_get_underlying_agent_instance(agent_id) if not isinstance(agent_instance, RoutedAgent): logger.error(f"Not supported agent type {type(agent_instance)} for termination") continue agent_instance.trigger_transient_cancel() except Exception as e: logger.warning(f"Failed to interrupt agent {agent_id}: {e}") # Remove agents from runtime for agent_id in list(session.agents_streaming): try: logger.info(f"Removing agent {agent_id} from runtime.") await self.agent_runtime.remove_agent(agent_id) except Exception as e: logger.warning(f"Failed to remove agent {agent_id}: {e}") # Finalize session teardown session.end_session() await self.repository.save_session(session) self.active_sessions.pop(session_id, None)