import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
from acados_template import latexify_plot


def plot_cartpole(t, u_max, U, X_true, latexify=False, plt_show=False, time_label='$t$', x_labels=None, u_labels=None, saveas=None):
    """
    Params:
        t: time values of the discretization
        u_max: maximum absolute value of u
        U: arrray with shape (N-1, nu) or (N, nu)
        X_true: arrray with shape (N, nx)
        latexify: latex style plots
    """

    if latexify:
        latexify_plot()

    nx = X_true.shape[1]
    fig, axes = plt.subplots(nx+1, 1, sharex=True)

    for i in range(nx):
        axes[i].plot(t, X_true[:, i])
        axes[i].grid()
        if x_labels is not None:
            axes[i].set_ylabel(x_labels[i])
        else:
            axes[i].set_ylabel(f'$x_{i}$')

    axes[-1].step(t, np.append([U[0]], U))

    if u_labels is not None:
        axes[-1].set_ylabel(u_labels[0])
    else:
        axes[-1].set_ylabel('$u$')

    axes[-1].hlines(u_max, t[0], t[-1], linestyles='dashed', alpha=0.7)
    axes[-1].hlines(-u_max, t[0], t[-1], linestyles='dashed', alpha=0.7)
    axes[-1].set_ylim([-1.2*u_max, 1.2*u_max])
    axes[-1].set_xlim(t[0], t[-1])
    axes[-1].set_xlabel(time_label)
    axes[-1].grid()

    plt.subplots_adjust(left=None, bottom=None, right=None, top=None, hspace=0.4)

    fig.align_ylabels()

    if saveas is not None:
        plt.savefig(saveas)

    if plt_show:
        plt.show()


def animate_cart_pole(Xtraj, dt=0.05, plt_show=False, saveas=None):
    '''
    Create animation of the cart pole system.
    Xtraj is a matrix of size N x 4, where N is the number of time steps, with the following columns:
    x1, theta, v1, omega

    dt defines the time gap (in seconds) between two successive images.
    '''

    N = Xtraj.shape[0]
    pendulum_length = 1.0
    cart_width = .2
    cart_height = .1

    x1 = Xtraj[:, 0]
    theta = Xtraj[:, 1]

    # x and y position of the tip of the pendulum
    pendu_tip_x = x1 - pendulum_length * np.sin(theta)
    pendu_tip_y = 0 + pendulum_length * np.cos(theta)

    xmin = min(np.min(x1), np.min(pendu_tip_x)) - 5 * cart_width / 2
    xmax = max(np.max(x1), np.max(pendu_tip_x)) + 5 * cart_width / 2

    fig, ax = plt.subplots()

    def animate(i):
        ax.clear()

        # level of cart
        ax.plot([xmin, xmax], [0, 0], 'k--')

        # draw rectancle for cart
        cart = mpl.patches.Rectangle((x1[i] - cart_width / 2, 0 - cart_height / 2),
                                     cart_width,
                                     cart_height,
                                     facecolor='C0')
        ax.add_patch(cart)

        # draw line for pendulum
        pendu = mpl.lines.Line2D([x1[i], pendu_tip_x[i]], [0, pendu_tip_y[i]],
                                 color='k',
                                 linewidth=2)
        ax.add_line(pendu)

        # trace of pendulum tip
        ax.plot(pendu_tip_x[:i], pendu_tip_y[:i], color='lightgray', linewidth=1)

        ax.set_xlim([xmin, xmax])
        ax.set_ylim([-1.2, 1.2])
        ax.set_aspect('equal')


    ani = FuncAnimation(fig, animate, N, interval=dt * 1000, repeat_delay=500, repeat=True)
    if saveas is not None:
        ani.save(saveas, dpi=100)

    if plt_show:
        plt.show()

