import numpy as np
import matplotlib as mpl
import matplotlib.pylab as plt
import matplotlib.animation as animation

def plot_circle(x, y, r):
    """
    Plot a circle with radius r at point (x, y) in current plot handle
    """
    theta = np.linspace(0, 2*np.pi, 100)
    X = x + r * np.sin(theta)
    Y = y + r * np.cos(theta)
    plt.plot(X, Y, color="steelblue")


def plot_plate(X, Y, R, a):
    plt.figure(figsize=(6,6))

    # plot plate boundaries
    plt.plot([0, 0, a, a, 0], [0, a, a, 0, 0], 'k-')

    # plot circles
    for i, (x, y, r) in enumerate(zip(X, Y, R)):
        plot_circle(x, y, r)
        plt.plot(x, y, color="steelblue", marker="x")
        txt = "$s_" + str(i+1) + "$"
        plt.text(x+0.2, y+0.2, txt)

    plt.grid(True)
    plt.xlabel('$x$')
    plt.ylabel("$y$")
    plt.axis("equal")
    plt.show()


def animate_pendulum(Theta, dt=0.03):
    '''
    Create animation of a pendulum, where Theta contains the trajectory of its
    angle. dt defines the time gap (in seconds) between two succesive entries.
    '''
    fig = plt.figure()
    ax = fig.add_subplot(111, autoscale_on=False, xlim=(-1.2, 1.2), ylim=(-1.2, 1.2))
    ax.set_aspect('equal')
    ax.axis('off')

    # create empty plot
    line, = ax.plot([], [], 'o-', lw=2)

    def init():
        # placeholder for data
        line.set_data([], [])
        return line,

    def animate(i):
        # plot pendulum as defined by i-th entry of Theta
        thisx = [0, -np.sin(Theta[i])]
        thisy = [0, np.cos(Theta[i])]
        line.set_data(thisx, thisy)
        return line,

    ani = animation.FuncAnimation(fig, animate, Theta.size,
                                  interval=dt*1000, repeat_delay=500,
                                  blit=True, init_func=init)
    plt.show()
    return ani


def plot_cartpole(t_grid, u_max, U, X, saveas=None, plt_show=False):
    """
    Params:
        t_grid: time values of the discretization
        u_max: maximum absolute value of u
        U: arrray with shape (N_sim-1, nu) or (N_sim, nu)
        X_true: arrray with shape (N_sim, nx)
        saveas: filename to save the plot
        plt_show: whether execute plt.show() before returning
    """

    nx = X.shape[1]

    plt.subplot(nx+1, 1, 1)
    line, = plt.step(t_grid, np.append([U[0]], U))
    line.set_color('r')
    plt.ylabel('$u$')
    plt.hlines(u_max, t_grid[0], t_grid[-1], linestyles='dashed', alpha=0.7)
    plt.hlines(-u_max, t_grid[0], t_grid[-1], linestyles='dashed', alpha=0.7)
    plt.ylim([-1.2*u_max, 1.2*u_max])
    plt.xlim([t_grid[0], t_grid[-1]])
    plt.gca().set_xticklabels([])
    plt.grid()

    states_lables = ['$p_\mathrm{x}$', r'$\theta$', r'$v_\mathrm{x}$', r'$\omega$']

    for i in range(nx):
        plt.subplot(nx+1, 1, i+2)
        line, = plt.plot(t_grid, X[:, i])
        plt.ylabel(states_lables[i])
        plt.xlim([t_grid[0], t_grid[-1]])
        if i < nx-1:
            plt.gca().set_xticklabels([])
        plt.grid()

    plt.xlabel(r'time $t$ in s')

    plt.subplots_adjust(left=None, bottom=None, right=None, top=None, hspace=0.4)

    # plt.tight_layout()
    if saveas is not None: plt.savefig(saveas) 
    if plt_show: plt.show()


def animate_cartpole(Xtraj, dt=0.03, saveas=None, plt_show=False):
    '''
    Create animation of the cart pole system.
    Xtraj is a matrix of size N x 4, where N is the number of time steps, with the following columns:
    px, theta, v, omega

    dt defines the time gap (in seconds) between two succesive images.
    '''

    N = Xtraj.shape[0]
    pendulum_length = .8
    cart_width = .2
    cart_height = .1

    px = Xtraj[:, 0]
    theta = Xtraj[:, 1]

    # x and y position of the tip of the pendulum
    pendu_tip_x = px - pendulum_length * np.sin(theta)
    pendu_tip_y = 0 + pendulum_length * np.cos(theta)

    xmin = min(np.min(px), np.min(pendu_tip_x)) - 5 * cart_width / 2
    xmax = max(np.max(px), np.max(pendu_tip_x)) + 5 * cart_width / 2

    fig, ax = plt.subplots()

    def animate(i):
        ax.clear()

        # level of cart
        ax.plot([xmin, xmax], [0, 0], 'k--')

        # draw rectancle for cart
        cart = mpl.patches.Rectangle((px[i] - cart_width / 2, 0 - cart_height / 2), cart_width, cart_height, facecolor='C0')
        ax.add_patch(cart)

        # draw line for pendulum
        pendu = mpl.lines.Line2D([px[i], pendu_tip_x[i]], [0, pendu_tip_y[i]], color='k', linewidth=2)
        ax.add_line(pendu)

        # trace of pendulum tip
        ax.plot(pendu_tip_x[:i], pendu_tip_y[:i], color='lightgray', linewidth=1)

        ax.set_xlim([xmin, xmax])
        ax.set_ylim([-1.2, 1.2])
        ax.set_aspect('equal')

    ani = animation.FuncAnimation(fig, animate, N, interval=dt * 1000, repeat_delay=500, repeat=True)

    if saveas is not None: ani.save(saveas, dpi=100)
    if plt_show: plt.show()

    return ani
