"""Tests for agent/model_metadata.py — token estimation, context lengths,
probing, caching, and error parsing.

Coverage levels:
  Token estimation       — concrete value assertions, edge cases
  Context length lookup  — resolution order, fuzzy match, cache priority
  API metadata fetch     — caching, TTL, canonical slugs, stale fallback
  Probe tiers            — descending, boundaries, extreme inputs
  Error parsing          — OpenAI, Ollama, Anthropic, edge cases
  Persistent cache       — save/load, corruption, update, provider isolation
"""

import os
import time
import tempfile

import pytest
import yaml
from pathlib import Path
from unittest.mock import patch, MagicMock

from agent.model_metadata import (
    CONTEXT_PROBE_TIERS,
    DEFAULT_CONTEXT_LENGTHS,
    _strip_provider_prefix,
    estimate_tokens_rough,
    estimate_messages_tokens_rough,
    get_model_context_length,
    get_next_probe_tier,
    get_cached_context_length,
    parse_context_limit_from_error,
    save_context_length,
    fetch_model_metadata,
    _MODEL_CACHE_TTL,
)


# =========================================================================
# Token estimation
# =========================================================================

class TestEstimateTokensRough:
    def test_empty_string(self):
        assert estimate_tokens_rough("") == 0

    def test_none_returns_zero(self):
        assert estimate_tokens_rough(None) == 0

    def test_known_length(self):
        assert estimate_tokens_rough("a" * 400) == 100

    def test_short_text(self):
        # "hello" = 5 chars → ceil(5/4) = 2
        assert estimate_tokens_rough("hello") == 2

    def test_proportional(self):
        short = estimate_tokens_rough("hello world")
        long = estimate_tokens_rough("hello world " * 100)
        assert long > short

    def test_unicode_multibyte(self):
        """Unicode chars are still 1 Python char each — 4 chars/token holds."""
        text = "你好世界"  # 4 CJK characters
        assert estimate_tokens_rough(text) == 1


class TestEstimateMessagesTokensRough:
    def test_empty_list(self):
        assert estimate_messages_tokens_rough([]) == 0

    def test_single_message_concrete_value(self):
        """Verify against known str(msg) length (ceiling division)."""
        msg = {"role": "user", "content": "a" * 400}
        result = estimate_messages_tokens_rough([msg])
        n = len(str(msg))
        expected = (n + 3) // 4
        assert result == expected

    def test_multiple_messages_additive(self):
        msgs = [
            {"role": "user", "content": "Hello"},
            {"role": "assistant", "content": "Hi there, how can I help?"},
        ]
        result = estimate_messages_tokens_rough(msgs)
        n = sum(len(str(m)) for m in msgs)
        expected = (n + 3) // 4
        assert result == expected

    def test_tool_call_message(self):
        """Tool call messages with no 'content' key still contribute tokens."""
        msg = {"role": "assistant", "content": None,
               "tool_calls": [{"id": "1", "function": {"name": "terminal", "arguments": "{}"}}]}
        result = estimate_messages_tokens_rough([msg])
        assert result > 0
        assert result == (len(str(msg)) + 3) // 4

    def test_message_with_list_content(self):
        """Vision messages with multimodal content arrays."""
        msg = {"role": "user", "content": [
            {"type": "text", "text": "describe"},
            {"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}
        ]}
        result = estimate_messages_tokens_rough([msg])
        assert result == (len(str(msg)) + 3) // 4


# =========================================================================
# Default context lengths
# =========================================================================

class TestDefaultContextLengths:
    def test_claude_models_context_lengths(self):
        for key, value in DEFAULT_CONTEXT_LENGTHS.items():
            if "claude" not in key:
                continue
            # Claude 4.6+ models (4.6 and 4.7) have 1M context at standard
            # API pricing (no long-context premium).  Older Claude 4.x and
            # 3.x models cap at 200k.
            if any(tag in key for tag in ("4.6", "4-6", "4.7", "4-7")):
                assert value == 1000000, f"{key} should be 1000000"
            else:
                assert value == 200000, f"{key} should be 200000"

    def test_gpt4_models_128k_or_1m(self):
        # gpt-4.1 and gpt-4.1-mini have 1M context; other gpt-4* have 128k
        for key, value in DEFAULT_CONTEXT_LENGTHS.items():
            if "gpt-4" in key and "gpt-4.1" not in key:
                assert value == 128000, f"{key} should be 128000"

    def test_gpt41_models_1m(self):
        for key, value in DEFAULT_CONTEXT_LENGTHS.items():
            if "gpt-4.1" in key:
                assert value == 1047576, f"{key} should be 1047576"

    def test_gemini_models_1m(self):
        for key, value in DEFAULT_CONTEXT_LENGTHS.items():
            if "gemini" in key:
                assert value == 1048576, f"{key} should be 1048576"

    def test_grok_models_context_lengths(self):
        # xAI /v1/models does not return context_length metadata, so
        # DEFAULT_CONTEXT_LENGTHS must cover the Grok family explicitly.
        # Values sourced from models.dev (2026-04).
        expected = {
            "grok-4.20": 2000000,
            "grok-4-1-fast": 2000000,
            "grok-4-fast": 2000000,
            "grok-4": 256000,
            "grok-code-fast": 256000,
            "grok-3": 131072,
            "grok-2": 131072,
            "grok-2-vision": 8192,
            "grok": 131072,
        }
        for key, value in expected.items():
            assert key in DEFAULT_CONTEXT_LENGTHS, f"{key} missing from DEFAULT_CONTEXT_LENGTHS"
            assert DEFAULT_CONTEXT_LENGTHS[key] == value, (
                f"{key} should be {value}, got {DEFAULT_CONTEXT_LENGTHS[key]}"
            )

    def test_grok_substring_matching(self):
        # Longest-first substring matching must resolve the real xAI model
        # IDs to the correct fallback entries without 128k probe-down.
        from agent.model_metadata import get_model_context_length
        from unittest.mock import patch as mock_patch

        # Fake the provider/API/cache layers so the lookup falls through
        # to DEFAULT_CONTEXT_LENGTHS.
        with mock_patch("agent.model_metadata.fetch_model_metadata", return_value={}),              mock_patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}),              mock_patch("agent.model_metadata.get_cached_context_length", return_value=None):
            cases = [
                ("grok-4.20-0309-reasoning", 2000000),
                ("grok-4.20-0309-non-reasoning", 2000000),
                ("grok-4.20-multi-agent-0309", 2000000),
                ("grok-4-1-fast-reasoning", 2000000),
                ("grok-4-1-fast-non-reasoning", 2000000),
                ("grok-4-fast-reasoning", 2000000),
                ("grok-4-fast-non-reasoning", 2000000),
                ("grok-4", 256000),
                ("grok-4-0709", 256000),
                ("grok-code-fast-1", 256000),
                ("grok-3", 131072),
                ("grok-3-mini", 131072),
                ("grok-3-mini-fast", 131072),
                ("grok-2", 131072),
                ("grok-2-vision", 8192),
                ("grok-2-vision-1212", 8192),
                ("grok-beta", 131072),
            ]
            for model_id, expected_ctx in cases:
                actual = get_model_context_length(model_id)
                assert actual == expected_ctx, (
                    f"{model_id}: expected {expected_ctx}, got {actual}"
                )

    def test_deepseek_v4_models_1m_context(self):
        from agent.model_metadata import get_model_context_length
        from unittest.mock import patch as mock_patch

        expected_keys = {
            "deepseek-v4-pro": 1_000_000,
            "deepseek-v4-flash": 1_000_000,
            "deepseek-chat": 1_000_000,
            "deepseek-reasoner": 1_000_000,
        }
        for key, value in expected_keys.items():
            assert key in DEFAULT_CONTEXT_LENGTHS, f"{key} missing"
            assert DEFAULT_CONTEXT_LENGTHS[key] == value, (
                f"{key} should be {value}, got {DEFAULT_CONTEXT_LENGTHS[key]}"
            )

        # Longest-first substring matching must resolve both the bare V4
        # ids (native DeepSeek) and the vendor-prefixed forms (OpenRouter
        # / Nous Portal) to 1M without probing down to the legacy 128K
        # ``deepseek`` substring fallback.
        with mock_patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
             mock_patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
             mock_patch("agent.model_metadata.get_cached_context_length", return_value=None):
            cases = [
                ("deepseek-v4-pro", 1_000_000),
                ("deepseek-v4-flash", 1_000_000),
                ("deepseek/deepseek-v4-pro", 1_000_000),
                ("deepseek/deepseek-v4-flash", 1_000_000),
                ("deepseek-chat", 1_000_000),
                ("deepseek-reasoner", 1_000_000),
            ]
            for model_id, expected_ctx in cases:
                actual = get_model_context_length(model_id)
                assert actual == expected_ctx, (
                    f"{model_id}: expected {expected_ctx}, got {actual}"
                )

    def test_all_values_positive(self):
        for key, value in DEFAULT_CONTEXT_LENGTHS.items():
            assert value > 0, f"{key} has non-positive context length"

    def test_dict_is_not_empty(self):
        assert len(DEFAULT_CONTEXT_LENGTHS) >= 10


# =========================================================================
# Codex OAuth context-window resolution (provider="openai-codex")
# =========================================================================

class TestCodexOAuthContextLength:
    """ChatGPT Codex OAuth imposes lower context limits than the direct
    OpenAI API for the same slugs. Verified Apr 2026 via live probe of
    chatgpt.com/backend-api/codex/models: every model returns 272k, while
    models.dev reports 1.05M for gpt-5.5/gpt-5.4 and 400k for the rest.
    """

    def setup_method(self):
        import agent.model_metadata as mm
        mm._codex_oauth_context_cache = {}
        mm._codex_oauth_context_cache_time = 0.0

    def test_fallback_table_used_without_token(self):
        """With no access token, the hardcoded Codex fallback table wins
        over models.dev (which reports 1.05M for gpt-5.5 but Codex is 272k).
        """
        from agent.model_metadata import get_model_context_length

        with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
             patch("agent.model_metadata.save_context_length"):
            for model in (
                "gpt-5.5",
                "gpt-5.4",
                "gpt-5.4-mini",
                "gpt-5.3-codex",
                "gpt-5.2-codex",
                "gpt-5.1-codex-max",
                "gpt-5.1-codex-mini",
            ):
                ctx = get_model_context_length(
                    model=model,
                    base_url="https://chatgpt.com/backend-api/codex",
                    api_key="",
                    provider="openai-codex",
                )
                assert ctx == 272_000, (
                    f"Codex {model}: expected 272000 fallback, got {ctx} "
                    "(models.dev leakage?)"
                )

    def test_live_probe_overrides_fallback(self):
        """When a token is provided, the live /models probe is preferred
        and its context_window drives the result."""
        from agent.model_metadata import get_model_context_length

        fake_response = MagicMock()
        fake_response.status_code = 200
        fake_response.json.return_value = {
            "models": [
                {"slug": "gpt-5.5", "context_window": 300_000},
                {"slug": "gpt-5.4", "context_window": 400_000},
            ]
        }

        with patch("agent.model_metadata.requests.get", return_value=fake_response), \
             patch("agent.model_metadata.get_cached_context_length", return_value=None), \
             patch("agent.model_metadata.save_context_length"):
            ctx_55 = get_model_context_length(
                model="gpt-5.5",
                base_url="https://chatgpt.com/backend-api/codex",
                api_key="fake-token",
                provider="openai-codex",
            )
            ctx_54 = get_model_context_length(
                model="gpt-5.4",
                base_url="https://chatgpt.com/backend-api/codex",
                api_key="fake-token",
                provider="openai-codex",
            )
        assert ctx_55 == 300_000
        assert ctx_54 == 400_000

    def test_probe_failure_falls_back_to_hardcoded(self):
        """If the probe fails (non-200 / network error), we still return
        the hardcoded 272k rather than leaking through to models.dev 1.05M."""
        from agent.model_metadata import get_model_context_length

        fake_response = MagicMock()
        fake_response.status_code = 401
        fake_response.json.return_value = {}

        with patch("agent.model_metadata.requests.get", return_value=fake_response), \
             patch("agent.model_metadata.get_cached_context_length", return_value=None), \
             patch("agent.model_metadata.save_context_length"):
            ctx = get_model_context_length(
                model="gpt-5.5",
                base_url="https://chatgpt.com/backend-api/codex",
                api_key="expired-token",
                provider="openai-codex",
            )
        assert ctx == 272_000

    def test_non_codex_providers_unaffected(self):
        """Resolving gpt-5.5 on non-Codex providers must NOT use the Codex
        272k override — OpenRouter / direct OpenAI API have different limits.
        """
        from agent.model_metadata import get_model_context_length

        # OpenRouter — should hit its own catalog path first; when mocked
        # empty, falls through to hardcoded DEFAULT_CONTEXT_LENGTHS (1.05M,
        # matching the real direct-API value — Codex OAuth's 272k cap is
        # provider-specific and must not leak here).
        with patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
             patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
             patch("agent.model_metadata.get_cached_context_length", return_value=None), \
             patch("agent.models_dev.lookup_models_dev_context", return_value=None):
            ctx = get_model_context_length(
                model="openai/gpt-5.5",
                base_url="https://openrouter.ai/api/v1",
                api_key="",
                provider="openrouter",
            )
        assert ctx == 1_050_000, (
            f"Non-Codex gpt-5.5 resolved to {ctx}; Codex 272k override "
            "leaked outside openai-codex provider"
        )

    def test_stale_codex_cache_over_400k_is_invalidated(self, tmp_path, monkeypatch):
        """Pre-PR #14935 builds cached gpt-5.5 at 1.05M (from models.dev)
        before the Codex-aware branch existed. Upgrading users keep that
        stale entry on disk and the cache-first lookup returns it forever.
        Codex OAuth caps at 272k for every slug, so any cached Codex
        entry >= 400k must be dropped and re-resolved via the live probe.
        """
        from agent import model_metadata as mm

        # Isolate the cache file to tmp_path
        cache_file = tmp_path / "context_length_cache.yaml"
        monkeypatch.setattr(mm, "_get_context_cache_path", lambda: cache_file)

        base_url = "https://chatgpt.com/backend-api/codex/"
        stale_key = f"gpt-5.5@{base_url}"
        other_key = "other-model@https://api.openai.com/v1/"
        import yaml as _yaml
        cache_file.write_text(_yaml.dump({"context_lengths": {
            stale_key: 1_050_000,   # stale pre-fix value
            other_key: 128_000,     # unrelated, must survive
        }}))

        fake_response = MagicMock()
        fake_response.status_code = 200
        fake_response.json.return_value = {
            "models": [{"slug": "gpt-5.5", "context_window": 272_000}]
        }

        with patch("agent.model_metadata.requests.get", return_value=fake_response), \
             patch("agent.model_metadata.save_context_length") as mock_save:
            ctx = mm.get_model_context_length(
                model="gpt-5.5",
                base_url=base_url,
                api_key="fake-token",
                provider="openai-codex",
            )

        assert ctx == 272_000, f"Stale entry should have been re-resolved to 272k, got {ctx}"
        # Live save was called with the fresh value
        mock_save.assert_called_with("gpt-5.5", base_url, 272_000)
        # The stale entry was removed from disk; unrelated entries survived
        remaining = _yaml.safe_load(cache_file.read_text()).get("context_lengths", {})
        assert stale_key not in remaining, "Stale entry was not invalidated from the cache file"
        assert remaining.get(other_key) == 128_000, "Unrelated cache entries must not be touched"

    def test_fresh_codex_cache_under_400k_is_respected(self, tmp_path, monkeypatch):
        """Codex entries at the correct 272k must NOT be invalidated —
        only stale pre-fix values (>= 400k) get dropped."""
        from agent import model_metadata as mm

        cache_file = tmp_path / "context_length_cache.yaml"
        monkeypatch.setattr(mm, "_get_context_cache_path", lambda: cache_file)

        base_url = "https://chatgpt.com/backend-api/codex/"
        import yaml as _yaml
        cache_file.write_text(_yaml.dump({"context_lengths": {
            f"gpt-5.5@{base_url}": 272_000,
        }}))

        # If the invalidation incorrectly fired, this would be called; assert it isn't.
        with patch("agent.model_metadata.requests.get") as mock_get:
            ctx = mm.get_model_context_length(
                model="gpt-5.5",
                base_url=base_url,
                api_key="fake-token",
                provider="openai-codex",
            )
        assert ctx == 272_000
        mock_get.assert_not_called()

    def test_stale_invalidation_scoped_to_codex_provider(self, tmp_path, monkeypatch):
        """A cached 1M entry for a non-Codex provider (e.g. Anthropic opus on
        OpenRouter, legitimately 1M) must NOT be invalidated by this guard."""
        from agent import model_metadata as mm

        cache_file = tmp_path / "context_length_cache.yaml"
        monkeypatch.setattr(mm, "_get_context_cache_path", lambda: cache_file)

        base_url = "https://openrouter.ai/api/v1"
        import yaml as _yaml
        cache_file.write_text(_yaml.dump({"context_lengths": {
            f"anthropic/claude-opus-4.6@{base_url}": 1_000_000,
        }}))

        ctx = mm.get_model_context_length(
            model="anthropic/claude-opus-4.6",
            base_url=base_url,
            api_key="fake",
            provider="openrouter",
        )
        assert ctx == 1_000_000, "Non-codex 1M cache entries must be respected"


# =========================================================================
# get_model_context_length — resolution order
# =========================================================================

class TestGetModelContextLength:
    @patch("agent.model_metadata.fetch_model_metadata")
    def test_known_model_from_api(self, mock_fetch):
        mock_fetch.return_value = {
            "test/model": {"context_length": 32000}
        }
        assert get_model_context_length("test/model") == 32000

    @patch("agent.model_metadata.fetch_model_metadata")
    def test_fallback_to_defaults(self, mock_fetch):
        mock_fetch.return_value = {}
        assert get_model_context_length("anthropic/claude-sonnet-4") == 200000

    @patch("agent.model_metadata.fetch_model_metadata")
    def test_unknown_model_returns_first_probe_tier(self, mock_fetch):
        mock_fetch.return_value = {}
        assert get_model_context_length("unknown/never-heard-of-this") == CONTEXT_PROBE_TIERS[0]

    @patch("agent.model_metadata.fetch_model_metadata")
    def test_partial_match_in_defaults(self, mock_fetch):
        mock_fetch.return_value = {}
        assert get_model_context_length("openai/gpt-4o") == 128000

    @patch("agent.model_metadata.fetch_model_metadata")
    def test_qwen3_coder_plus_context_length(self, mock_fetch):
        """qwen3-coder-plus has a 1M context window, not the generic 128K Qwen default."""
        mock_fetch.return_value = {}
        assert get_model_context_length("qwen3-coder-plus") == 1000000

    @patch("agent.model_metadata.fetch_model_metadata")
    def test_qwen3_coder_context_length(self, mock_fetch):
        """qwen3-coder has a 256K context window, not the generic 128K Qwen default."""
        mock_fetch.return_value = {}
        assert get_model_context_length("qwen3-coder") == 262144

    @patch("agent.model_metadata.fetch_model_metadata")
    def test_qwen_generic_context_length(self, mock_fetch):
        """Generic qwen models still get the 128K default."""
        mock_fetch.return_value = {}
        assert get_model_context_length("qwen3-plus") == 131072

    @patch("agent.model_metadata.fetch_model_metadata")
    def test_api_missing_context_length_key(self, mock_fetch):
        """Model in API but without context_length → defaults to the top
        probe tier (currently 256K)."""
        mock_fetch.return_value = {"test/model": {"name": "Test"}}
        assert get_model_context_length("test/model") == CONTEXT_PROBE_TIERS[0]

    @patch("agent.model_metadata.fetch_model_metadata")
    def test_cache_takes_priority_over_api(self, mock_fetch, tmp_path):
        """Persistent cache should be checked BEFORE API metadata."""
        mock_fetch.return_value = {"my/model": {"context_length": 999999}}
        cache_file = tmp_path / "cache.yaml"
        with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
            save_context_length("my/model", "http://local", 32768)
            result = get_model_context_length("my/model", base_url="http://local")
            assert result == 32768  # cache wins over API's 999999

    @patch("agent.model_metadata.fetch_model_metadata")
    def test_no_base_url_skips_cache(self, mock_fetch, tmp_path):
        """Without base_url, cache lookup is skipped."""
        mock_fetch.return_value = {}
        cache_file = tmp_path / "cache.yaml"
        with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
            save_context_length("custom/model", "http://local", 32768)
            # No base_url → cache skipped → falls to probe tier
            result = get_model_context_length("custom/model")
            assert result == CONTEXT_PROBE_TIERS[0]

    @patch("agent.model_metadata.fetch_model_metadata")
    @patch("agent.model_metadata.fetch_endpoint_model_metadata")
    def test_custom_endpoint_metadata_beats_fuzzy_default(self, mock_endpoint_fetch, mock_fetch):
        mock_fetch.return_value = {}
        mock_endpoint_fetch.return_value = {
            "zai-org/GLM-5-TEE": {"context_length": 65536}
        }

        result = get_model_context_length(
            "zai-org/GLM-5-TEE",
            base_url="https://llm.chutes.ai/v1",
            api_key="test-key",
        )

        assert result == 65536

    @patch("agent.model_metadata.fetch_model_metadata")
    @patch("agent.model_metadata.fetch_endpoint_model_metadata")
    def test_custom_endpoint_without_metadata_skips_name_based_default(self, mock_endpoint_fetch, mock_fetch):
        mock_fetch.return_value = {}
        mock_endpoint_fetch.return_value = {}

        result = get_model_context_length(
            "zai-org/GLM-5-TEE",
            base_url="https://llm.chutes.ai/v1",
            api_key="test-key",
        )

        assert result == CONTEXT_PROBE_TIERS[0]

    @patch("agent.model_metadata.fetch_model_metadata")
    @patch("agent.model_metadata.fetch_endpoint_model_metadata")
    def test_custom_endpoint_single_model_fallback(self, mock_endpoint_fetch, mock_fetch):
        """Single-model servers: use the only model even if name doesn't match."""
        mock_fetch.return_value = {}
        mock_endpoint_fetch.return_value = {
            "Qwen3.5-9B-Q4_K_M.gguf": {"context_length": 131072}
        }

        result = get_model_context_length(
            "qwen3.5:9b",
            base_url="http://myserver.example.com:8080/v1",
            api_key="test-key",
        )

        assert result == 131072

    @patch("agent.model_metadata.fetch_model_metadata")
    @patch("agent.model_metadata.fetch_endpoint_model_metadata")
    def test_custom_endpoint_fuzzy_substring_match(self, mock_endpoint_fetch, mock_fetch):
        """Fuzzy match: configured model name is substring of endpoint model."""
        mock_fetch.return_value = {}
        mock_endpoint_fetch.return_value = {
            "org/llama-3.3-70b-instruct-fp8": {"context_length": 131072},
            "org/qwen-2.5-72b": {"context_length": 32768},
        }

        result = get_model_context_length(
            "llama-3.3-70b-instruct",
            base_url="http://myserver.example.com:8080/v1",
            api_key="test-key",
        )

        assert result == 131072

    @patch("agent.model_metadata.fetch_model_metadata")
    def test_config_context_length_overrides_all(self, mock_fetch):
        """Explicit config_context_length takes priority over everything."""
        mock_fetch.return_value = {
            "test/model": {"context_length": 200000}
        }

        result = get_model_context_length(
            "test/model",
            config_context_length=65536,
        )

        assert result == 65536

    @patch("agent.model_metadata.fetch_model_metadata")
    def test_config_context_length_zero_is_ignored(self, mock_fetch):
        """config_context_length=0 should be treated as unset."""
        mock_fetch.return_value = {}

        result = get_model_context_length(
            "anthropic/claude-sonnet-4",
            config_context_length=0,
        )

        assert result == 200000

    @patch("agent.model_metadata.fetch_model_metadata")
    def test_config_context_length_none_is_ignored(self, mock_fetch):
        """config_context_length=None should be treated as unset."""
        mock_fetch.return_value = {}

        result = get_model_context_length(
            "anthropic/claude-sonnet-4",
            config_context_length=None,
        )

        assert result == 200000


# =========================================================================
# Bedrock context resolution — must run BEFORE custom-endpoint probe
# =========================================================================

class TestBedrockContextResolution:
    """Regression tests for Bedrock context-length resolution order.

    Bug: because ``bedrock-runtime.<region>.amazonaws.com`` is not listed in
    ``_URL_TO_PROVIDER``, ``_is_known_provider_base_url`` returned False and
    the custom-endpoint probe at step 2 ran first — fetching ``/models`` from
    Bedrock (which it doesn't serve), returning the 128K default-fallback
    before execution ever reached the Bedrock branch.

    Fix: promote the Bedrock branch ahead of the custom-endpoint probe.
    """

    @patch("agent.model_metadata.fetch_endpoint_model_metadata")
    def test_bedrock_provider_returns_static_table_before_probe(self, mock_fetch):
        """provider='bedrock' resolves via static table, bypasses /models probe."""
        ctx = get_model_context_length(
            "anthropic.claude-opus-4-v1:0",
            provider="bedrock",
            base_url="https://bedrock-runtime.us-east-1.amazonaws.com",
        )
        # Must return the static Bedrock table value (200K for Claude),
        # NOT DEFAULT_FALLBACK_CONTEXT (128K).
        assert ctx == 200000
        mock_fetch.assert_not_called()

    @patch("agent.model_metadata.fetch_endpoint_model_metadata")
    def test_bedrock_url_without_provider_hint(self, mock_fetch):
        """bedrock-runtime host infers Bedrock even when provider is omitted."""
        ctx = get_model_context_length(
            "anthropic.claude-sonnet-4-v1:0",
            base_url="https://bedrock-runtime.us-west-2.amazonaws.com",
        )
        assert ctx == 200000
        mock_fetch.assert_not_called()

    @patch("agent.model_metadata.fetch_endpoint_model_metadata")
    def test_non_bedrock_url_still_probes(self, mock_fetch):
        """Non-Bedrock hosts still reach the custom-endpoint probe."""
        mock_fetch.return_value = {"some-model": {"context_length": 50000}}
        ctx = get_model_context_length(
            "some-model",
            base_url="https://api.example.com/v1",
        )
        assert ctx == 50000
        assert mock_fetch.called


# =========================================================================
# _strip_provider_prefix — Ollama model:tag vs provider:model
# =========================================================================

class TestStripProviderPrefix:
    def test_known_provider_prefix_is_stripped(self):
        assert _strip_provider_prefix("local:my-model") == "my-model"
        assert _strip_provider_prefix("openrouter:anthropic/claude-sonnet-4") == "anthropic/claude-sonnet-4"
        assert _strip_provider_prefix("anthropic:claude-sonnet-4") == "claude-sonnet-4"
        assert _strip_provider_prefix("stepfun:step-3.5-flash") == "step-3.5-flash"

    def test_ollama_model_tag_preserved(self):
        """Ollama model:tag format must NOT be stripped."""
        assert _strip_provider_prefix("qwen3.5:27b") == "qwen3.5:27b"
        assert _strip_provider_prefix("llama3.3:70b") == "llama3.3:70b"
        assert _strip_provider_prefix("gemma2:9b") == "gemma2:9b"
        assert _strip_provider_prefix("codellama:13b-instruct-q4_0") == "codellama:13b-instruct-q4_0"

    def test_http_urls_preserved(self):
        assert _strip_provider_prefix("http://example.com") == "http://example.com"
        assert _strip_provider_prefix("https://example.com") == "https://example.com"

    def test_no_colon_returns_unchanged(self):
        assert _strip_provider_prefix("gpt-4o") == "gpt-4o"
        assert _strip_provider_prefix("anthropic/claude-sonnet-4") == "anthropic/claude-sonnet-4"

    @patch("agent.model_metadata.fetch_model_metadata")
    def test_ollama_model_tag_not_mangled_in_context_lookup(self, mock_fetch):
        """Ensure 'qwen3.5:27b' is NOT reduced to '27b' during context length lookup.

        We mock a custom endpoint that knows 'qwen3.5:27b' — the full name
        must reach the endpoint metadata lookup intact.
        """
        mock_fetch.return_value = {}
        with patch("agent.model_metadata.fetch_endpoint_model_metadata") as mock_ep, \
             patch("agent.model_metadata._is_custom_endpoint", return_value=True):
            mock_ep.return_value = {"qwen3.5:27b": {"context_length": 32768}}
            result = get_model_context_length(
                "qwen3.5:27b",
                base_url="http://localhost:11434/v1",
            )
        assert result == 32768


# =========================================================================
# fetch_model_metadata — caching, TTL, slugs, failures
# =========================================================================

class TestFetchModelMetadata:
    def _reset_cache(self):
        import agent.model_metadata as mm
        mm._model_metadata_cache = {}
        mm._model_metadata_cache_time = 0

    @patch("agent.model_metadata.requests.get")
    def test_caches_result(self, mock_get):
        self._reset_cache()
        mock_response = MagicMock()
        mock_response.json.return_value = {
            "data": [{"id": "test/model", "context_length": 99999, "name": "Test"}]
        }
        mock_response.raise_for_status = MagicMock()
        mock_get.return_value = mock_response

        result1 = fetch_model_metadata(force_refresh=True)
        assert "test/model" in result1
        assert mock_get.call_count == 1

        result2 = fetch_model_metadata()
        assert "test/model" in result2
        assert mock_get.call_count == 1  # cached

    @patch("agent.model_metadata.requests.get")
    def test_api_failure_returns_empty_on_cold_cache(self, mock_get):
        self._reset_cache()
        mock_get.side_effect = Exception("Network error")
        result = fetch_model_metadata(force_refresh=True)
        assert result == {}

    @patch("agent.model_metadata.requests.get")
    def test_api_failure_returns_stale_cache(self, mock_get):
        """On API failure with existing cache, stale data is returned."""
        import agent.model_metadata as mm
        mm._model_metadata_cache = {"old/model": {"context_length": 50000}}
        mm._model_metadata_cache_time = 0  # expired

        mock_get.side_effect = Exception("Network error")
        result = fetch_model_metadata(force_refresh=True)
        assert "old/model" in result
        assert result["old/model"]["context_length"] == 50000

    @patch("agent.model_metadata.requests.get")
    def test_canonical_slug_aliasing(self, mock_get):
        """Models with canonical_slug get indexed under both IDs."""
        self._reset_cache()
        mock_response = MagicMock()
        mock_response.json.return_value = {
            "data": [{
                "id": "anthropic/claude-3.5-sonnet:beta",
                "canonical_slug": "anthropic/claude-3.5-sonnet",
                "context_length": 200000,
                "name": "Claude 3.5 Sonnet"
            }]
        }
        mock_response.raise_for_status = MagicMock()
        mock_get.return_value = mock_response

        result = fetch_model_metadata(force_refresh=True)
        # Both the original ID and canonical slug should work
        assert "anthropic/claude-3.5-sonnet:beta" in result
        assert "anthropic/claude-3.5-sonnet" in result
        assert result["anthropic/claude-3.5-sonnet"]["context_length"] == 200000

    @patch("agent.model_metadata.requests.get")
    def test_provider_prefixed_models_get_bare_aliases(self, mock_get):
        self._reset_cache()
        mock_response = MagicMock()
        mock_response.json.return_value = {
            "data": [{
                "id": "provider/test-model",
                "context_length": 123456,
                "name": "Provider: Test Model",
            }]
        }
        mock_response.raise_for_status = MagicMock()
        mock_get.return_value = mock_response

        result = fetch_model_metadata(force_refresh=True)

        assert result["provider/test-model"]["context_length"] == 123456
        assert result["test-model"]["context_length"] == 123456

    @patch("agent.model_metadata.requests.get")
    def test_ttl_expiry_triggers_refetch(self, mock_get):
        """Cache expires after _MODEL_CACHE_TTL seconds."""
        import agent.model_metadata as mm
        self._reset_cache()

        mock_response = MagicMock()
        mock_response.json.return_value = {
            "data": [{"id": "m1", "context_length": 1000, "name": "M1"}]
        }
        mock_response.raise_for_status = MagicMock()
        mock_get.return_value = mock_response

        fetch_model_metadata(force_refresh=True)
        assert mock_get.call_count == 1

        # Simulate TTL expiry
        mm._model_metadata_cache_time = time.time() - _MODEL_CACHE_TTL - 1
        fetch_model_metadata()
        assert mock_get.call_count == 2  # refetched

    @patch("agent.model_metadata.requests.get")
    def test_malformed_json_no_data_key(self, mock_get):
        """API returns JSON without 'data' key — empty cache, no crash."""
        self._reset_cache()
        mock_response = MagicMock()
        mock_response.json.return_value = {"error": "something"}
        mock_response.raise_for_status = MagicMock()
        mock_get.return_value = mock_response

        result = fetch_model_metadata(force_refresh=True)
        assert result == {}


# =========================================================================
# Context probe tiers
# =========================================================================

class TestContextProbeTiers:
    def test_tiers_descending(self):
        for i in range(len(CONTEXT_PROBE_TIERS) - 1):
            assert CONTEXT_PROBE_TIERS[i] > CONTEXT_PROBE_TIERS[i + 1]

    def test_first_tier_is_256k(self):
        assert CONTEXT_PROBE_TIERS[0] == 256_000

    def test_last_tier_is_8k(self):
        assert CONTEXT_PROBE_TIERS[-1] == 8_000


class TestGetNextProbeTier:
    def test_from_256k(self):
        assert get_next_probe_tier(256_000) == 128_000

    def test_from_128k(self):
        assert get_next_probe_tier(128_000) == 64_000

    def test_from_64k(self):
        assert get_next_probe_tier(64_000) == 32_000

    def test_from_32k(self):
        assert get_next_probe_tier(32_000) == 16_000

    def test_from_8k_returns_none(self):
        assert get_next_probe_tier(8_000) is None

    def test_from_below_min_returns_none(self):
        assert get_next_probe_tier(4_000) is None

    def test_from_arbitrary_value(self):
        assert get_next_probe_tier(100_000) == 64_000

    def test_above_max_tier(self):
        """Value above 256K should return 256K."""
        assert get_next_probe_tier(500_000) == 256_000

    def test_zero_returns_none(self):
        assert get_next_probe_tier(0) is None


# =========================================================================
# Error message parsing
# =========================================================================

class TestParseContextLimitFromError:
    def test_openai_format(self):
        msg = "This model's maximum context length is 32768 tokens. However, your messages resulted in 45000 tokens."
        assert parse_context_limit_from_error(msg) == 32768

    def test_context_length_exceeded(self):
        msg = "context_length_exceeded: maximum context length is 131072"
        assert parse_context_limit_from_error(msg) == 131072

    def test_context_size_exceeded(self):
        msg = "Maximum context size 65536 exceeded"
        assert parse_context_limit_from_error(msg) == 65536

    def test_no_limit_in_message(self):
        assert parse_context_limit_from_error("Something went wrong with the API") is None

    def test_unreasonable_small_number_rejected(self):
        assert parse_context_limit_from_error("context length is 42 tokens") is None

    def test_ollama_format(self):
        msg = "Context size has been exceeded. Maximum context size is 32768"
        assert parse_context_limit_from_error(msg) == 32768

    def test_anthropic_format(self):
        msg = "prompt is too long: 250000 tokens > 200000 maximum"
        # Should extract 200000 (the limit), not 250000 (the input size)
        assert parse_context_limit_from_error(msg) == 200000

    def test_lmstudio_format(self):
        msg = "Error: context window of 4096 tokens exceeded"
        assert parse_context_limit_from_error(msg) == 4096

    def test_minimax_delta_only_message_returns_none(self):
        msg = "invalid params, context window exceeds limit (2013)"
        assert parse_context_limit_from_error(msg) is None

    def test_completely_unrelated_error(self):
        assert parse_context_limit_from_error("Invalid API key") is None

    def test_empty_string(self):
        assert parse_context_limit_from_error("") is None

    def test_number_outside_reasonable_range(self):
        """Very large number (>10M) should be rejected."""
        msg = "maximum context length is 99999999999"
        assert parse_context_limit_from_error(msg) is None


# =========================================================================
# Persistent context length cache
# =========================================================================

class TestContextLengthCache:
    def test_save_and_load(self, tmp_path):
        cache_file = tmp_path / "cache.yaml"
        with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
            save_context_length("test/model", "http://localhost:8080/v1", 32768)
            assert get_cached_context_length("test/model", "http://localhost:8080/v1") == 32768

    def test_missing_cache_returns_none(self, tmp_path):
        cache_file = tmp_path / "nonexistent.yaml"
        with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
            assert get_cached_context_length("test/model", "http://x") is None

    def test_multiple_models_cached(self, tmp_path):
        cache_file = tmp_path / "cache.yaml"
        with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
            save_context_length("model-a", "http://a", 64000)
            save_context_length("model-b", "http://b", 128000)
            assert get_cached_context_length("model-a", "http://a") == 64000
            assert get_cached_context_length("model-b", "http://b") == 128000

    def test_same_model_different_providers(self, tmp_path):
        cache_file = tmp_path / "cache.yaml"
        with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
            save_context_length("llama-3", "http://local:8080", 32768)
            save_context_length("llama-3", "https://openrouter.ai/api/v1", 131072)
            assert get_cached_context_length("llama-3", "http://local:8080") == 32768
            assert get_cached_context_length("llama-3", "https://openrouter.ai/api/v1") == 131072

    def test_idempotent_save(self, tmp_path):
        cache_file = tmp_path / "cache.yaml"
        with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
            save_context_length("model", "http://x", 32768)
            save_context_length("model", "http://x", 32768)
            with open(cache_file) as f:
                data = yaml.safe_load(f)
            assert len(data["context_lengths"]) == 1

    def test_update_existing_value(self, tmp_path):
        """Saving a different value for the same key overwrites it."""
        cache_file = tmp_path / "cache.yaml"
        with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
            save_context_length("model", "http://x", 128000)
            save_context_length("model", "http://x", 64000)
            assert get_cached_context_length("model", "http://x") == 64000

    def test_corrupted_yaml_returns_empty(self, tmp_path):
        """Corrupted cache file is handled gracefully."""
        cache_file = tmp_path / "cache.yaml"
        cache_file.write_text("{{{{not valid yaml: [[[")
        with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
            assert get_cached_context_length("model", "http://x") is None

    def test_wrong_structure_returns_none(self, tmp_path):
        """YAML that loads but has wrong structure."""
        cache_file = tmp_path / "cache.yaml"
        cache_file.write_text("just_a_string\n")
        with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
            assert get_cached_context_length("model", "http://x") is None

    @patch("agent.model_metadata.fetch_model_metadata")
    def test_cached_value_takes_priority(self, mock_fetch, tmp_path):
        mock_fetch.return_value = {}
        cache_file = tmp_path / "cache.yaml"
        with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
            save_context_length("unknown/model", "http://local", 65536)
            assert get_model_context_length("unknown/model", base_url="http://local") == 65536

    def test_special_chars_in_model_name(self, tmp_path):
        """Model names with colons, slashes, etc. don't break the cache."""
        cache_file = tmp_path / "cache.yaml"
        model = "anthropic/claude-3.5-sonnet:beta"
        url = "https://api.example.com/v1"
        with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
            save_context_length(model, url, 200000)
            assert get_cached_context_length(model, url) == 200000
