import time
from collections.abc import Iterator
from dataclasses import dataclass

import wsproto.extensions
import wsproto.frame_protocol
import wsproto.utilities
from wsproto import ConnectionState
from wsproto.frame_protocol import Opcode

from mitmproxy import connection
from mitmproxy import http
from mitmproxy import websocket
from mitmproxy.proxy import commands
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy.commands import StartHook
from mitmproxy.proxy.context import Context
from mitmproxy.proxy.events import MessageInjected
from mitmproxy.proxy.utils import expect


@dataclass
class WebsocketStartHook(StartHook):
    """
    A WebSocket connection has commenced.
    """

    flow: http.HTTPFlow


@dataclass
class WebsocketMessageHook(StartHook):
    """
    Called when a WebSocket message is received from the client or
    server. The most recent message will be flow.messages[-1]. The
    message is user-modifiable. Currently there are two types of
    messages, corresponding to the BINARY and TEXT frame types.
    """

    flow: http.HTTPFlow


@dataclass
class WebsocketEndHook(StartHook):
    """
    A WebSocket connection has ended.
    You can check `flow.websocket.close_code` to determine why it ended.
    """

    flow: http.HTTPFlow


class WebSocketMessageInjected(MessageInjected[websocket.WebSocketMessage]):
    """
    The user has injected a custom WebSocket message.
    """


class WebsocketConnection(wsproto.Connection):
    """
    A very thin wrapper around wsproto.Connection:

     - we keep the underlying connection as an attribute for easy access.
     - we add a framebuffer for incomplete messages
     - we wrap .send() so that we can directly yield it.
    """

    conn: connection.Connection
    frame_buf: list[bytes]

    def __init__(self, *args, conn: connection.Connection, **kwargs):
        super().__init__(*args, **kwargs)
        self.conn = conn
        self.frame_buf = [b""]

    def send2(self, event: wsproto.events.Event) -> commands.SendData:
        data = self.send(event)
        return commands.SendData(self.conn, data)

    def __repr__(self):
        return f"WebsocketConnection<{self.state.name}, {self.conn}>"


class WebsocketLayer(layer.Layer):
    """
    WebSocket layer that intercepts and relays messages.
    """

    flow: http.HTTPFlow
    client_ws: WebsocketConnection
    server_ws: WebsocketConnection

    def __init__(self, context: Context, flow: http.HTTPFlow):
        super().__init__(context)
        self.flow = flow

    @expect(events.Start)
    def start(self, _) -> layer.CommandGenerator[None]:
        client_extensions = []
        server_extensions = []

        # Parse extension headers. We only support deflate at the moment and ignore everything else.
        assert self.flow.response  # satisfy type checker
        ext_header = self.flow.response.headers.get("Sec-WebSocket-Extensions", "")
        if ext_header:
            for ext in wsproto.utilities.split_comma_header(
                ext_header.encode("ascii", "replace")
            ):
                ext_name = ext.split(";", 1)[0].strip()
                if ext_name == wsproto.extensions.PerMessageDeflate.name:
                    client_deflate = wsproto.extensions.PerMessageDeflate()
                    client_deflate.finalize(ext)
                    client_extensions.append(client_deflate)
                    server_deflate = wsproto.extensions.PerMessageDeflate()
                    server_deflate.finalize(ext)
                    server_extensions.append(server_deflate)
                else:
                    yield commands.Log(
                        f"Ignoring unknown WebSocket extension {ext_name!r}."
                    )

        self.client_ws = WebsocketConnection(
            wsproto.ConnectionType.SERVER, client_extensions, conn=self.context.client
        )
        self.server_ws = WebsocketConnection(
            wsproto.ConnectionType.CLIENT, server_extensions, conn=self.context.server
        )

        yield WebsocketStartHook(self.flow)

        self._handle_event = self.relay_messages

    _handle_event = start

    @expect(events.DataReceived, events.ConnectionClosed, WebSocketMessageInjected)
    def relay_messages(self, event: events.Event) -> layer.CommandGenerator[None]:
        assert self.flow.websocket  # satisfy type checker

        if isinstance(event, events.ConnectionEvent):
            from_client = event.connection == self.context.client
            injected = False
        elif isinstance(event, WebSocketMessageInjected):
            from_client = event.message.from_client
            injected = True
        else:
            raise AssertionError(f"Unexpected event: {event}")

        from_str = "client" if from_client else "server"
        if from_client:
            src_ws = self.client_ws
            dst_ws = self.server_ws
        else:
            src_ws = self.server_ws
            dst_ws = self.client_ws

        if isinstance(event, events.DataReceived):
            src_ws.receive_data(event.data)
        elif isinstance(event, events.ConnectionClosed):
            src_ws.receive_data(None)
        elif isinstance(event, WebSocketMessageInjected):
            fragmentizer = Fragmentizer([], event.message.type == Opcode.TEXT)
            src_ws._events.extend(fragmentizer(event.message.content))
        else:  # pragma: no cover
            raise AssertionError(f"Unexpected event: {event}")

        for ws_event in src_ws.events():
            if isinstance(ws_event, wsproto.events.Message):
                is_text = isinstance(ws_event.data, str)
                if is_text:
                    typ = Opcode.TEXT
                    src_ws.frame_buf[-1] += ws_event.data.encode()
                else:
                    typ = Opcode.BINARY
                    src_ws.frame_buf[-1] += ws_event.data

                if ws_event.message_finished:
                    content = b"".join(src_ws.frame_buf)

                    fragmentizer = Fragmentizer(src_ws.frame_buf, is_text)
                    src_ws.frame_buf = [b""]

                    message = websocket.WebSocketMessage(
                        typ, from_client, content, injected=injected
                    )
                    self.flow.websocket.messages.append(message)
                    yield WebsocketMessageHook(self.flow)

                    if not message.dropped:
                        for msg in fragmentizer(message.content):
                            yield dst_ws.send2(msg)

                elif ws_event.frame_finished:
                    src_ws.frame_buf.append(b"")

            elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)):
                yield commands.Log(
                    f"Received WebSocket {ws_event.__class__.__name__.lower()} from {from_str} "
                    f"(payload: {bytes(ws_event.payload)!r})"
                )
                yield dst_ws.send2(ws_event)
            elif isinstance(ws_event, wsproto.events.CloseConnection):
                self.flow.websocket.timestamp_end = time.time()
                self.flow.websocket.closed_by_client = from_client
                self.flow.websocket.close_code = ws_event.code
                self.flow.websocket.close_reason = ws_event.reason

                for ws in [self.server_ws, self.client_ws]:
                    if ws.state in {
                        ConnectionState.OPEN,
                        ConnectionState.REMOTE_CLOSING,
                    }:
                        # response == original event, so no need to differentiate here.
                        yield ws.send2(ws_event)
                    yield commands.CloseConnection(ws.conn)
                yield WebsocketEndHook(self.flow)
                self.flow.live = False
                self._handle_event = self.done
            else:  # pragma: no cover
                raise AssertionError(f"Unexpected WebSocket event: {ws_event}")

    @expect(events.DataReceived, events.ConnectionClosed, WebSocketMessageInjected)
    def done(self, _) -> layer.CommandGenerator[None]:
        yield from ()


class Fragmentizer:
    """
    Theory (RFC 6455):
       Unless specified otherwise by an extension, frames have no semantic
       meaning.  An intermediary might coalesce and/or split frames, [...]

    Practice:
        Some WebSocket servers reject large payload sizes.
        Other WebSocket servers reject CONTINUATION frames.

    As a workaround, we either retain the original chunking or, if the payload has been modified, use ~4kB chunks.
    If one deals with web servers that do not support CONTINUATION frames, addons need to monkeypatch FRAGMENT_SIZE
    if they need to modify the message.
    """

    # A bit less than 4kb to accommodate for headers.
    FRAGMENT_SIZE = 4000

    def __init__(self, fragments: list[bytes], is_text: bool):
        self.fragment_lengths = [len(x) for x in fragments]
        self.is_text = is_text

    def msg(self, data: bytes, message_finished: bool):
        if self.is_text:
            data_str = data.decode(errors="replace")
            return wsproto.events.TextMessage(
                data_str, message_finished=message_finished
            )
        else:
            return wsproto.events.BytesMessage(data, message_finished=message_finished)

    def __call__(self, content: bytes) -> Iterator[wsproto.events.Message]:
        if len(content) == sum(self.fragment_lengths):
            # message has the same length, we can reuse the same sizes
            offset = 0
            for fl in self.fragment_lengths[:-1]:
                yield self.msg(content[offset : offset + fl], False)
                offset += fl
            yield self.msg(content[offset:], True)
        else:
            offset = 0
            total = len(content) - self.FRAGMENT_SIZE
            while offset < total:
                yield self.msg(content[offset : offset + self.FRAGMENT_SIZE], False)
                offset += self.FRAGMENT_SIZE
            yield self.msg(content[offset:], True)
