"""Google Code Assist API client — project discovery, onboarding, quota.

The Code Assist API powers Google's official gemini-cli. It sits at
``cloudcode-pa.googleapis.com`` and provides:

- Free tier access (generous daily quota) for personal Google accounts
- Paid tier access via GCP projects with billing / Workspace / Standard / Enterprise

This module handles the control-plane dance needed before inference:

1. ``load_code_assist()`` — probe the user's account to learn what tier they're on
   and whether a ``cloudaicompanionProject`` is already assigned.
2. ``onboard_user()`` — if the user hasn't been onboarded yet (new account, fresh
   free tier, etc.), call this with the chosen tier + project id. Supports LRO
   polling for slow provisioning.
3. ``retrieve_user_quota()`` — fetch the ``buckets[]`` array showing remaining
   quota per model, used by the ``/gquota`` slash command.

VPC-SC handling: enterprise accounts under a VPC Service Controls perimeter
will get ``SECURITY_POLICY_VIOLATED`` on ``load_code_assist``. We catch this
and force the account to ``standard-tier`` so the call chain still succeeds.

Derived from opencode-gemini-auth (MIT) and clawdbot/extensions/google. The
request/response shapes are specific to Google's internal Code Assist API,
documented nowhere public — we copy them from the reference implementations.
"""

from __future__ import annotations

import json
import logging
import time
import urllib.error
import urllib.parse
import urllib.request
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

logger = logging.getLogger(__name__)


# =============================================================================
# Constants
# =============================================================================

CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com"

# Fallback endpoints tried when prod returns an error during project discovery
FALLBACK_ENDPOINTS = [
    "https://daily-cloudcode-pa.sandbox.googleapis.com",
    "https://autopush-cloudcode-pa.sandbox.googleapis.com",
]

# Tier identifiers that Google's API uses
FREE_TIER_ID = "free-tier"
LEGACY_TIER_ID = "legacy-tier"
STANDARD_TIER_ID = "standard-tier"

# Default HTTP headers matching gemini-cli's fingerprint.
# Google may reject unrecognized User-Agents on these internal endpoints.
_GEMINI_CLI_USER_AGENT = "google-api-nodejs-client/9.15.1 (gzip)"
_X_GOOG_API_CLIENT = "gl-node/24.0.0"
_DEFAULT_REQUEST_TIMEOUT = 30.0
_ONBOARDING_POLL_ATTEMPTS = 12
_ONBOARDING_POLL_INTERVAL_SECONDS = 5.0


class CodeAssistError(RuntimeError):
    """Exception raised by the Code Assist (``cloudcode-pa``) integration.

    Carries HTTP status / response / retry-after metadata so the agent's
    ``error_classifier._extract_status_code`` and the main loop's Retry-After
    handling (which walks ``error.response.headers``) pick up the right
    signals.  Without these, 429s from the OAuth path look like opaque
    ``RuntimeError`` and skip the rate-limit path.
    """

    def __init__(
        self,
        message: str,
        *,
        code: str = "code_assist_error",
        status_code: Optional[int] = None,
        response: Any = None,
        retry_after: Optional[float] = None,
        details: Optional[Dict[str, Any]] = None,
    ) -> None:
        super().__init__(message)
        self.code = code
        # ``status_code`` is picked up by ``agent.error_classifier._extract_status_code``
        # so a 429 from Code Assist classifies as FailoverReason.rate_limit and
        # triggers the main loop's fallback_providers chain the same way SDK
        # errors do.
        self.status_code = status_code
        # ``response`` is the underlying ``httpx.Response`` (or a shim with a
        # ``.headers`` mapping and ``.json()`` method).  The main loop reads
        # ``error.response.headers["Retry-After"]`` to honor Google's retry
        # hints when the backend throttles us.
        self.response = response
        # Parsed ``Retry-After`` seconds (kept separately for convenience —
        # Google returns retry hints in both the header and the error body's
        # ``google.rpc.RetryInfo`` details, and we pick whichever we found).
        self.retry_after = retry_after
        # Parsed structured error details from the Google error envelope
        # (e.g. ``{"reason": "MODEL_CAPACITY_EXHAUSTED", "status": "RESOURCE_EXHAUSTED"}``).
        # Useful for logging and for tests that want to assert on specifics.
        self.details = details or {}


class ProjectIdRequiredError(CodeAssistError):
    def __init__(self, message: str = "GCP project id required for this tier") -> None:
        super().__init__(message, code="code_assist_project_id_required")


# =============================================================================
# HTTP primitive (auth via Bearer token passed per-call)
# =============================================================================

def _build_headers(access_token: str, *, user_agent_model: str = "") -> Dict[str, str]:
    ua = _GEMINI_CLI_USER_AGENT
    if user_agent_model:
        ua = f"{ua} model/{user_agent_model}"
    return {
        "Content-Type": "application/json",
        "Accept": "application/json",
        "Authorization": f"Bearer {access_token}",
        "User-Agent": ua,
        "X-Goog-Api-Client": _X_GOOG_API_CLIENT,
        "x-activity-request-id": str(uuid.uuid4()),
    }


def _client_metadata() -> Dict[str, str]:
    """Match Google's gemini-cli exactly — unrecognized metadata may be rejected."""
    return {
        "ideType": "IDE_UNSPECIFIED",
        "platform": "PLATFORM_UNSPECIFIED",
        "pluginType": "GEMINI",
    }


def _post_json(
    url: str,
    body: Dict[str, Any],
    access_token: str,
    *,
    timeout: float = _DEFAULT_REQUEST_TIMEOUT,
    user_agent_model: str = "",
) -> Dict[str, Any]:
    data = json.dumps(body).encode("utf-8")
    request = urllib.request.Request(
        url, data=data, method="POST",
        headers=_build_headers(access_token, user_agent_model=user_agent_model),
    )
    try:
        with urllib.request.urlopen(request, timeout=timeout) as response:
            raw = response.read().decode("utf-8", errors="replace")
            return json.loads(raw) if raw else {}
    except urllib.error.HTTPError as exc:
        detail = ""
        try:
            detail = exc.read().decode("utf-8", errors="replace")
        except Exception:
            pass
        # Special case: VPC-SC violation should be distinguishable
        if _is_vpc_sc_violation(detail):
            raise CodeAssistError(
                f"VPC-SC policy violation: {detail}",
                code="code_assist_vpc_sc",
            ) from exc
        raise CodeAssistError(
            f"Code Assist HTTP {exc.code}: {detail or exc.reason}",
            code=f"code_assist_http_{exc.code}",
        ) from exc
    except urllib.error.URLError as exc:
        raise CodeAssistError(
            f"Code Assist request failed: {exc}",
            code="code_assist_network_error",
        ) from exc


def _is_vpc_sc_violation(body: str) -> bool:
    """Detect a VPC Service Controls violation from a response body."""
    if not body:
        return False
    try:
        parsed = json.loads(body)
    except (json.JSONDecodeError, ValueError):
        return "SECURITY_POLICY_VIOLATED" in body
    # Walk the nested error structure Google uses
    error = parsed.get("error") if isinstance(parsed, dict) else None
    if not isinstance(error, dict):
        return False
    details = error.get("details") or []
    if isinstance(details, list):
        for item in details:
            if isinstance(item, dict):
                reason = item.get("reason") or ""
                if reason == "SECURITY_POLICY_VIOLATED":
                    return True
    msg = str(error.get("message", ""))
    return "SECURITY_POLICY_VIOLATED" in msg


# =============================================================================
# load_code_assist — discovers current tier + assigned project
# =============================================================================

@dataclass
class CodeAssistProjectInfo:
    """Result from ``load_code_assist``."""
    current_tier_id: str = ""
    cloudaicompanion_project: str = ""   # Google-managed project (free tier)
    allowed_tiers: List[str] = field(default_factory=list)
    raw: Dict[str, Any] = field(default_factory=dict)


def load_code_assist(
    access_token: str,
    *,
    project_id: str = "",
    user_agent_model: str = "",
) -> CodeAssistProjectInfo:
    """Call ``POST /v1internal:loadCodeAssist`` with prod → sandbox fallback.

    Returns whatever tier + project info Google reports. On VPC-SC violations,
    returns a synthetic ``standard-tier`` result so the chain can continue.
    """
    body: Dict[str, Any] = {
        "metadata": {
            "duetProject": project_id,
            **_client_metadata(),
        },
    }
    if project_id:
        body["cloudaicompanionProject"] = project_id

    endpoints = [CODE_ASSIST_ENDPOINT] + FALLBACK_ENDPOINTS
    last_err: Optional[Exception] = None
    for endpoint in endpoints:
        url = f"{endpoint}/v1internal:loadCodeAssist"
        try:
            resp = _post_json(url, body, access_token, user_agent_model=user_agent_model)
            return _parse_load_response(resp)
        except CodeAssistError as exc:
            if exc.code == "code_assist_vpc_sc":
                logger.info("VPC-SC violation on %s — defaulting to standard-tier", endpoint)
                return CodeAssistProjectInfo(
                    current_tier_id=STANDARD_TIER_ID,
                    cloudaicompanion_project=project_id,
                )
            last_err = exc
            logger.warning("loadCodeAssist failed on %s: %s", endpoint, exc)
            continue
    if last_err:
        raise last_err
    return CodeAssistProjectInfo()


def _parse_load_response(resp: Dict[str, Any]) -> CodeAssistProjectInfo:
    current_tier = resp.get("currentTier") or {}
    tier_id = str(current_tier.get("id") or "") if isinstance(current_tier, dict) else ""
    project = str(resp.get("cloudaicompanionProject") or "")
    allowed = resp.get("allowedTiers") or []
    allowed_ids: List[str] = []
    if isinstance(allowed, list):
        for t in allowed:
            if isinstance(t, dict):
                tid = str(t.get("id") or "")
                if tid:
                    allowed_ids.append(tid)
    return CodeAssistProjectInfo(
        current_tier_id=tier_id,
        cloudaicompanion_project=project,
        allowed_tiers=allowed_ids,
        raw=resp,
    )


# =============================================================================
# onboard_user — provisions a new user on a tier (with LRO polling)
# =============================================================================

def onboard_user(
    access_token: str,
    *,
    tier_id: str,
    project_id: str = "",
    user_agent_model: str = "",
) -> Dict[str, Any]:
    """Call ``POST /v1internal:onboardUser`` to provision the user.

    For paid tiers, ``project_id`` is REQUIRED (raises ProjectIdRequiredError).
    For free tiers, ``project_id`` is optional — Google will assign one.

    Returns the final operation response. Polls ``/v1internal/<name>`` for up
    to ``_ONBOARDING_POLL_ATTEMPTS`` × ``_ONBOARDING_POLL_INTERVAL_SECONDS``
    (default: 12 × 5s = 1 min).
    """
    if tier_id != FREE_TIER_ID and tier_id != LEGACY_TIER_ID and not project_id:
        raise ProjectIdRequiredError(
            f"Tier {tier_id!r} requires a GCP project id. "
            "Set HERMES_GEMINI_PROJECT_ID or GOOGLE_CLOUD_PROJECT."
        )

    body: Dict[str, Any] = {
        "tierId": tier_id,
        "metadata": _client_metadata(),
    }
    if project_id:
        body["cloudaicompanionProject"] = project_id

    endpoint = CODE_ASSIST_ENDPOINT
    url = f"{endpoint}/v1internal:onboardUser"
    resp = _post_json(url, body, access_token, user_agent_model=user_agent_model)

    # Poll if LRO (long-running operation)
    if not resp.get("done"):
        op_name = resp.get("name", "")
        if not op_name:
            return resp
        for attempt in range(_ONBOARDING_POLL_ATTEMPTS):
            time.sleep(_ONBOARDING_POLL_INTERVAL_SECONDS)
            poll_url = f"{endpoint}/v1internal/{op_name}"
            try:
                poll_resp = _post_json(poll_url, {}, access_token, user_agent_model=user_agent_model)
            except CodeAssistError as exc:
                logger.warning("Onboarding poll attempt %d failed: %s", attempt + 1, exc)
                continue
            if poll_resp.get("done"):
                return poll_resp
        logger.warning("Onboarding did not complete within %d attempts", _ONBOARDING_POLL_ATTEMPTS)
    return resp


# =============================================================================
# retrieve_user_quota — for /gquota
# =============================================================================

@dataclass
class QuotaBucket:
    model_id: str
    token_type: str = ""
    remaining_fraction: float = 0.0
    reset_time_iso: str = ""
    raw: Dict[str, Any] = field(default_factory=dict)


def retrieve_user_quota(
    access_token: str,
    *,
    project_id: str = "",
    user_agent_model: str = "",
) -> List[QuotaBucket]:
    """Call ``POST /v1internal:retrieveUserQuota`` and parse ``buckets[]``."""
    body: Dict[str, Any] = {}
    if project_id:
        body["project"] = project_id
    url = f"{CODE_ASSIST_ENDPOINT}/v1internal:retrieveUserQuota"
    resp = _post_json(url, body, access_token, user_agent_model=user_agent_model)
    raw_buckets = resp.get("buckets") or []
    buckets: List[QuotaBucket] = []
    if not isinstance(raw_buckets, list):
        return buckets
    for b in raw_buckets:
        if not isinstance(b, dict):
            continue
        buckets.append(QuotaBucket(
            model_id=str(b.get("modelId") or ""),
            token_type=str(b.get("tokenType") or ""),
            remaining_fraction=float(b.get("remainingFraction") or 0.0),
            reset_time_iso=str(b.get("resetTime") or ""),
            raw=b,
        ))
    return buckets


# =============================================================================
# Project context resolution
# =============================================================================

@dataclass
class ProjectContext:
    """Resolved state for a given OAuth session."""
    project_id: str = ""           # effective project id sent on requests
    managed_project_id: str = ""   # Google-assigned project (free tier)
    tier_id: str = ""
    source: str = ""               # "env", "config", "discovered", "onboarded"


def resolve_project_context(
    access_token: str,
    *,
    configured_project_id: str = "",
    env_project_id: str = "",
    user_agent_model: str = "",
) -> ProjectContext:
    """Figure out what project id + tier to use for requests.

    Priority:
      1. If configured_project_id or env_project_id is set, use that directly
         and short-circuit (no discovery needed).
      2. Otherwise call loadCodeAssist to see what Google says.
      3. If no tier assigned yet, onboard the user (free tier default).
    """
    # Short-circuit: caller provided a project id
    if configured_project_id:
        return ProjectContext(
            project_id=configured_project_id,
            tier_id=STANDARD_TIER_ID,  # assume paid since they specified one
            source="config",
        )
    if env_project_id:
        return ProjectContext(
            project_id=env_project_id,
            tier_id=STANDARD_TIER_ID,
            source="env",
        )

    # Discover via loadCodeAssist
    info = load_code_assist(access_token, user_agent_model=user_agent_model)

    effective_project = info.cloudaicompanion_project
    tier = info.current_tier_id

    if not tier:
        # User hasn't been onboarded — provision them on free tier
        onboard_resp = onboard_user(
            access_token,
            tier_id=FREE_TIER_ID,
            project_id="",
            user_agent_model=user_agent_model,
        )
        # Re-parse from the onboard response
        response_body = onboard_resp.get("response") or {}
        if isinstance(response_body, dict):
            effective_project = (
                effective_project
                or str(response_body.get("cloudaicompanionProject") or "")
            )
        tier = FREE_TIER_ID
        source = "onboarded"
    else:
        source = "discovered"

    return ProjectContext(
        project_id=effective_project,
        managed_project_id=effective_project if tier == FREE_TIER_ID else "",
        tier_id=tier,
        source=source,
    )
