"""
An experimental support for curvilinear grid.
"""

import functools
from itertools import chain

import numpy as np

import matplotlib as mpl
from matplotlib import _api
from matplotlib.path import Path
from matplotlib.transforms import Affine2D, IdentityTransform
from .axislines import AxisArtistHelper, GridHelperBase
from .axis_artist import AxisArtist
from .grid_finder import GridFinder


class FixedAxisArtistHelper(AxisArtistHelper.Fixed):
    """
    Helper class for a fixed axis.
    """

    def __init__(self, grid_helper, side, nth_coord_ticks=None):
        """
        nth_coord = along which coordinate value varies.
         nth_coord = 0 ->  x axis, nth_coord = 1 -> y axis
        """

        super().__init__(loc=side)

        self.grid_helper = grid_helper
        if nth_coord_ticks is None:
            nth_coord_ticks = self.nth_coord
        self.nth_coord_ticks = nth_coord_ticks

        self.side = side

    def update_lim(self, axes):
        self.grid_helper.update_lim(axes)

    @_api.deprecated("3.5")
    def change_tick_coord(self, coord_number=None):
        if coord_number is None:
            self.nth_coord_ticks = 1 - self.nth_coord_ticks
        elif coord_number in [0, 1]:
            self.nth_coord_ticks = coord_number
        else:
            raise Exception("wrong coord number")

    def get_tick_transform(self, axes):
        return axes.transData

    def get_tick_iterators(self, axes):
        """tick_loc, tick_angle, tick_label"""
        v1, v2 = axes.get_ylim() if self.nth_coord == 0 else axes.get_xlim()
        if v1 > v2:  # Inverted limits.
            side = {"left": "right", "right": "left",
                    "top": "bottom", "bottom": "top"}[self.side]
        else:
            side = self.side
        g = self.grid_helper
        ti1 = g.get_tick_iterator(self.nth_coord_ticks, side)
        ti2 = g.get_tick_iterator(1-self.nth_coord_ticks, side, minor=True)
        return chain(ti1, ti2), iter([])


class FloatingAxisArtistHelper(AxisArtistHelper.Floating):
    grid_info = _api.deprecate_privatize_attribute("3.5")

    def __init__(self, grid_helper, nth_coord, value, axis_direction=None):
        """
        nth_coord = along which coordinate value varies.
         nth_coord = 0 ->  x axis, nth_coord = 1 -> y axis
        """
        super().__init__(nth_coord, value)
        self.value = value
        self.grid_helper = grid_helper
        self._extremes = -np.inf, np.inf
        self._line_num_points = 100  # number of points to create a line

    def set_extremes(self, e1, e2):
        if e1 is None:
            e1 = -np.inf
        if e2 is None:
            e2 = np.inf
        self._extremes = e1, e2

    def update_lim(self, axes):
        self.grid_helper.update_lim(axes)

        x1, x2 = axes.get_xlim()
        y1, y2 = axes.get_ylim()
        grid_finder = self.grid_helper.grid_finder
        extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
                                              x1, y1, x2, y2)

        lon_min, lon_max, lat_min, lat_max = extremes
        e_min, e_max = self._extremes  # ranges of other coordinates
        if self.nth_coord == 0:
            lat_min = max(e_min, lat_min)
            lat_max = min(e_max, lat_max)
        elif self.nth_coord == 1:
            lon_min = max(e_min, lon_min)
            lon_max = min(e_max, lon_max)

        lon_levs, lon_n, lon_factor = \
            grid_finder.grid_locator1(lon_min, lon_max)
        lat_levs, lat_n, lat_factor = \
            grid_finder.grid_locator2(lat_min, lat_max)

        if self.nth_coord == 0:
            xx0 = np.full(self._line_num_points, self.value)
            yy0 = np.linspace(lat_min, lat_max, self._line_num_points)
            xx, yy = grid_finder.transform_xy(xx0, yy0)
        elif self.nth_coord == 1:
            xx0 = np.linspace(lon_min, lon_max, self._line_num_points)
            yy0 = np.full(self._line_num_points, self.value)
            xx, yy = grid_finder.transform_xy(xx0, yy0)

        self._grid_info = {
            "extremes": (lon_min, lon_max, lat_min, lat_max),
            "lon_info": (lon_levs, lon_n, np.asarray(lon_factor)),
            "lat_info": (lat_levs, lat_n, np.asarray(lat_factor)),
            "lon_labels": grid_finder.tick_formatter1(
                "bottom", lon_factor, lon_levs),
            "lat_labels": grid_finder.tick_formatter2(
                "bottom", lat_factor, lat_levs),
            "line_xy": (xx, yy),
        }

    def get_axislabel_transform(self, axes):
        return Affine2D()  # axes.transData

    def get_axislabel_pos_angle(self, axes):

        extremes = self._grid_info["extremes"]

        if self.nth_coord == 0:
            xx0 = self.value
            yy0 = (extremes[2] + extremes[3]) / 2
            dxx = 0
            dyy = abs(extremes[2] - extremes[3]) / 1000
        elif self.nth_coord == 1:
            xx0 = (extremes[0] + extremes[1]) / 2
            yy0 = self.value
            dxx = abs(extremes[0] - extremes[1]) / 1000
            dyy = 0

        grid_finder = self.grid_helper.grid_finder
        (xx1,), (yy1,) = grid_finder.transform_xy([xx0], [yy0])

        data_to_axes = axes.transData - axes.transAxes
        p = data_to_axes.transform([xx1, yy1])

        if 0 <= p[0] <= 1 and 0 <= p[1] <= 1:
            xx1c, yy1c = axes.transData.transform([xx1, yy1])
            (xx2,), (yy2,) = grid_finder.transform_xy([xx0 + dxx], [yy0 + dyy])
            xx2c, yy2c = axes.transData.transform([xx2, yy2])
            return (xx1c, yy1c), np.rad2deg(np.arctan2(yy2c-yy1c, xx2c-xx1c))
        else:
            return None, None

    def get_tick_transform(self, axes):
        return IdentityTransform()  # axes.transData

    def get_tick_iterators(self, axes):
        """tick_loc, tick_angle, tick_label, (optionally) tick_label"""

        grid_finder = self.grid_helper.grid_finder

        lat_levs, lat_n, lat_factor = self._grid_info["lat_info"]
        yy0 = lat_levs / lat_factor
        dy = 0.01 / lat_factor

        lon_levs, lon_n, lon_factor = self._grid_info["lon_info"]
        xx0 = lon_levs / lon_factor
        dx = 0.01 / lon_factor

        e0, e1 = self._extremes

        if self.nth_coord == 0:
            mask = (e0 <= yy0) & (yy0 <= e1)
            # xx0, yy0 = xx0[mask], yy0[mask]
            yy0 = yy0[mask]
        elif self.nth_coord == 1:
            mask = (e0 <= xx0) & (xx0 <= e1)
            # xx0, yy0 = xx0[mask], yy0[mask]
            xx0 = xx0[mask]

        def transform_xy(x, y):
            trf = grid_finder.get_transform() + axes.transData
            return trf.transform(np.column_stack([x, y])).T

        # find angles
        if self.nth_coord == 0:
            xx0 = np.full_like(yy0, self.value)

            xx1, yy1 = transform_xy(xx0, yy0)

            xx00 = xx0.copy()
            xx00[xx0 + dx > e1] -= dx
            xx1a, yy1a = transform_xy(xx00, yy0)
            xx1b, yy1b = transform_xy(xx00+dx, yy0)

            xx2a, yy2a = transform_xy(xx0, yy0)
            xx2b, yy2b = transform_xy(xx0, yy0+dy)

            labels = self._grid_info["lat_labels"]
            labels = [l for l, m in zip(labels, mask) if m]

        elif self.nth_coord == 1:
            yy0 = np.full_like(xx0, self.value)

            xx1, yy1 = transform_xy(xx0, yy0)

            xx1a, yy1a = transform_xy(xx0, yy0)
            xx1b, yy1b = transform_xy(xx0, yy0+dy)

            xx00 = xx0.copy()
            xx00[xx0 + dx > e1] -= dx
            xx2a, yy2a = transform_xy(xx00, yy0)
            xx2b, yy2b = transform_xy(xx00+dx, yy0)

            labels = self._grid_info["lon_labels"]
            labels = [l for l, m in zip(labels, mask) if m]

        def f1():
            dd = np.arctan2(yy1b-yy1a, xx1b-xx1a)  # angle normal
            dd2 = np.arctan2(yy2b-yy2a, xx2b-xx2a)  # angle tangent
            mm = (yy1b == yy1a) & (xx1b == xx1a)  # mask where dd not defined
            dd[mm] = dd2[mm] + np.pi / 2

            tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
            in_01 = functools.partial(
                mpl.transforms._interval_contains_close, (0, 1))
            for x, y, d, d2, lab in zip(xx1, yy1, dd, dd2, labels):
                c2 = tick_to_axes.transform((x, y))
                if in_01(c2[0]) and in_01(c2[1]):
                    d1, d2 = np.rad2deg([d, d2])
                    yield [x, y], d1, d2, lab

        return f1(), iter([])

    def get_line_transform(self, axes):
        return axes.transData

    def get_line(self, axes):
        self.update_lim(axes)
        x, y = self._grid_info["line_xy"]
        return Path(np.column_stack([x, y]))


class GridHelperCurveLinear(GridHelperBase):
    grid_info = _api.deprecate_privatize_attribute("3.5")

    def __init__(self, aux_trans,
                 extreme_finder=None,
                 grid_locator1=None,
                 grid_locator2=None,
                 tick_formatter1=None,
                 tick_formatter2=None):
        """
        aux_trans : a transform from the source (curved) coordinate to
        target (rectilinear) coordinate. An instance of MPL's Transform
        (inverse transform should be defined) or a tuple of two callable
        objects which defines the transform and its inverse. The callables
        need take two arguments of array of source coordinates and
        should return two target coordinates.

        e.g., ``x2, y2 = trans(x1, y1)``
        """
        super().__init__()
        self._grid_info = None
        self._aux_trans = aux_trans
        self.grid_finder = GridFinder(aux_trans,
                                      extreme_finder,
                                      grid_locator1,
                                      grid_locator2,
                                      tick_formatter1,
                                      tick_formatter2)

    def update_grid_finder(self, aux_trans=None, **kwargs):
        if aux_trans is not None:
            self.grid_finder.update_transform(aux_trans)
        self.grid_finder.update(**kwargs)
        self._old_limits = None  # Force revalidation.

    def new_fixed_axis(self, loc,
                       nth_coord=None,
                       axis_direction=None,
                       offset=None,
                       axes=None):
        if axes is None:
            axes = self.axes
        if axis_direction is None:
            axis_direction = loc
        _helper = FixedAxisArtistHelper(self, loc, nth_coord_ticks=nth_coord)
        axisline = AxisArtist(axes, _helper, axis_direction=axis_direction)
        # Why is clip not set on axisline, unlike in new_floating_axis or in
        # the floating_axig.GridHelperCurveLinear subclass?
        return axisline

    def new_floating_axis(self, nth_coord,
                          value,
                          axes=None,
                          axis_direction="bottom"
                          ):

        if axes is None:
            axes = self.axes

        _helper = FloatingAxisArtistHelper(
            self, nth_coord, value, axis_direction)

        axisline = AxisArtist(axes, _helper)

        # _helper = FloatingAxisArtistHelper(self, nth_coord,
        #                                    value,
        #                                    label_direction=label_direction,
        #                                    )

        # axisline = AxisArtistFloating(axes, _helper,
        #                               axis_direction=axis_direction)
        axisline.line.set_clip_on(True)
        axisline.line.set_clip_box(axisline.axes.bbox)
        # axisline.major_ticklabels.set_visible(True)
        # axisline.minor_ticklabels.set_visible(False)

        return axisline

    def _update_grid(self, x1, y1, x2, y2):
        self._grid_info = self.grid_finder.get_grid_info(x1, y1, x2, y2)

    def get_gridlines(self, which="major", axis="both"):
        grid_lines = []
        if axis in ["both", "x"]:
            for gl in self._grid_info["lon"]["lines"]:
                grid_lines.extend(gl)
        if axis in ["both", "y"]:
            for gl in self._grid_info["lat"]["lines"]:
                grid_lines.extend(gl)
        return grid_lines

    def get_tick_iterator(self, nth_coord, axis_side, minor=False):

        # axisnr = dict(left=0, bottom=1, right=2, top=3)[axis_side]
        angle_tangent = dict(left=90, right=90, bottom=0, top=0)[axis_side]
        # angle = [0, 90, 180, 270][axisnr]
        lon_or_lat = ["lon", "lat"][nth_coord]
        if not minor:  # major ticks
            for (xy, a), l in zip(
                    self._grid_info[lon_or_lat]["tick_locs"][axis_side],
                    self._grid_info[lon_or_lat]["tick_labels"][axis_side]):
                angle_normal = a
                yield xy, angle_normal, angle_tangent, l
        else:
            for (xy, a), l in zip(
                    self._grid_info[lon_or_lat]["tick_locs"][axis_side],
                    self._grid_info[lon_or_lat]["tick_labels"][axis_side]):
                angle_normal = a
                yield xy, angle_normal, angle_tangent, ""
            # for xy, a, l in self._grid_info[lon_or_lat]["ticks"][axis_side]:
            #     yield xy, a, ""
