# llm_client.py
# LLM Clients in the SAME style as ShopifyClient:
#  - __init__(token_metadata, credentials, Integration_id)
#  - headers()
#  - get_new_token()
#  - fetch_page(endpoint, params)
#  - stream_batches(endpoint, batch_size)

import requests
from rest_framework.exceptions import PermissionDenied


GEMINI_CONFIG = {

    **{e: {
        "type": "page_token",
        "page_param": "pageSize",
        "page_key": "pageToken"
    } for e in [
        'models','files','cachedContents'
    ]}}


DEEPSEEK_CONFIG = {
    **{e:{
        "type":"None"
    }for e in [
        'models', 'user/balance'
    ]}
}

OPENAI_CONFIG = {
    **{e:{
        "type":"None"
    }for e in [
        'models', 'certificates', 'chat/completions', 'evals', 'fine_tuning/jobs',
    'batches', 'files', 'vector_stores', 'containers', 'assistants', 'admin_api_keys',
    'invites', 'users', 'projects', 'audit_logs','chatkit/threads','groups',
    'usage/embeddings', 'usage/moderations','roles','usage/completions',
    'usage/images', 'usage/audio_speeches', 'usage/audio_transcriptions',
    'usage/vector_stores', 'usage/code_interpreter_sessions', 'costs',
    'certificates','audio/voice_consents','videos'
    ]}
}







OPENAI_ORG_APIS = ['certificates','admin_api_keys','invites','users','projects','audit_logs','usage/completions','usage/embeddings',
            'usage/moderations','usage/images','usage/audio_speeches','usage/audio_transcriptions','usage/vector_stores','usage/code_interpreter_sessions',
            'costs','certificates','roles','groups']


class BaseLLMClient:

    def __init__(self, token_metadata, credentials, Integration_id):
        self.token_metadata = token_metadata or {}
        self.credentials = credentials or {}
        self.Integration_id = Integration_id

    def headers(self):
        return {
            "Accept": "application/json",
            "Content-Type": "application/json",
        }

    def get_new_token(self):
        # API-key providers don't need refresh; keep for compatibility.
        pass

    # def _normalize_prompts(self, endpoint: str, params: dict):
    #     params = params or {}

    #     prompts = params.get("prompts")
    #     if isinstance(prompts, list) and prompts:
    #         return [str(p) for p in prompts if p is not None and str(p).strip()]

    #     prompt = params.get("prompt")
    #     if prompt is not None and str(prompt).strip():
    #         return [str(prompt)]

    #     # Optional convenience: if someone passes prompt directly as endpoint
    #     if endpoint and endpoint not in ("chat", "models") and str(endpoint).strip():
    #         return [str(endpoint)]

    #     return ["ping"]

    def fetch_page(self, endpoint: str, params: dict):
        raise NotImplementedError

    def stream_batches(self, endpoint: str, batch_size: int = 100, params: dict = None):
        """
        Same semantics as ShopifyClient.stream_batches:
        yields a list (batch) each time.

        Each item in the batch is a dict:
          { "prompt": ..., "response": ..., "raw": ... }
        """


class OpenAIClient(BaseLLMClient):
    def __init__(self, token_metadata, credentials, Integration_id):
        super().__init__(token_metadata, credentials, Integration_id)
        self.api_key = self.credentials["api_token"]
        self.site_url = (self.credentials.get("site_url") or "https://api.openai.com/v1/")
        self.model = self.credentials.get("default_model") or "gpt-4o-mini"
        self.cfg = OPENAI_CONFIG

    def headers(self):
        return {
            "Authorization": f"Bearer {self.api_key}",
            "Accept": "application/json",
            "Content-Type": "application/json",
        }

    def fetch_page(self, endpoint: str, params: dict):

        if endpoint in OPENAI_ORG_APIS:
            base_url = f"{self.site_url}organization/{endpoint}"
        else:
            base_url = f"{self.site_url}{endpoint}"

        r = requests.get(
            base_url,
            headers=self.headers()
        )
        if r.status_code == 200:
            return r.json().get('data')
        elif r.status_code ==403:
            raise PermissionDenied({
                "message": "Require Permission"
            })

        else:
            return None
        
    def stream_batches(self, endpoint: str, batch_size: int = 100):
        while True:
            yield self.fetch_page(endpoint, {})
            break

class DeepSeekClient(BaseLLMClient):
    def __init__(self, token_metadata, credentials, Integration_id):
        super().__init__(token_metadata, credentials, Integration_id)
        self.api_key = self.credentials["api_token"]
        self.site_url = (self.credentials.get("site_url") or "https://api.deepseek.com")
        self.model = self.credentials.get("default_model") or "deepseek-chat"
        self.cfg = DEEPSEEK_CONFIG

    def headers(self):
        return {
            "Authorization": f"Bearer {self.api_key}",
            "Accept": "application/json",
            "Content-Type": "application/json",
        }

    def fetch_page(self, endpoint: str, params: dict):
        prompt = (params or {}).get("prompt") or "ping"

        url = f"{self.site_url}/{endpoint}"
        r = requests.get(
            url,
            headers=self.headers()
        )
        if r.status_code == 200:
            return r.json().get('data') if endpoint =='models' else r.json().get('balance_infos')
        elif r.status_code ==403:
            raise PermissionDenied({
                "message": "Require Permission"
            })


        else:
            return None
        
    def stream_batches(self, endpoint: str, batch_size: int = 100):
        while True:
            yield self.fetch_page(endpoint, {})
            break


class GeminiClient(BaseLLMClient):
    def __init__(self, token_metadata, credentials, Integration_id):
        super().__init__(token_metadata, credentials, Integration_id)
        self.api_key = self.credentials["api_token"]
        self.site_url = (self.credentials.get("site_url") or "https://generativelanguage.googleapis.com/v1beta/")
        self.model = self.credentials.get("default_model") or "gemini-1.5-flash"
        self.cfg = GEMINI_CONFIG

    def headers(self):
        return {
            "Accept": "application/json",
            "Content-Type": "application/json",
        }

    def fetch_page(self, endpoint: str, params: dict):
        url = f"{self.site_url}{endpoint}"
        check_flag = False
        if params=={}:
            check_flag = True
            params = {
                "pageSize":100
            }

        r = requests.post(
            url,
            headers=self.headers(),
            params=params
        )
        if r.status_code == 200:
            response = r.json()
            response_data = response.get(endpoint,response)
            if response_data!=[]:
                if isinstance(response.get(endpoint,response),dict):
                    return [response]
            return response
        elif r.status_code ==403:
            raise PermissionDenied({
                "message": "Require Permission"
            })

        else:
            return []
    
    def stream_batches(self, endpoint: str, batch_size: int = 10000):
        cfg = self.cfg[endpoint]
        batch = []
        next_page = None 
        while True:
            params = {
                cfg['page_param']:100
            }
            if next_page:
                params[cfg['page_key']] = next_page
            
            response = self.fetch_page(endpoint,params)
            page_token = response[0].get('nextPageToken')
            if not page_token:
                break
            if response ==[]:
                break
            if response !=[]:
                    response_data = response.get(endpoint,response)
                    batch.extend(response_data)
            
            if len(batch) >= batch_size:
                yield batch
                batch = []
        if batch:
            yield batch
        return []
                
    


class AnthropicClient(BaseLLMClient):
    def __init__(self, token_metadata, credentials, Integration_id):
        super().__init__(token_metadata, credentials, Integration_id)
        self.api_key = self.credentials["api_key"]
        self.site_url = (self.credentials.get("site_url") or "https://api.anthropic.com").rstrip("/")
        self.model = self.credentials.get("default_model") or "claude-3-haiku-20240307"
        self.anthropic_version = self.credentials.get("anthropic_version") or "2023-06-01"

    def headers(self):
        return {
            "x-api-key": self.api_key,
            "anthropic-version": self.anthropic_version,
            "Accept": "application/json",
            "Content-Type": "application/json",
        }

    def fetch_page(self, endpoint: str, params: dict):
        prompt = (params or {}).get("prompt") or "ping"
        url = f"{self.site_url}/v1/messages"

        r = requests.post(
            url,
            headers=self.headers(),
            json={
                "model": self.model,
                "max_tokens": 256,
                "messages": [{"role": "user", "content": prompt}],
            },
            timeout=60,
        )
        if r.status_code != 200:
            raise PermissionDenied({
                "message": "Require Permission"
            })


        data = r.json()
        text = ""
        try:
            text = "".join([b.get("text", "") for b in data.get("content", [])])
        except Exception:
            text = str(data)

        return {"prompt": prompt, "response": text, "raw": data}


class AzureOpenAIClient(BaseLLMClient):
    def __init__(self, token_metadata, credentials, Integration_id):
        super().__init__(token_metadata, credentials, Integration_id)
        self.api_key = self.credentials["api_key"]
        self.endpoint = self.credentials["endpoint"].rstrip("/")
        self.deployment = self.credentials["deployment"]
        self.api_version = self.credentials.get("api_version") or "2024-10-21"

    def headers(self):
        return {
            "api-key": self.api_key,
            "Accept": "application/json",
            "Content-Type": "application/json",
        }

    def fetch_page(self, endpoint: str, params: dict):
        prompt = (params or {}).get("prompt") or "ping"
        url = f"{self.endpoint}/openai/deployments/{self.deployment}/chat/completions?api-version={self.api_version}"

        r = requests.post(
            url,
            headers=self.headers(),
            json={
                "messages": [{"role": "user", "content": prompt}],
                "max_tokens": 256,
            },
            timeout=60,
        )
        if r.status_code != 200:
            raise PermissionDenied({
                "message": "Require Permission"
            })


        data = r.json()
        return {
            "prompt": prompt,
            "response": data["choices"][0]["message"]["content"],
            "raw": data,
        }


class MetaLlamaClient(BaseLLMClient):
    """
    OpenAI-compatible self-hosted LLaMA (vLLM/Ollama/TGI).

    Required:
      - base_url OR site_url (example: http://localhost:8000 OR http://localhost:8000/v1)
    Optional:
      - api_key
      - default_model
    """

    def __init__(self, token_metadata, credentials, Integration_id):
        super().__init__(token_metadata, credentials, Integration_id)

        self.site_url = (self.credentials.get("base_url") or self.credentials.get("site_url") or "").rstrip("/")
        if not self.site_url:
            raise ValueError("base_url (or site_url) is required for meta_llama")

        self.api_key = self.credentials.get("api_key")  # optional
        self.model = self.credentials.get("default_model") or "llama-3.1-8b-instruct"

    def headers(self):
        h = {
            "Accept": "application/json",
            "Content-Type": "application/json",
        }
        if self.api_key:
            h["Authorization"] = f"Bearer {self.api_key}"
        return h

    def fetch_page(self, endpoint: str, params: dict):
        prompt = (params or {}).get("prompt") or "ping"

        url = (
            f"{self.site_url}/v1/chat/completions"
            if not self.site_url.endswith("/v1")
            else f"{self.site_url}/chat/completions"
        )

        r = requests.post(
            url,
            headers=self.headers(),
            json={
                "model": self.model,
                "messages": [{"role": "user", "content": prompt}],
            },
            timeout=60,
        )
        if r.status_code != 200:
            raise PermissionDenied({
                "message": "Require Permission"
            })


        data = r.json()
        return {
            "prompt": prompt,
            "response": data["choices"][0]["message"]["content"],
            "raw": data,
        }


# Registry (so IntegrationExtractor can pick correct client class)
LLM_CLIENTS = {
    "openai": OpenAIClient,
    "deepseek": DeepSeekClient,
    "gemini": GeminiClient,
    "anthropic": AnthropicClient,
    "azure_openai": AzureOpenAIClient,
    "meta_llama": MetaLlamaClient,
}