import pytest

import attr

from ..abc import SendStream, ReceiveStream
from .._highlevel_generic import StapledStream


@attr.s
class RecordSendStream(SendStream):
    record = attr.ib(factory=list)

    async def send_all(self, data):
        self.record.append(("send_all", data))

    async def wait_send_all_might_not_block(self):
        self.record.append("wait_send_all_might_not_block")

    async def aclose(self):
        self.record.append("aclose")


@attr.s
class RecordReceiveStream(ReceiveStream):
    record = attr.ib(factory=list)

    async def receive_some(self, max_bytes=None):
        self.record.append(("receive_some", max_bytes))

    async def aclose(self):
        self.record.append("aclose")


async def test_StapledStream():
    send_stream = RecordSendStream()
    receive_stream = RecordReceiveStream()
    stapled = StapledStream(send_stream, receive_stream)

    assert stapled.send_stream is send_stream
    assert stapled.receive_stream is receive_stream

    await stapled.send_all(b"foo")
    await stapled.wait_send_all_might_not_block()
    assert send_stream.record == [
        ("send_all", b"foo"),
        "wait_send_all_might_not_block",
    ]
    send_stream.record.clear()

    await stapled.send_eof()
    assert send_stream.record == ["aclose"]
    send_stream.record.clear()

    async def fake_send_eof():
        send_stream.record.append("send_eof")

    send_stream.send_eof = fake_send_eof
    await stapled.send_eof()
    assert send_stream.record == ["send_eof"]

    send_stream.record.clear()
    assert receive_stream.record == []

    await stapled.receive_some(1234)
    assert receive_stream.record == [("receive_some", 1234)]
    assert send_stream.record == []
    receive_stream.record.clear()

    await stapled.aclose()
    assert receive_stream.record == ["aclose"]
    assert send_stream.record == ["aclose"]


async def test_StapledStream_with_erroring_close():
    # Make sure that if one of the aclose methods errors out, then the other
    # one still gets called.
    class BrokenSendStream(RecordSendStream):
        async def aclose(self):
            await super().aclose()
            raise ValueError

    class BrokenReceiveStream(RecordReceiveStream):
        async def aclose(self):
            await super().aclose()
            raise ValueError

    stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())

    with pytest.raises(ValueError) as excinfo:
        await stapled.aclose()
    assert isinstance(excinfo.value.__context__, ValueError)

    assert stapled.send_stream.record == ["aclose"]
    assert stapled.receive_stream.record == ["aclose"]
