import json
import sys
import tempfile
import threading
import types
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path

import pytest


TOOLS_DIR = Path(__file__).resolve().parents[2] / "tools"


def _load_tool_module(module_name: str, filename: str):
    spec = spec_from_file_location(module_name, TOOLS_DIR / filename)
    assert spec and spec.loader
    module = module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module


def _reset_modules(prefixes: tuple[str, ...]):
    for name in list(sys.modules):
        if name.startswith(prefixes):
            sys.modules.pop(name, None)


@pytest.fixture(autouse=True)
def _restore_tool_and_agent_modules():
    """Save and restore sys.modules entries so fakes don't leak to other tests."""
    original_modules = {
        name: module
        for name, module in sys.modules.items()
        if name in ("tools", "agent", "hermes_cli")
        or name.startswith("tools.")
        or name.startswith("agent.")
        or name.startswith("hermes_cli.")
    }
    try:
        yield
    finally:
        _reset_modules(("tools", "agent", "hermes_cli"))
        sys.modules.update(original_modules)


def _install_fake_tools_package(*, credential_mounts=None):
    _reset_modules(("tools", "agent", "hermes_cli"))

    hermes_cli = types.ModuleType("hermes_cli")
    hermes_cli.__path__ = []  # type: ignore[attr-defined]
    sys.modules["hermes_cli"] = hermes_cli
    sys.modules["hermes_cli.config"] = types.SimpleNamespace(
        get_hermes_home=lambda: Path(tempfile.gettempdir()) / "hermes-home",
    )

    tools_package = types.ModuleType("tools")
    tools_package.__path__ = [str(TOOLS_DIR)]  # type: ignore[attr-defined]
    sys.modules["tools"] = tools_package

    env_package = types.ModuleType("tools.environments")
    env_package.__path__ = [str(TOOLS_DIR / "environments")]  # type: ignore[attr-defined]
    sys.modules["tools.environments"] = env_package

    interrupt_event = threading.Event()
    sys.modules["tools.interrupt"] = types.SimpleNamespace(
        set_interrupt=lambda value=True: interrupt_event.set() if value else interrupt_event.clear(),
        is_interrupted=lambda: interrupt_event.is_set(),
        _interrupt_event=interrupt_event,
    )

    class _DummyBaseEnvironment:
        def __init__(self, cwd: str, timeout: int, env=None):
            self.cwd = cwd
            self.timeout = timeout
            self.env = env or {}

        def _prepare_command(self, command: str):
            return command, None

    sys.modules["tools.environments.base"] = types.SimpleNamespace(BaseEnvironment=_DummyBaseEnvironment)
    sys.modules["tools.managed_tool_gateway"] = types.SimpleNamespace(
        resolve_managed_tool_gateway=lambda vendor: types.SimpleNamespace(
            vendor=vendor,
            gateway_origin="https://modal-gateway.example.com",
            nous_user_token="user-token",
            managed_mode=True,
        )
    )
    sys.modules["tools.credential_files"] = types.SimpleNamespace(
        get_credential_file_mounts=lambda: list(credential_mounts or []),
    )

    return interrupt_event


class _FakeResponse:
    def __init__(self, status_code: int, payload=None, text: str = ""):
        self.status_code = status_code
        self._payload = payload
        self.text = text

    def json(self):
        if isinstance(self._payload, Exception):
            raise self._payload
        return self._payload


def test_managed_modal_execute_polls_until_completed(monkeypatch):
    _install_fake_tools_package()
    managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
    modal_common = sys.modules["tools.environments.modal_utils"]

    calls = []
    poll_count = {"value": 0}

    def fake_request(method, url, headers=None, json=None, timeout=None):
        calls.append((method, url, json, timeout))
        if method == "POST" and url.endswith("/v1/sandboxes"):
            return _FakeResponse(200, {"id": "sandbox-1"})
        if method == "POST" and url.endswith("/execs"):
            return _FakeResponse(202, {"execId": json["execId"], "status": "running"})
        if method == "GET" and "/execs/" in url:
            poll_count["value"] += 1
            if poll_count["value"] == 1:
                return _FakeResponse(200, {"execId": url.rsplit("/", 1)[-1], "status": "running"})
            return _FakeResponse(200, {
                "execId": url.rsplit("/", 1)[-1],
                "status": "completed",
                "output": "hello",
                "returncode": 0,
            })
        if method == "POST" and url.endswith("/terminate"):
            return _FakeResponse(200, {"status": "terminated"})
        raise AssertionError(f"Unexpected request: {method} {url}")

    monkeypatch.setattr(managed_modal.requests, "request", fake_request)
    monkeypatch.setattr(modal_common.time, "sleep", lambda _: None)

    env = managed_modal.ManagedModalEnvironment(image="python:3.11")
    result = env.execute("echo hello")
    env.cleanup()

    assert result == {"output": "hello", "returncode": 0}
    assert any(call[0] == "POST" and call[1].endswith("/execs") for call in calls)


def test_managed_modal_create_sends_a_stable_idempotency_key(monkeypatch):
    _install_fake_tools_package()
    managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")

    create_headers = []

    def fake_request(method, url, headers=None, json=None, timeout=None):
        if method == "POST" and url.endswith("/v1/sandboxes"):
            create_headers.append(headers or {})
            return _FakeResponse(200, {"id": "sandbox-1"})
        if method == "POST" and url.endswith("/terminate"):
            return _FakeResponse(200, {"status": "terminated"})
        raise AssertionError(f"Unexpected request: {method} {url}")

    monkeypatch.setattr(managed_modal.requests, "request", fake_request)

    env = managed_modal.ManagedModalEnvironment(image="python:3.11")
    env.cleanup()

    assert len(create_headers) == 1
    assert isinstance(create_headers[0].get("x-idempotency-key"), str)
    assert create_headers[0]["x-idempotency-key"]


def test_managed_modal_execute_cancels_on_interrupt(monkeypatch):
    interrupt_event = _install_fake_tools_package()
    managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
    modal_common = sys.modules["tools.environments.modal_utils"]

    calls = []

    def fake_request(method, url, headers=None, json=None, timeout=None):
        calls.append((method, url, json, timeout))
        if method == "POST" and url.endswith("/v1/sandboxes"):
            return _FakeResponse(200, {"id": "sandbox-1"})
        if method == "POST" and url.endswith("/execs"):
            return _FakeResponse(202, {"execId": json["execId"], "status": "running"})
        if method == "GET" and "/execs/" in url:
            return _FakeResponse(200, {"execId": url.rsplit("/", 1)[-1], "status": "running"})
        if method == "POST" and url.endswith("/cancel"):
            return _FakeResponse(202, {"status": "cancelling"})
        if method == "POST" and url.endswith("/terminate"):
            return _FakeResponse(200, {"status": "terminated"})
        raise AssertionError(f"Unexpected request: {method} {url}")

    def fake_sleep(_seconds):
        interrupt_event.set()

    monkeypatch.setattr(managed_modal.requests, "request", fake_request)
    monkeypatch.setattr(modal_common.time, "sleep", fake_sleep)

    env = managed_modal.ManagedModalEnvironment(image="python:3.11")
    result = env.execute("sleep 30")
    env.cleanup()

    assert result == {
        "output": "[Command interrupted - Modal sandbox exec cancelled]",
        "returncode": 130,
    }
    assert any(call[0] == "POST" and call[1].endswith("/cancel") for call in calls)
    poll_calls = [call for call in calls if call[0] == "GET" and "/execs/" in call[1]]
    cancel_calls = [call for call in calls if call[0] == "POST" and call[1].endswith("/cancel")]
    assert poll_calls[0][3] == (1.0, 5.0)
    assert cancel_calls[0][3] == (1.0, 5.0)


def test_managed_modal_execute_returns_descriptive_error_on_missing_exec(monkeypatch):
    _install_fake_tools_package()
    managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
    modal_common = sys.modules["tools.environments.modal_utils"]

    def fake_request(method, url, headers=None, json=None, timeout=None):
        if method == "POST" and url.endswith("/v1/sandboxes"):
            return _FakeResponse(200, {"id": "sandbox-1"})
        if method == "POST" and url.endswith("/execs"):
            return _FakeResponse(202, {"execId": json["execId"], "status": "running"})
        if method == "GET" and "/execs/" in url:
            return _FakeResponse(404, {"error": "not found"}, text="not found")
        if method == "POST" and url.endswith("/terminate"):
            return _FakeResponse(200, {"status": "terminated"})
        raise AssertionError(f"Unexpected request: {method} {url}")

    monkeypatch.setattr(managed_modal.requests, "request", fake_request)
    monkeypatch.setattr(modal_common.time, "sleep", lambda _: None)

    env = managed_modal.ManagedModalEnvironment(image="python:3.11")
    result = env.execute("echo hello")
    env.cleanup()

    assert result["returncode"] == 1
    assert "not found" in result["output"].lower()


def test_managed_modal_create_and_cleanup_preserve_gateway_persistence_fields(monkeypatch):
    _install_fake_tools_package()
    managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")

    create_payloads = []
    terminate_payloads = []

    def fake_request(method, url, headers=None, json=None, timeout=None):
        if method == "POST" and url.endswith("/v1/sandboxes"):
            create_payloads.append(json)
            return _FakeResponse(200, {"id": "sandbox-1"})
        if method == "POST" and url.endswith("/terminate"):
            terminate_payloads.append(json)
            return _FakeResponse(200, {"status": "terminated"})
        raise AssertionError(f"Unexpected request: {method} {url}")

    monkeypatch.setattr(managed_modal.requests, "request", fake_request)

    env = managed_modal.ManagedModalEnvironment(
        image="python:3.11",
        task_id="task-managed-persist",
        persistent_filesystem=False,
    )
    env.cleanup()

    assert create_payloads == [{
        "image": "python:3.11",
        "cwd": "/root",
        "cpu": 1.0,
        "memoryMiB": 5120.0,
        "timeoutMs": 3_600_000,
        "idleTimeoutMs": 300_000,
        "persistentFilesystem": False,
        "logicalKey": "task-managed-persist",
    }]
    assert terminate_payloads == [{"snapshotBeforeTerminate": False}]


def test_managed_modal_rejects_host_credential_passthrough():
    _install_fake_tools_package(
        credential_mounts=[{
            "host_path": "/tmp/token.json",
            "container_path": "/root/.hermes/token.json",
        }]
    )
    managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")

    with pytest.raises(ValueError, match="credential-file passthrough"):
        managed_modal.ManagedModalEnvironment(image="python:3.11")


def test_managed_modal_execute_times_out_and_cancels(monkeypatch):
    _install_fake_tools_package()
    managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
    modal_common = sys.modules["tools.environments.modal_utils"]

    calls = []
    monotonic_values = iter([0.0, 0.0, 0.0, 12.5, 12.5])

    def fake_request(method, url, headers=None, json=None, timeout=None):
        calls.append((method, url, json, timeout))
        if method == "POST" and url.endswith("/v1/sandboxes"):
            return _FakeResponse(200, {"id": "sandbox-1"})
        if method == "POST" and url.endswith("/execs"):
            return _FakeResponse(202, {"execId": json["execId"], "status": "running"})
        if method == "GET" and "/execs/" in url:
            return _FakeResponse(200, {"execId": url.rsplit("/", 1)[-1], "status": "running"})
        if method == "POST" and url.endswith("/cancel"):
            return _FakeResponse(202, {"status": "cancelling"})
        if method == "POST" and url.endswith("/terminate"):
            return _FakeResponse(200, {"status": "terminated"})
        raise AssertionError(f"Unexpected request: {method} {url}")

    monkeypatch.setattr(managed_modal.requests, "request", fake_request)
    monkeypatch.setattr(modal_common.time, "monotonic", lambda: next(monotonic_values))
    monkeypatch.setattr(modal_common.time, "sleep", lambda _: None)

    env = managed_modal.ManagedModalEnvironment(image="python:3.11")
    result = env.execute("sleep 30", timeout=2)
    env.cleanup()

    assert result == {
        "output": "Managed Modal exec timed out after 2s",
        "returncode": 124,
    }
    assert any(call[0] == "POST" and call[1].endswith("/cancel") for call in calls)
