"""
Lightweight message status publisher for asyncclient.

Publishes MESSAGE_STATUS events to the local proxy which relays them
to the EMQX broker via MQTT.

Each call to report() submits an HTTP POST to a thread pool — no
batching or internal queue.

Usage::

    publisher = MessageStatusPublisher()

    message_id_gen = Gen()

    msg = {...}
    message_id_gen.enrich(msg)   # adds message_reporter_id / message_reporter_increment
    publisher.report(msg, reporter_id_gen)
"""

import atexit
import concurrent.futures
import json
import logging
import os
import threading
import time
import urllib.error
import urllib.request
import uuid

from defence360agent.internals.feature_flags import is_enabled

logger = logging.getLogger(__name__)

_IAID_PATH = "/var/imunify360/iaid"
_PROXY_URL = os.environ.get("IMUNIFY_PROXY_URL", "http://127.0.0.1:11234")
_PUBLISH_ENDPOINT = _PROXY_URL.rstrip("/") + "/api/v1/mqtt-publish"
# Shared secret for proxy APIKey middleware; must match
# IMUNIFY_PROXY_API_KEY on the proxy side (see src/proxy/auth/jwt.go).
# When unset (e.g. in tests or pre-deploy) the proxy logs a WARN and
# passes requests through.
_PROXY_API_KEY = os.environ.get("IMUNIFY_PROXY_API_KEY", "")
_POST_TIMEOUT = 5
_MAX_WORKERS = 4
# Cap on concurrently queued+in-flight POSTs. Matches the Go publisher's
# statusPublisherQueueSize; when the broker or proxy is slow we prefer
# dropping new events over unbounded memory growth.
_MAX_INFLIGHT = 128


class Gen:
    """ID + monotonic counter generator.

    Each instance has its own UUID and its own counter.
    """

    def __init__(self) -> None:
        self.id = uuid.uuid4().hex
        self._counter = 0
        self._lock = threading.Lock()

    def _next(self) -> int:
        with self._lock:
            value = self._counter
            self._counter += 1
            return value

    def enrich(self, msg: dict) -> None:
        """Add message_reporter_id and message_reporter_increment to msg."""
        msg["message_reporter_id"] = self.id
        msg["message_reporter_increment"] = self._next()


def _read_iaid() -> str:
    try:
        with open(_IAID_PATH) as f:
            return f.read().strip()
    except OSError:
        return ""


class MessageStatusPublisher:
    def __init__(self) -> None:
        self._iaid: str = ""
        self._init_lock = threading.Lock()
        self._initialized = False
        self._pool = concurrent.futures.ThreadPoolExecutor(
            max_workers=_MAX_WORKERS,
            thread_name_prefix="msg-status",
        )
        self._inflight = threading.BoundedSemaphore(_MAX_INFLIGHT)

    def _ensure_initialized(self) -> None:
        if self._initialized:
            return
        with self._init_lock:
            if self._initialized:
                return
            self._iaid = _read_iaid()
            if not self._iaid:
                logger.info(
                    "msg-status: iaid not available yet (file %s missing or"
                    " empty), will retry",
                    _IAID_PATH,
                )
                return
            self._initialized = True

    def report(self, msg: dict, reporter_gen: Gen, stage: str) -> None:
        """Publish a status record via HTTP POST to the proxy."""
        # Gate the feature flag before any allocation: report() is called
        # per message and when tracking is disabled (default) we want zero
        # dict/pool overhead.
        if not is_enabled("mqtt_tracking"):
            return
        method = msg.get("method", "")
        if not msg.get("message_reporter_id"):
            return

        # Bounded queue: drop new events when we're already at capacity so
        # a slow proxy/broker can't grow our memory unboundedly. Mirrors
        # the Go publisher's fire-and-drop channel pattern.
        if not self._inflight.acquire(blocking=False):
            logger.warning(
                "msg-status: queue full, dropping stage=%s method=%s",
                stage,
                method,
            )
            return

        record = {
            "timestamp": time.time(),
            "reporter_id": reporter_gen.id,
            "reporter_increment": reporter_gen._next(),
            "message_reporter_id": msg.get("message_reporter_id", ""),
            "message_reporter_increment": msg.get(
                "message_reporter_increment", 0
            ),
            "message_type": method,
            "stage": stage,
        }
        try:
            future = self._pool.submit(self._do_post, record, stage, method)
        except RuntimeError:
            # Pool already shut down.
            self._inflight.release()
            return
        future.add_done_callback(lambda _: self._inflight.release())

    def _do_post(self, record: dict, stage: str, method: str) -> None:
        self._ensure_initialized()
        if not self._iaid:
            return
        record["iaid"] = self._iaid
        try:
            payload = json.dumps(record).encode()
            headers = {"Content-Type": "application/json"}
            if _PROXY_API_KEY:
                headers["X-API-Key"] = _PROXY_API_KEY
            req = urllib.request.Request(
                _PUBLISH_ENDPOINT,
                data=payload,
                headers=headers,
                method="POST",
            )
            with urllib.request.urlopen(req, timeout=_POST_TIMEOUT) as resp:
                resp.read()
        except Exception as e:
            logger.warning(
                "msg-status: POST failed stage=%s method=%s: %r",
                stage,
                method,
                e,
            )

    def shutdown(self) -> None:
        self._pool.shutdown(wait=False)


publisher = MessageStatusPublisher()
atexit.register(publisher.shutdown)

message_id_gen = Gen()
