"""
Internal utilities to manipulate connection strings
"""

# Copyright (C) 2024 The Psycopg Team

from __future__ import annotations

import os
from typing import Any
from functools import lru_cache
from ipaddress import ip_address
from dataclasses import dataclass
from typing_extensions import TypeAlias

from . import pq
from . import errors as e

ConnDict: TypeAlias = "dict[str, Any]"


def split_attempts(params: ConnDict) -> list[ConnDict]:
    """
    Split connection parameters with a sequence of hosts into separate attempts.
    """

    def split_val(key: str) -> list[str]:
        val = get_param(params, key)
        return val.split(",") if val else []

    hosts = split_val("host")
    hostaddrs = split_val("hostaddr")
    ports = split_val("port")

    if hosts and hostaddrs and len(hosts) != len(hostaddrs):
        raise e.OperationalError(
            f"could not match {len(hosts)} host names"
            f" with {len(hostaddrs)} hostaddr values"
        )

    nhosts = max(len(hosts), len(hostaddrs))

    if 1 < len(ports) != nhosts:
        raise e.OperationalError(
            f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
        )

    # A single attempt to make. Don't mangle the conninfo string.
    if nhosts <= 1:
        return [params]

    if len(ports) == 1:
        ports *= nhosts

    # Now all lists are either empty or have the same length
    rv = []
    for i in range(nhosts):
        attempt = params.copy()
        if hosts:
            attempt["host"] = hosts[i]
        if hostaddrs:
            attempt["hostaddr"] = hostaddrs[i]
        if ports:
            attempt["port"] = ports[i]
        rv.append(attempt)

    return rv


def get_param(params: ConnDict, name: str) -> str | None:
    """
    Return a value from a connection string.

    The value may be also specified in a PG* env var.
    """
    if name in params:
        return str(params[name])

    # TODO: check if in service

    paramdef = get_param_def(name)
    if not paramdef:
        return None

    env = os.environ.get(paramdef.envvar)
    if env is not None:
        return env

    return None


@dataclass
class ParamDef:
    """
    Information about defaults and env vars for connection params
    """

    keyword: str
    envvar: str
    compiled: str | None


def get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None:
    """
    Return the ParamDef of a connection string parameter.
    """
    if not _cache:
        defs = pq.Conninfo.get_defaults()
        for d in defs:
            cd = ParamDef(
                keyword=d.keyword.decode(),
                envvar=d.envvar.decode() if d.envvar else "",
                compiled=d.compiled.decode() if d.compiled is not None else None,
            )
            _cache[cd.keyword] = cd

    return _cache.get(keyword)


@lru_cache()
def is_ip_address(s: str) -> bool:
    """Return True if the string represent a valid ip address."""
    try:
        ip_address(s)
    except ValueError:
        return False
    return True
