import asyncio
from typing import Dict, Optional
from datetime import datetime, timedelta
from .session import Session
from .session_repo import SessionRepository
from agentopera.engine.types.agent.agent_id import AgentId
from agentopera.engine.agent import RoutedAgent 
from agentopera.engine.runtime import DistAgentEngine
from agentopera.utils.logger import logger
[docs]
class SessionManager:
    def __init__(self, repository: SessionRepository, agent_runtime: DistAgentEngine, session_timeout_minutes: int = 30, cleanup_interval_sec: int = 60):
        self.repository = repository
        self.agent_runtime = agent_runtime
        self.active_sessions: Dict[str, Session] = {}
        self.session_timeout = timedelta(minutes=session_timeout_minutes)
        self._cleanup_task: Optional[asyncio.Task] = None
        self._cleanup_interval = cleanup_interval_sec
[docs]
    async def start(self):
        """Starts the session manager background task."""
        if self._cleanup_task is None:
            self._cleanup_task = asyncio.create_task(self._run_cleanup_loop()) 
[docs]
    async def shutdown(self):
        """Gracefully stops the background task."""
        if self._cleanup_task:
            self._cleanup_task.cancel()
            try:
                await self._cleanup_task
            except asyncio.CancelledError:
                logger.info("SessionManager cleanup task cancelled.") 
    
    async def _run_cleanup_loop(self):
        """Periodically clears expired sessions."""
        logger.info("SessionManager cleanup task started.")
        while True:
            try:
                await asyncio.sleep(self._cleanup_interval)
                await self.expire_old_sessions()
            except Exception as e:
                logger.error(f"Error in cleanup loop: {e}")
[docs]
    async def start_session(self, session_id: str) -> Session:
        """Gets or creates session by ID."""
        session = await self.get_session(session_id)
        if session:
            return session
        return await self.start_new_session(session_id) 
[docs]
    async def start_new_session(self, session_id: str) -> Session:
        """Creates a brand-new session."""
        session = Session(session_id=session_id)
        self.active_sessions[session_id] = session
        await self.repository.save_session(session)
        return session 
[docs]
    async def get_session(self, session_id: str) -> Optional[Session]:
        """Gets session and updates last-used time."""
        session = self.active_sessions.get(session_id)
        if not session:
            session = await self.repository.get_session(session_id)
            if session:
                self.active_sessions[session_id] = session
        if session:
            session.touch()
            await self.repository.save_session(session)
        return session 
[docs]
    async def register_agent(self, session_id: str, agent_id: AgentId):
        #logger.info(f"add agent {agent_id} into session {session_id}")
        session = await self.start_session(session_id)
        session.mark_agent_streaming(agent_id)
        await self.repository.save_session(session) 
    
[docs]
    async def interrupt_streaming_agents(self, session_id: str):
        """
        Sends a transient cancel signal to all currently streaming agents in a session.
        This does not end the session or mark it as inactive.
        """
        session = await self.get_session(session_id)
        if not session:
            logger.warning(f"No active session found for session_id={session_id}")
            return
        # Step 1: Make a copy to avoid mutating dict during iteration
        active_streaming_agents = list(session.agents_streaming)
        logger.info(f"current streaming agents {active_streaming_agents}")
        # Step 2: Clear the tracking state
        session.agents_streaming.clear()
        await self.repository.save_session(session)  # Persist state ASAP
        # Step 3: Cancel each agent
        for agent_id in active_streaming_agents:
            logger.info(f"[Session {session_id}] Interrupting streaming agent: {agent_id}")
            try:
                agent_instance = await self.agent_runtime.try_get_underlying_agent_instance(id=agent_id)
                if not isinstance(agent_instance, RoutedAgent):
                    logger.error(f"[Session {session_id}] Unsupported agent type {type(agent_instance)} for {agent_id}")
                    continue
                agent_instance.trigger_transient_cancel()
            except Exception as e:
                logger.error(f"[Session {session_id}] Failed to cancel agent {agent_id}: {e}")
            session.unmark_agent_streaming(agent_id) 
[docs]
    async def unmark_agent_streaming(self, session_id: str, agent_id: AgentId):
        #logger.info(f"remove agent {agent_id} from session {session_id}")
        session = await self.get_session(session_id)
        if session:
            session.unmark_agent_streaming(agent_id)
            await self.repository.save_session(session) 
[docs]
    async def expire_old_sessions(self):
        now = datetime.utcnow()
        expired = [
            sid for sid, sess in self.active_sessions.items()
            if now - sess.last_updated > self.session_timeout
        ]
        for sid in expired:
            await self.end_session(sid)
            logger.info(f"Expired session: {sid}") 
[docs]
    async def end_session(self, session_id: str):
        session = await self.get_session(session_id)
        if session:
            # Cancel any streaming tasks
            for agent_id in list(session.agents_streaming):
                try:
                    logger.info(f"Interrupting agent {agent_id} before teardown.")
                    agent_instance = await self.agent_runtime.try_get_underlying_agent_instance(agent_id)
                    if not isinstance(agent_instance, RoutedAgent):
                        logger.error(f"Not supported agent type {type(agent_instance)} for termination")
                        continue
                    agent_instance.trigger_transient_cancel()
                except Exception as e:
                    logger.warning(f"Failed to interrupt agent {agent_id}: {e}")
            # Remove agents from runtime
            for agent_id in list(session.agents_streaming):
                try:
                    logger.info(f"Removing agent {agent_id} from runtime.")
                    await self.agent_runtime.remove_agent(agent_id)
                except Exception as e:
                    logger.warning(f"Failed to remove agent {agent_id}: {e}")
            # Finalize session teardown
            session.end_session()
            await self.repository.save_session(session)
        self.active_sessions.pop(session_id, None)