import asyncio
from typing import Dict, Optional
from datetime import datetime, timedelta
from .session import Session
from .session_repo import SessionRepository
from agentopera.engine.types.agent.agent_id import AgentId
from agentopera.engine.agent import RoutedAgent
from agentopera.engine.runtime import DistAgentEngine
from agentopera.utils.logger import logger
[docs]
class SessionManager:
def __init__(self, repository: SessionRepository, agent_runtime: DistAgentEngine, session_timeout_minutes: int = 30, cleanup_interval_sec: int = 60):
self.repository = repository
self.agent_runtime = agent_runtime
self.active_sessions: Dict[str, Session] = {}
self.session_timeout = timedelta(minutes=session_timeout_minutes)
self._cleanup_task: Optional[asyncio.Task] = None
self._cleanup_interval = cleanup_interval_sec
[docs]
async def start(self):
"""Starts the session manager background task."""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._run_cleanup_loop())
[docs]
async def shutdown(self):
"""Gracefully stops the background task."""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
logger.info("SessionManager cleanup task cancelled.")
async def _run_cleanup_loop(self):
"""Periodically clears expired sessions."""
logger.info("SessionManager cleanup task started.")
while True:
try:
await asyncio.sleep(self._cleanup_interval)
await self.expire_old_sessions()
except Exception as e:
logger.error(f"Error in cleanup loop: {e}")
[docs]
async def start_session(self, session_id: str) -> Session:
"""Gets or creates session by ID."""
session = await self.get_session(session_id)
if session:
return session
return await self.start_new_session(session_id)
[docs]
async def start_new_session(self, session_id: str) -> Session:
"""Creates a brand-new session."""
session = Session(session_id=session_id)
self.active_sessions[session_id] = session
await self.repository.save_session(session)
return session
[docs]
async def get_session(self, session_id: str) -> Optional[Session]:
"""Gets session and updates last-used time."""
session = self.active_sessions.get(session_id)
if not session:
session = await self.repository.get_session(session_id)
if session:
self.active_sessions[session_id] = session
if session:
session.touch()
await self.repository.save_session(session)
return session
[docs]
async def register_agent(self, session_id: str, agent_id: AgentId):
#logger.info(f"add agent {agent_id} into session {session_id}")
session = await self.start_session(session_id)
session.mark_agent_streaming(agent_id)
await self.repository.save_session(session)
[docs]
async def interrupt_streaming_agents(self, session_id: str):
"""
Sends a transient cancel signal to all currently streaming agents in a session.
This does not end the session or mark it as inactive.
"""
session = await self.get_session(session_id)
if not session:
logger.warning(f"No active session found for session_id={session_id}")
return
# Step 1: Make a copy to avoid mutating dict during iteration
active_streaming_agents = list(session.agents_streaming)
logger.info(f"current streaming agents {active_streaming_agents}")
# Step 2: Clear the tracking state
session.agents_streaming.clear()
await self.repository.save_session(session) # Persist state ASAP
# Step 3: Cancel each agent
for agent_id in active_streaming_agents:
logger.info(f"[Session {session_id}] Interrupting streaming agent: {agent_id}")
try:
agent_instance = await self.agent_runtime.try_get_underlying_agent_instance(id=agent_id)
if not isinstance(agent_instance, RoutedAgent):
logger.error(f"[Session {session_id}] Unsupported agent type {type(agent_instance)} for {agent_id}")
continue
agent_instance.trigger_transient_cancel()
except Exception as e:
logger.error(f"[Session {session_id}] Failed to cancel agent {agent_id}: {e}")
session.unmark_agent_streaming(agent_id)
[docs]
async def unmark_agent_streaming(self, session_id: str, agent_id: AgentId):
#logger.info(f"remove agent {agent_id} from session {session_id}")
session = await self.get_session(session_id)
if session:
session.unmark_agent_streaming(agent_id)
await self.repository.save_session(session)
[docs]
async def expire_old_sessions(self):
now = datetime.utcnow()
expired = [
sid for sid, sess in self.active_sessions.items()
if now - sess.last_updated > self.session_timeout
]
for sid in expired:
await self.end_session(sid)
logger.info(f"Expired session: {sid}")
[docs]
async def end_session(self, session_id: str):
session = await self.get_session(session_id)
if session:
# Cancel any streaming tasks
for agent_id in list(session.agents_streaming):
try:
logger.info(f"Interrupting agent {agent_id} before teardown.")
agent_instance = await self.agent_runtime.try_get_underlying_agent_instance(agent_id)
if not isinstance(agent_instance, RoutedAgent):
logger.error(f"Not supported agent type {type(agent_instance)} for termination")
continue
agent_instance.trigger_transient_cancel()
except Exception as e:
logger.warning(f"Failed to interrupt agent {agent_id}: {e}")
# Remove agents from runtime
for agent_id in list(session.agents_streaming):
try:
logger.info(f"Removing agent {agent_id} from runtime.")
await self.agent_runtime.remove_agent(agent_id)
except Exception as e:
logger.warning(f"Failed to remove agent {agent_id}: {e}")
# Finalize session teardown
session.end_session()
await self.repository.save_session(session)
self.active_sessions.pop(session_id, None)