import pytest

from functools import partial

import attr

import trio
from trio.socket import AF_INET, SOCK_STREAM, IPPROTO_TCP
import trio.testing
from .test_ssl import client_ctx, SERVER_CTX

from .._highlevel_ssl_helpers import (
    open_ssl_over_tcp_stream,
    open_ssl_over_tcp_listeners,
    serve_ssl_over_tcp,
)


async def echo_handler(stream):
    async with stream:
        try:
            while True:
                data = await stream.receive_some(10000)
                if not data:
                    break
                await stream.send_all(data)
        except trio.BrokenResourceError:
            pass


# Resolver that always returns the given sockaddr, no matter what host/port
# you ask for.
@attr.s
class FakeHostnameResolver(trio.abc.HostnameResolver):
    sockaddr = attr.ib()

    async def getaddrinfo(self, *args):
        return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)]

    async def getnameinfo(self, *args):  # pragma: no cover
        raise NotImplementedError


# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners...
# noqa is needed because flake8 doesn't understand how pytest fixtures work.
async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx):  # noqa: F811
    async with trio.open_nursery() as nursery:
        (listener,) = await nursery.start(
            partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1")
        )
        async with listener:
            sockaddr = listener.transport_listener.socket.getsockname()
            hostname_resolver = FakeHostnameResolver(sockaddr)
            trio.socket.set_custom_hostname_resolver(hostname_resolver)

            # We don't have the right trust set up
            # (checks that ssl_context=None is doing some validation)
            stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80)
            async with stream:
                with pytest.raises(trio.BrokenResourceError):
                    await stream.do_handshake()

            # We have the trust but not the hostname
            # (checks custom ssl_context + hostname checking)
            stream = await open_ssl_over_tcp_stream(
                "xyzzy.example.org", 80, ssl_context=client_ctx
            )
            async with stream:
                with pytest.raises(trio.BrokenResourceError):
                    await stream.do_handshake()

            # This one should work!
            stream = await open_ssl_over_tcp_stream(
                "trio-test-1.example.org", 80, ssl_context=client_ctx
            )
            async with stream:
                assert isinstance(stream, trio.SSLStream)
                assert stream.server_hostname == "trio-test-1.example.org"
                await stream.send_all(b"x")
                assert await stream.receive_some(1) == b"x"

            # Check https_compatible settings are being passed through
            assert not stream._https_compatible
            stream = await open_ssl_over_tcp_stream(
                "trio-test-1.example.org",
                80,
                ssl_context=client_ctx,
                https_compatible=True,
                # also, smoke test happy_eyeballs_delay
                happy_eyeballs_delay=1,
            )
            async with stream:
                assert stream._https_compatible

            # Stop the echo server
            nursery.cancel_scope.cancel()


async def test_open_ssl_over_tcp_listeners():
    (listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1")
    async with listener:
        assert isinstance(listener, trio.SSLListener)
        tl = listener.transport_listener
        assert isinstance(tl, trio.SocketListener)
        assert tl.socket.getsockname()[0] == "127.0.0.1"

        assert not listener._https_compatible

    (listener,) = await open_ssl_over_tcp_listeners(
        0, SERVER_CTX, host="127.0.0.1", https_compatible=True
    )
    async with listener:
        assert listener._https_compatible
