from __future__ import annotations

import asyncio
from datetime import timedelta
from json.decoder import JSONDecodeError
from typing import (
    TYPE_CHECKING,
    TypeGuard,
)

from aiohttp import ClientConnectionResetError, http, web
from cross_web import AiohttpHTTPRequestAdapter, HTTPException

from strawberry.http.async_base_view import (
    AsyncBaseHTTPView,
    AsyncWebSocketAdapter,
)
from strawberry.http.exceptions import (
    NonJsonMessageReceived,
    NonTextMessageReceived,
    WebSocketDisconnected,
)
from strawberry.http.typevars import (
    Context,
    RootValue,
)
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL

if TYPE_CHECKING:
    from collections.abc import AsyncGenerator, Callable, Mapping, Sequence

    from strawberry.http import GraphQLHTTPResponse
    from strawberry.http.ides import GraphQL_IDE
    from strawberry.schema import BaseSchema


class AiohttpWebSocketAdapter(AsyncWebSocketAdapter):
    def __init__(
        self, view: AsyncBaseHTTPView, request: web.Request, ws: web.WebSocketResponse
    ) -> None:
        super().__init__(view)
        self.request = request
        self.ws = ws

    async def iter_json(
        self, *, ignore_parsing_errors: bool = False
    ) -> AsyncGenerator[object, None]:
        async for ws_message in self.ws:
            if ws_message.type == http.WSMsgType.TEXT:
                try:
                    yield self.view.decode_json(ws_message.data)
                except JSONDecodeError as e:
                    if not ignore_parsing_errors:
                        raise NonJsonMessageReceived from e

            elif ws_message.type == http.WSMsgType.BINARY:
                raise NonTextMessageReceived

    async def send_json(self, message: Mapping[str, object]) -> None:
        try:
            encoded_data = self.view.encode_json(message)
            if isinstance(encoded_data, bytes):
                await self.ws.send_bytes(encoded_data)
            else:
                await self.ws.send_str(encoded_data)
        except (RuntimeError, ClientConnectionResetError) as exc:
            raise WebSocketDisconnected from exc

    async def close(self, code: int, reason: str) -> None:
        await self.ws.close(code=code, message=reason.encode())


class GraphQLView(
    AsyncBaseHTTPView[
        web.Request,
        web.Response | web.StreamResponse,
        web.Response,
        web.Request,
        web.WebSocketResponse,
        Context,
        RootValue,
    ]
):
    # Mark the view as coroutine so that AIOHTTP does not confuse it with a deprecated
    # bare handler function.
    _is_coroutine = asyncio.coroutines._is_coroutine  # type: ignore[attr-defined]

    allow_queries_via_get = True
    request_adapter_class = AiohttpHTTPRequestAdapter
    websocket_adapter_class = AiohttpWebSocketAdapter  # type: ignore

    def __init__(
        self,
        schema: BaseSchema,
        graphql_ide: GraphQL_IDE | None = "graphiql",
        allow_queries_via_get: bool = True,
        keep_alive: bool = True,
        keep_alive_interval: float = 1,
        subscription_protocols: Sequence[str] = (
            GRAPHQL_TRANSPORT_WS_PROTOCOL,
            GRAPHQL_WS_PROTOCOL,
        ),
        connection_init_wait_timeout: timedelta = timedelta(minutes=1),
        multipart_uploads_enabled: bool = False,
    ) -> None:
        self.schema = schema
        self.allow_queries_via_get = allow_queries_via_get
        self.keep_alive = keep_alive
        self.keep_alive_interval = keep_alive_interval
        self.subscription_protocols = subscription_protocols
        self.connection_init_wait_timeout = connection_init_wait_timeout
        self.multipart_uploads_enabled = multipart_uploads_enabled
        self.graphql_ide = graphql_ide

    async def render_graphql_ide(self, request: web.Request) -> web.Response:
        return web.Response(text=self.graphql_ide_html, content_type="text/html")

    async def get_sub_response(self, request: web.Request) -> web.Response:
        return web.Response()

    def is_websocket_request(self, request: web.Request) -> TypeGuard[web.Request]:
        ws = web.WebSocketResponse(protocols=self.subscription_protocols)
        return ws.can_prepare(request).ok

    async def pick_websocket_subprotocol(self, request: web.Request) -> str | None:
        ws = web.WebSocketResponse(protocols=self.subscription_protocols)
        return ws.can_prepare(request).protocol

    async def create_websocket_response(
        self, request: web.Request, subprotocol: str | None
    ) -> web.WebSocketResponse:
        protocols = [subprotocol] if subprotocol else []
        ws = web.WebSocketResponse(protocols=protocols)
        await ws.prepare(request)
        return ws

    async def __call__(self, request: web.Request) -> web.StreamResponse:
        try:
            return await self.run(request=request)
        except HTTPException as e:
            return web.Response(
                body=e.reason,
                status=e.status_code,
            )

    async def get_root_value(self, request: web.Request) -> RootValue | None:
        return None

    async def get_context(
        self, request: web.Request, response: web.Response | web.WebSocketResponse
    ) -> Context:
        return {"request": request, "response": response}  # type: ignore

    def create_response(
        self,
        response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse],
        sub_response: web.Response,
    ) -> web.Response:
        encoded_data = self.encode_json(response_data)
        if isinstance(encoded_data, bytes):
            encoded_data = encoded_data.decode()
        sub_response.text = encoded_data
        sub_response.content_type = "application/json"

        return sub_response

    async def create_streaming_response(
        self,
        request: web.Request,
        stream: Callable[[], AsyncGenerator[str, None]],
        sub_response: web.Response,
        headers: dict[str, str],
    ) -> web.StreamResponse:
        response = web.StreamResponse(
            status=sub_response.status,
            headers={
                **sub_response.headers,
                **headers,
            },
        )

        await response.prepare(request)

        async for data in stream():
            await response.write(data.encode())

        await response.write_eof()

        return response


__all__ = ["GraphQLView"]
