import os
import asyncio
import json
from typing import Optional, Dict, Any
from agentopera.models.openai import OpenAIChatCompletionClient
from agentopera.engine.types.models import (
    SystemMessage
)
from agentopera.chatflow.messages import TextMessage, MultiModalMessage
from agentopera.engine.function_call import FunctionCall
from agentopera.engine.types.models import ModelFamily
from agentopera.utils.logger import logger
from .semantic_router_components import IntentClassifierBase
from .message_utils import parse_context_to_messages, construct_debug_str
from .intent_registry import IntentRegistry
from posthog import Posthog
[docs]
class LLMIntentClassifier(IntentClassifierBase):
    def __init__(self, intent_registry: IntentRegistry, model="Llama-3.1-8B-Instruct"):
        self.posthog = Posthog(project_api_key=os.getenv('NEXT_PUBLIC_POSTHOG_KEY'), host='https://us.i.posthog.com', enable_exception_autocapture=True)
        
        self.intent_registry = intent_registry
        self.intents_as_tools = self.intent_registry.get_tool_schemas()
        self.model_name = model
        self.client = self._set_model_client(model)
        self.backup_models = ["gpt-4o-mini"]
    def _set_model_client(self, model: str):
        """Initialize the model client based on the model name."""
        model_info = {
            "vision": False,
            "function_calling": True,
            "json_output": True,
            "family": ModelFamily.UNKNOWN,
        }
        if model == "gpt-4o-mini":
            return OpenAIChatCompletionClient(
                temperature=0.7,
                api_key=os.environ["LLM_GATEWAY_KEY"],
                base_url=os.getenv("TENSOROPERA_API_GATEWAY"),
                model="gpt-4o-mini",
                tools=self.intents_as_tools,
                tool_choice="required"
            )
        elif model == "Llama-3.1-8B-Instruct":
            return OpenAIChatCompletionClient(
                temperature=0.7,
                api_key=os.environ["LLM_GATEWAY_KEY"],
                base_url=os.getenv("TENSOROPERA_API_GATEWAY"),
                model="Llama-3.1-8B-Instruct",
                model_info=model_info,
                tools=self.intents_as_tools,
                tool_choice="auto"
            )
        else:
            raise ValueError(f"Model '{model}' is not supported")
    def _parse_response(self, response) -> tuple[str, float]:
        """
        Parses the response from the model to extract the predicted intent and normalized confidence score.
        Expects `confidence_score` as a stringified integer (e.g. "87"),
        which is converted to a float between 0.1 and 1.0.
        """
        try:
            def normalize_score(raw_score: str | float | int) -> float:
                try:
                    score = float(raw_score)
                    if score > 1:  # Normalize scores like "87"
                        score = max(min(score / 100, 1.0), 0.1)
                    return round(score, 3)
                except Exception as e:
                    logger.warning(f"Invalid score format: {raw_score} — defaulting to 0.5")
                    return 0.5
            # Case 1: Function call output (OpenAI-style tools)
            if isinstance(response.content, list) and isinstance(response.content[0], FunctionCall):
                intent = response.content[0].name
                arguments = json.loads(response.content[0].arguments)
                score = normalize_score(arguments.get("confidence_score", 50))
                return intent, score
            # Case 2: String response (e.g. Llama-style or legacy)
            if isinstance(response.content, str):
                content = response.content.strip()
                if content.startswith("<|python_tag|>") or content.endswith("<|eom_id|>"):
                    content = content.removeprefix("<|python_tag|>").removesuffix("<|eom_id|>")
                    data = json.loads(content)
                    intent = data.get("name", "chat_intent")
                    raw_score = data.get("parameters", {}).get("confidence_score", 50)
                    return intent, normalize_score(raw_score)
                # Fallback: plain string output with no function call
                return "chat_intent", 0.5
        except Exception as e:
            logger.error(f"Error parsing model response: {e}")
            self.posthog.capture(
                distinct_id="intent_classifier",
                event="intent_classification_parsing_failed",
                properties={
                    "error": str(e),
                    "response": response.model_dump_json() if hasattr(response, "model_dump_json") else str(response),
                },
            )
        return "chat_intent", 0.5
    
    async def _try_classify(self, messages, model_name, timeout, max_retries=3, backoff_factor=0.5):
        """
        Attempt to classify using the specified model with retry logic.
        Args:
            messages (List): The messages to classify.
            model_name (str): The name of the model to use.
            timeout (float): The timeout for the classification request.
            max_retries (int): Maximum number of retries before giving up.
            backoff_factor (float): Multiplier for exponential backoff between retries.
        Returns:
            tuple: The parsed intent and confidence score, or (None, None) if all attempts fail.
        """
        attempt = 0
        response, messages_debug = None, None
        while attempt < max_retries:
            try:
                logger.info(f"Attempt {attempt + 1}/{max_retries} - Classifying intent with model '{model_name}'")
                response = await asyncio.wait_for(self.client.create(messages=messages), timeout)
                return self._parse_response(response)
            except asyncio.TimeoutError:
                attempt += 1
                logger.error(f"Attempt {attempt} - Classification timed out for model '{model_name}'")
                await asyncio.sleep(backoff_factor * (2 ** (attempt - 1)))  # Exponential backoff
            except Exception as e:
                attempt += 1
                logger.error(f"Attempt {attempt} - Error during classification with model '{model_name}': {e}")
                messages_debug = construct_debug_str(messages=messages)
                self.posthog.capture(
                    distinct_id="intent_classifier",
                    event="intent_classification_failed",
                    properties={"error": str(e), "messages": messages_debug, "response": response.model_dump_json() if response is not None else None},
                )
                await asyncio.sleep(backoff_factor * (2 ** (attempt - 1)))  # Exponential backoff
        # Return default if all retries fail
        logger.error(f"All {max_retries} attempts failed for model '{model_name}'. Returning default intent.")
        return None, None
[docs]
    async def classify_intent(self, message: TextMessage | MultiModalMessage, timeout: Optional[float] = 80.0) -> str:
        """Returns the intent with the highest confidence score."""
        try:
            logger.info(f"Classifying intent - START\n\n")
            context_messages = parse_context_to_messages(message.context)
            messages = [
                SystemMessage(
                    content=(
                        "Classify the user's intent based on the full conversation history, including previous messages. "
                        "Ensure the classification considers context, as users may ask quick follow-up questions or switch topics. "
                        "Your classification must reflect the user's most relevant intent given the full history, but the most recent message from user: "
                        f"\"{message.content}\" is often the most important. "
                        "If the intent remains ambiguous, assign `chat_intent`. "
                        "Ensure the classification aligns with the following intent descriptions."
                        "You MUST respond with a function call named `intent` and a confidence score integer between 0 and 100."
                    )
                )
            ] + context_messages
            
            # Try the initial model
            intent, score = await self._try_classify(messages, self.model_name, timeout)
            # If the initial model fails, try backup models
            for model_name in self.backup_models:
                if intent is not None:
                    break
                logger.warning(f"Falling back to model '{model_name}'")
                self.client = self._set_model_client(model_name)
                intent, score = await self._try_classify(messages, model_name, timeout)
            # If all attempts failed, return default intent
            if intent is None:
                logger.error("All models timed out or failed. Returning default intent.")
                return "chat_intent"
            logger.info(f"Classifying intent - END. {intent} (score: {score})")
            #return intent
            return intent
        except Exception as e:
            logger.error(f"Intent classification failed: {e}")
            return "chat_intent"