"""
Separate connection attempts from a connection string.
"""

# Copyright (C) 2024 The Psycopg Team

from __future__ import annotations

import socket
import logging
from random import shuffle

from . import errors as e
from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def
from ._conninfo_utils import split_attempts

logger = logging.getLogger("psycopg")


def conninfo_attempts(params: ConnDict) -> list[ConnDict]:
    """Split a set of connection params on the single attempts to perform.

    A connection param can perform more than one attempt more than one ``host``
    is provided.

    Also perform async resolution of the hostname into hostaddr. Because a host
    can resolve to more than one address, this can lead to yield more attempts
    too. Raise `OperationalError` if no host could be resolved.

    Because the libpq async function doesn't honour the timeout, we need to
    reimplement the repeated attempts.
    """
    last_exc = None
    attempts = []
    for attempt in split_attempts(params):
        try:
            attempts.extend(_resolve_hostnames(attempt))
        except OSError as ex:
            logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex))
            last_exc = ex

    if not attempts:
        assert last_exc
        # We couldn't resolve anything
        raise e.OperationalError(str(last_exc))

    if get_param(params, "load_balance_hosts") == "random":
        shuffle(attempts)

    return attempts


def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
    """
    Perform DNS lookup of the hosts and return a list of connection attempts.

    If a ``host`` param is present but not ``hostname``, resolve the host
    addresses asynchronously.

    :param params: The input parameters, for instance as returned by
        `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
        a single entry for host, hostaddr because it is designed to further
        process the input of split_attempts().

    :return: A list of attempts to make (to include the case of a hostname
        resolving to more than one IP).
    """
    host = get_param(params, "host")
    if not host or host.startswith("/") or host[1:2] == ":":
        # Local path, or no host to resolve
        return [params]

    hostaddr = get_param(params, "hostaddr")
    if hostaddr:
        # Already resolved
        return [params]

    if is_ip_address(host):
        # If the host is already an ip address don't try to resolve it
        return [{**params, "hostaddr": host}]

    port = get_param(params, "port")
    if not port:
        port_def = get_param_def("port")
        port = port_def and port_def.compiled or "5432"

    ans = socket.getaddrinfo(
        host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
    )
    return [{**params, "hostaddr": item[4][0]} for item in ans]
