"""Tests for the Home Assistant tool module.

Tests real logic: entity filtering, payload building, response parsing,
handler validation, and availability gating.
"""

import json
from unittest.mock import patch

import pytest

from tools.homeassistant_tool import (
    _check_ha_available,
    _filter_and_summarize,
    _build_service_payload,
    _parse_service_response,
    _get_headers,
    _handle_get_state,
    _handle_call_service,
    _BLOCKED_DOMAINS,
    _ENTITY_ID_RE,
    _SERVICE_NAME_RE,
)


# ---------------------------------------------------------------------------
# Sample HA state data (matches real HA /api/states response shape)
# ---------------------------------------------------------------------------

SAMPLE_STATES = [
    {"entity_id": "light.bedroom", "state": "on", "attributes": {"friendly_name": "Bedroom Light", "brightness": 200}},
    {"entity_id": "light.kitchen", "state": "off", "attributes": {"friendly_name": "Kitchen Light"}},
    {"entity_id": "switch.fan", "state": "on", "attributes": {"friendly_name": "Living Room Fan"}},
    {"entity_id": "sensor.temperature", "state": "22.5", "attributes": {"friendly_name": "Kitchen Temperature", "unit_of_measurement": "C"}},
    {"entity_id": "climate.thermostat", "state": "heat", "attributes": {"friendly_name": "Main Thermostat", "current_temperature": 21}},
    {"entity_id": "binary_sensor.motion", "state": "off", "attributes": {"friendly_name": "Hallway Motion"}},
    {"entity_id": "sensor.humidity", "state": "55", "attributes": {"friendly_name": "Bedroom Humidity", "area": "bedroom"}},
]


# ---------------------------------------------------------------------------
# Entity filtering and summarization
# ---------------------------------------------------------------------------


class TestFilterAndSummarize:
    def test_no_filters_returns_all(self):
        result = _filter_and_summarize(SAMPLE_STATES)
        assert result["count"] == 7
        ids = {e["entity_id"] for e in result["entities"]}
        assert "light.bedroom" in ids
        assert "climate.thermostat" in ids

    def test_domain_filter_lights(self):
        result = _filter_and_summarize(SAMPLE_STATES, domain="light")
        assert result["count"] == 2
        for e in result["entities"]:
            assert e["entity_id"].startswith("light.")

    def test_domain_filter_sensor(self):
        result = _filter_and_summarize(SAMPLE_STATES, domain="sensor")
        assert result["count"] == 2
        ids = {e["entity_id"] for e in result["entities"]}
        assert ids == {"sensor.temperature", "sensor.humidity"}

    def test_domain_filter_no_matches(self):
        result = _filter_and_summarize(SAMPLE_STATES, domain="media_player")
        assert result["count"] == 0
        assert result["entities"] == []

    def test_area_filter_by_friendly_name(self):
        result = _filter_and_summarize(SAMPLE_STATES, area="kitchen")
        assert result["count"] == 2
        ids = {e["entity_id"] for e in result["entities"]}
        assert "light.kitchen" in ids
        assert "sensor.temperature" in ids

    def test_area_filter_by_area_attribute(self):
        result = _filter_and_summarize(SAMPLE_STATES, area="bedroom")
        ids = {e["entity_id"] for e in result["entities"]}
        # "Bedroom Light" matches via friendly_name, "Bedroom Humidity" matches via area attr
        assert "light.bedroom" in ids
        assert "sensor.humidity" in ids

    def test_area_filter_case_insensitive(self):
        result = _filter_and_summarize(SAMPLE_STATES, area="KITCHEN")
        assert result["count"] == 2

    def test_combined_domain_and_area(self):
        result = _filter_and_summarize(SAMPLE_STATES, domain="sensor", area="kitchen")
        assert result["count"] == 1
        assert result["entities"][0]["entity_id"] == "sensor.temperature"

    def test_summary_includes_friendly_name(self):
        result = _filter_and_summarize(SAMPLE_STATES, domain="climate")
        assert result["entities"][0]["friendly_name"] == "Main Thermostat"
        assert result["entities"][0]["state"] == "heat"

    def test_empty_states_list(self):
        result = _filter_and_summarize([])
        assert result["count"] == 0

    def test_missing_attributes_handled(self):
        states = [{"entity_id": "light.x", "state": "on"}]
        result = _filter_and_summarize(states)
        assert result["count"] == 1
        assert result["entities"][0]["friendly_name"] == ""


# ---------------------------------------------------------------------------
# Service payload building
# ---------------------------------------------------------------------------


class TestBuildServicePayload:
    def test_entity_id_only(self):
        payload = _build_service_payload(entity_id="light.bedroom")
        assert payload == {"entity_id": "light.bedroom"}

    def test_data_only(self):
        payload = _build_service_payload(data={"brightness": 255})
        assert payload == {"brightness": 255}

    def test_entity_id_and_data(self):
        payload = _build_service_payload(
            entity_id="light.bedroom",
            data={"brightness": 200, "color_name": "blue"},
        )
        assert payload["entity_id"] == "light.bedroom"
        assert payload["brightness"] == 200
        assert payload["color_name"] == "blue"

    def test_no_args_returns_empty(self):
        payload = _build_service_payload()
        assert payload == {}

    def test_entity_id_param_takes_precedence_over_data(self):
        payload = _build_service_payload(
            entity_id="light.a",
            data={"entity_id": "light.b"},
        )
        # explicit entity_id parameter wins over data["entity_id"]
        assert payload["entity_id"] == "light.a"


# ---------------------------------------------------------------------------
# Service response parsing
# ---------------------------------------------------------------------------


class TestParseServiceResponse:
    def test_list_response_extracts_entities(self):
        ha_response = [
            {"entity_id": "light.bedroom", "state": "on", "attributes": {}},
            {"entity_id": "light.kitchen", "state": "on", "attributes": {}},
        ]
        result = _parse_service_response("light", "turn_on", ha_response)
        assert result["success"] is True
        assert result["service"] == "light.turn_on"
        assert len(result["affected_entities"]) == 2
        assert result["affected_entities"][0]["entity_id"] == "light.bedroom"

    def test_empty_list_response(self):
        result = _parse_service_response("scene", "turn_on", [])
        assert result["success"] is True
        assert result["affected_entities"] == []

    def test_non_list_response(self):
        # Some HA services return a dict instead of a list
        result = _parse_service_response("script", "run", {"result": "ok"})
        assert result["success"] is True
        assert result["affected_entities"] == []

    def test_none_response(self):
        result = _parse_service_response("automation", "trigger", None)
        assert result["success"] is True
        assert result["affected_entities"] == []

    def test_service_name_format(self):
        result = _parse_service_response("climate", "set_temperature", [])
        assert result["service"] == "climate.set_temperature"


# ---------------------------------------------------------------------------
# Handler validation (no mocks - these paths don't reach the network)
# ---------------------------------------------------------------------------


class TestHandlerValidation:
    def test_get_state_missing_entity_id(self):
        result = json.loads(_handle_get_state({}))
        assert "error" in result
        assert "entity_id" in result["error"]

    def test_get_state_empty_entity_id(self):
        result = json.loads(_handle_get_state({"entity_id": ""}))
        assert "error" in result

    def test_call_service_missing_domain(self):
        result = json.loads(_handle_call_service({"service": "turn_on"}))
        assert "error" in result
        assert "domain" in result["error"]

    def test_call_service_missing_service(self):
        result = json.loads(_handle_call_service({"domain": "light"}))
        assert "error" in result
        assert "service" in result["error"]

    def test_call_service_missing_both(self):
        result = json.loads(_handle_call_service({}))
        assert "error" in result

    def test_call_service_empty_strings(self):
        result = json.loads(_handle_call_service({"domain": "", "service": ""}))
        assert "error" in result


# ---------------------------------------------------------------------------
# Security: domain blocklist
# ---------------------------------------------------------------------------


class TestDomainBlocklist:
    """Verify dangerous HA service domains are blocked."""

    @pytest.mark.parametrize("domain", sorted(_BLOCKED_DOMAINS))
    def test_blocked_domain_rejected(self, domain):
        result = json.loads(_handle_call_service({
            "domain": domain, "service": "any_service"
        }))
        assert "error" in result
        assert "blocked" in result["error"].lower()

    def test_safe_domain_not_blocked(self):
        """Safe domains like 'light' should not be blocked (will fail on network, not blocklist)."""
        # This will try to make a real HTTP call and fail, but the important thing
        # is it does NOT return a "blocked" error
        result = json.loads(_handle_call_service({
            "domain": "light", "service": "turn_on", "entity_id": "light.test"
        }))
        # Should fail with a network/connection error, not a "blocked" error
        if "error" in result:
            assert "blocked" not in result["error"].lower()

    def test_blocked_domains_include_shell_command(self):
        assert "shell_command" in _BLOCKED_DOMAINS

    def test_blocked_domains_include_hassio(self):
        assert "hassio" in _BLOCKED_DOMAINS

    def test_blocked_domains_include_rest_command(self):
        assert "rest_command" in _BLOCKED_DOMAINS


# ---------------------------------------------------------------------------
# Security: entity_id validation
# ---------------------------------------------------------------------------


class TestEntityIdValidation:
    """Verify entity_id format validation prevents path traversal."""

    def test_valid_entity_id_accepted(self):
        assert _ENTITY_ID_RE.match("light.bedroom")
        assert _ENTITY_ID_RE.match("sensor.temperature_1")
        assert _ENTITY_ID_RE.match("binary_sensor.motion")
        assert _ENTITY_ID_RE.match("climate.main_thermostat")

    def test_path_traversal_rejected(self):
        assert _ENTITY_ID_RE.match("../../config") is None
        assert _ENTITY_ID_RE.match("light/../../../etc/passwd") is None
        assert _ENTITY_ID_RE.match("../api/config") is None

    def test_special_chars_rejected(self):
        assert _ENTITY_ID_RE.match("light.bed room") is None  # space
        assert _ENTITY_ID_RE.match("light.bed;rm -rf") is None  # semicolon
        assert _ENTITY_ID_RE.match("light.bed/room") is None  # slash
        assert _ENTITY_ID_RE.match("LIGHT.BEDROOM") is None  # uppercase

    def test_missing_domain_rejected(self):
        assert _ENTITY_ID_RE.match(".bedroom") is None
        assert _ENTITY_ID_RE.match("bedroom") is None

    def test_get_state_rejects_invalid_entity_id(self):
        result = json.loads(_handle_get_state({"entity_id": "../../config"}))
        assert "error" in result
        assert "Invalid entity_id" in result["error"]

    def test_call_service_rejects_invalid_entity_id(self):
        result = json.loads(_handle_call_service({
            "domain": "light",
            "service": "turn_on",
            "entity_id": "../../../etc/passwd",
        }))
        assert "error" in result
        assert "Invalid entity_id" in result["error"]

    def test_call_service_allows_no_entity_id(self):
        """Some services (like scene.turn_on) don't need entity_id."""
        # Will fail on network, but should NOT fail on entity_id validation
        result = json.loads(_handle_call_service({
            "domain": "scene", "service": "turn_on"
        }))
        if "error" in result:
            assert "Invalid entity_id" not in result["error"]


# ---------------------------------------------------------------------------
# String-data deserialization (XML tool calling workaround)
# ---------------------------------------------------------------------------


class TestCallServiceStringData:
    """data param may arrive as a JSON string (XML tool calling mode)."""

    @patch("tools.homeassistant_tool._run_async", return_value={"success": True})
    def test_string_data_deserialized(self, mock_run):
        """JSON string data is parsed into a dict before dispatch."""
        _handle_call_service({
            "domain": "climate",
            "service": "set_hvac_mode",
            "entity_id": "climate.living_room",
            "data": '{"hvac_mode": "heat"}',
        })
        call_args = mock_run.call_args[0][0]  # the coroutine arg
        # _run_async was called, meaning we got past validation

    @patch("tools.homeassistant_tool._run_async", return_value={"success": True})
    def test_dict_data_passthrough(self, mock_run):
        """Dict data (JSON tool calling mode) still works unchanged."""
        _handle_call_service({
            "domain": "light",
            "service": "turn_on",
            "entity_id": "light.bedroom",
            "data": {"brightness": 255},
        })
        mock_run.assert_called_once()

    def test_invalid_json_string_returns_error(self):
        """Malformed JSON string in data returns a clear error."""
        result = json.loads(_handle_call_service({
            "domain": "light",
            "service": "turn_on",
            "entity_id": "light.bedroom",
            "data": "{not valid json}",
        }))
        assert "error" in result
        assert "Invalid JSON" in result["error"]

    @patch("tools.homeassistant_tool._run_async", return_value={"success": True})
    def test_empty_string_data_becomes_none(self, mock_run):
        """Empty/whitespace string data is treated as None."""
        _handle_call_service({
            "domain": "light",
            "service": "turn_on",
            "entity_id": "light.bedroom",
            "data": "   ",
        })
        mock_run.assert_called_once()


# ---------------------------------------------------------------------------
# Security: domain/service name format validation
# ---------------------------------------------------------------------------


class TestServiceNameValidation:
    """Verify domain/service format validation prevents path traversal in URL.

    The domain and service parameters are interpolated into
    /api/services/{domain}/{service}, so allowing arbitrary strings would
    enable SSRF via path traversal or blocked-domain bypass.
    """

    def test_valid_domain_names(self):
        assert _SERVICE_NAME_RE.match("light")
        assert _SERVICE_NAME_RE.match("switch")
        assert _SERVICE_NAME_RE.match("climate")
        assert _SERVICE_NAME_RE.match("shell_command")
        assert _SERVICE_NAME_RE.match("media_player")

    def test_valid_service_names(self):
        assert _SERVICE_NAME_RE.match("turn_on")
        assert _SERVICE_NAME_RE.match("turn_off")
        assert _SERVICE_NAME_RE.match("set_temperature")
        assert _SERVICE_NAME_RE.match("toggle")

    def test_path_traversal_in_domain_rejected(self):
        assert _SERVICE_NAME_RE.match("../../api/config") is None
        assert _SERVICE_NAME_RE.match("light/../../../etc") is None
        assert _SERVICE_NAME_RE.match("../config") is None

    def test_path_traversal_in_service_rejected(self):
        assert _SERVICE_NAME_RE.match("../../api/config") is None
        assert _SERVICE_NAME_RE.match("turn_on/../../config") is None

    def test_blocked_domain_bypass_via_traversal_rejected(self):
        """Ensure shell_command/../light is rejected, not just checked against blocklist."""
        assert _SERVICE_NAME_RE.match("shell_command/../light") is None
        assert _SERVICE_NAME_RE.match("python_script/../scene") is None
        assert _SERVICE_NAME_RE.match("hassio/../automation") is None

    def test_slashes_rejected(self):
        assert _SERVICE_NAME_RE.match("light/turn_on") is None
        assert _SERVICE_NAME_RE.match("a/b/c") is None

    def test_dots_rejected(self):
        assert _SERVICE_NAME_RE.match("light.turn_on") is None
        assert _SERVICE_NAME_RE.match("..") is None

    def test_uppercase_rejected(self):
        assert _SERVICE_NAME_RE.match("LIGHT") is None
        assert _SERVICE_NAME_RE.match("Turn_On") is None

    def test_special_chars_rejected(self):
        assert _SERVICE_NAME_RE.match("light;rm") is None
        assert _SERVICE_NAME_RE.match("light&cmd") is None
        assert _SERVICE_NAME_RE.match("light cmd") is None

    def test_handler_rejects_traversal_domain(self):
        """_handle_call_service must reject domain with path traversal."""
        result = json.loads(_handle_call_service({
            "domain": "../../api/config",
            "service": "turn_on",
        }))
        assert "error" in result
        assert "Invalid domain" in result["error"]

    def test_handler_rejects_traversal_service(self):
        """_handle_call_service must reject service with path traversal."""
        result = json.loads(_handle_call_service({
            "domain": "light",
            "service": "../../api/config",
        }))
        assert "error" in result
        assert "Invalid service" in result["error"]

    def test_handler_rejects_blocklist_bypass_traversal(self):
        """Blocklist bypass via shell_command/../light must be caught by format validation."""
        result = json.loads(_handle_call_service({
            "domain": "shell_command/../light",
            "service": "turn_on",
        }))
        assert "error" in result
        # Must be rejected as "Invalid domain", not slip through the blocklist
        assert "Invalid domain" in result["error"]


# ---------------------------------------------------------------------------
# Availability check
# ---------------------------------------------------------------------------


class TestCheckAvailable:
    def test_unavailable_without_token(self, monkeypatch):
        monkeypatch.delenv("HASS_TOKEN", raising=False)
        assert _check_ha_available() is False

    def test_available_with_token(self, monkeypatch):
        monkeypatch.setenv("HASS_TOKEN", "eyJ0eXAiOiJKV1Q")
        assert _check_ha_available() is True

    def test_empty_token_is_unavailable(self, monkeypatch):
        monkeypatch.setenv("HASS_TOKEN", "")
        assert _check_ha_available() is False


# ---------------------------------------------------------------------------
# Auth headers
# ---------------------------------------------------------------------------


class TestGetHeaders:
    def test_bearer_token_format(self, monkeypatch):
        monkeypatch.setattr("tools.homeassistant_tool._HASS_TOKEN", "my-secret-token")
        headers = _get_headers()
        assert headers["Authorization"] == "Bearer my-secret-token"
        assert headers["Content-Type"] == "application/json"


# ---------------------------------------------------------------------------
# Registry integration
# ---------------------------------------------------------------------------


class TestRegistration:
    def test_tools_registered_in_registry(self):
        from tools.registry import registry

        names = registry.get_all_tool_names()
        assert "ha_list_entities" in names
        assert "ha_get_state" in names
        assert "ha_call_service" in names

    def test_tools_in_homeassistant_toolset(self):
        from tools.registry import registry

        toolset_map = registry.get_tool_to_toolset_map()
        for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"):
            assert toolset_map[tool] == "homeassistant"

    def test_check_fn_gates_availability(self, monkeypatch):
        """Registry should exclude HA tools when HASS_TOKEN is not set."""
        from tools.registry import registry

        monkeypatch.delenv("HASS_TOKEN", raising=False)
        defs = registry.get_definitions({"ha_list_entities", "ha_get_state", "ha_call_service"})
        assert len(defs) == 0

    def test_check_fn_includes_when_token_set(self, monkeypatch):
        """Registry should include HA tools when HASS_TOKEN is set."""
        from tools.registry import registry

        monkeypatch.setenv("HASS_TOKEN", "test-token")
        defs = registry.get_definitions({"ha_list_entities", "ha_get_state", "ha_call_service"})
        assert len(defs) == 3
