import numpy as np
import matplotlib.animation as animation
import matplotlib.pylab as plt


# # for people who cannot see an interactive plot, uncomment the following lines
# import matplotlib
# if matplotlib.get_backend() == 'agg':
#     matplotlib.use('WebAgg')
# print(f'backend: {matplotlib.get_backend()}')


def plot_dynamic_programming(
    x1_values, x2_values, u_values, LQR_cost, LQR_u, J_maps, u_maps
):
    X1, X2 = np.meshgrid(x1_values, x2_values)
    J_maps = np.ma.array(J_maps, mask=np.isinf(J_maps))  # for correct color
    u_maps = np.ma.array(u_maps, mask=np.isnan(u_maps))  # for correct color
    fig = plt.figure(figsize=(12,6))
    ax1 = fig.add_subplot(1,2,1, projection='3d')
    ax2 = fig.add_subplot(1,2,2, projection='3d')

    N = u_maps.shape[0]
    ax1.plot_wireframe(
        X1.T, X2.T, LQR_cost,
        color="lightseagreen", cstride=1, rstride=0, linestyle="-", linewidths=0.8
    )
    ax1.set_title(r'Cost-to-go function')
    ax1.set_xlabel(r'$\theta$')
    ax1.set_ylabel(r'$\omega$')
    ax1.set_zlabel(r'$J_{cost}$')

    ax2.plot_wireframe(
        X1.T, X2.T, LQR_u,
        color="lightseagreen", cstride=1, rstride=0, linestyle="-", linewidths=0.8
    )
    ax2.set_title(r'Optimal feedback control')
    ax2.set_xlabel(r'$\theta$')
    ax2.set_ylabel(r'$\omega$')
    ax2.set_zlabel(r'$\tau$')
    ax2.set_zlim([np.amin(u_values), np.amax(u_values)])
    
    artists = []
    for k in reversed(range(N+1)):
        surface1 = ax1.plot_surface(
            X1.T, X2.T, J_maps[k],
            cmap='coolwarm', vmax=np.nanmax(J_maps), vmin=np.nanmin(J_maps)
        )
        if k == N:
            # plot a fake surface
            surface2 = ax2.plot_surface(X1.T, X2.T, np.full_like(u_maps[0], np.nan))
        else:
            surface2 = ax2.plot_surface(
                X1.T, X2.T, u_maps[k],
                cmap='coolwarm', vmax=np.amax(u_values), vmin=np.amin(u_values)
            )
        artists.append([surface1, surface2])

    ani =  animation.ArtistAnimation(
        fig, artists, blit=False, interval=1000, repeat_delay=2000
    )
    fig.savefig('DP_iterations.pdf', dpi=300)
    return ani

def plot_closed_loop_simulation(
    t_grid, cost_LQR, state_LQR, control_LQR, cost_DP, state_DP, control_DP
):
    fig, axes = plt.subplots(2, 2, figsize=(12, 6), sharex=True)
    fig.suptitle('Closed-loop simulation')
    artists = []
    for i in range(len(t_grid)):
        ax = axes[0,0]
        [line1] = ax.plot(t_grid[0:i+1], state_LQR[0, 0:i+1], 'r')
        [line2] = ax.plot(t_grid[0:i+1], state_DP[0, 0:i+1], 'b')
        ax.set_ylabel(r'state $\phi$')

        ax = axes[1,0]
        [line3] = ax.plot(t_grid[0:i+1], state_LQR[1, 0:i+1], 'r')
        [line4] = ax.plot(t_grid[0:i+1], state_DP[1, 0:i+1], 'b')
        ax.set_xlabel(r'time $t$')
        ax.set_ylabel(r'state $\omega$')

        ax = axes[0,1]
        [line5] = ax.plot(t_grid[0:i+1], np.append(control_LQR[0:i], np.nan), 'r')
        [line6] = ax.plot(t_grid[0:i+1], np.append(control_DP[0:i], np.nan), 'b')
        ax.set_ylabel(r'control $\tau$')
        ax.legend([r'LQR', r'DP'], loc='upper right')

        ax = axes[1,1]
        [line7] = ax.plot(t_grid[0:i+1], cost_LQR[0:i+1], 'ro--', markerfacecolor='none')
        [line8] = ax.plot(t_grid[0:i+1], cost_DP[0:i+1], 'bo--', markerfacecolor='none')
        ax.set_xlabel(r'time $t$')
        ax.set_ylabel(r'closed-loop cost $L$')

        artists.append([line1, line2, line3, line4, line5, line6, line7, line8])

    ani = animation.ArtistAnimation(fig, artists, blit=False, interval=50, repeat=False)
    fig.savefig('closed_loop_simulation.pdf', dpi=300)
    return ani

def animate_pendulum(Theta, dt=0.03):
    '''
    Create animation of a pendulum, where Theta contains the trajectory of its
    angle. dt defines the time gap (in seconds) between two succesive entries.

    Theta should be a list or 1D-numpy array
    '''
    fig = plt.figure()
    ax = fig.add_subplot(111, autoscale_on=False, xlim=(-1.2, 1.2), ylim=(-1.2, 1.2))
    ax.set_aspect('equal')
    ax.axis('off')

    # create empty plot
    line, = ax.plot([], [], 'o-', lw=2)

    def init():
        # placeholder for data
        line.set_data([], [])
        return line,

    def animate(i):
        # plot pendulum as defined by i-th entry of Theta
        thisx = [0, -np.sin(Theta[i])]
        thisy = [0, np.cos(Theta[i])]
        line.set_data(thisx, thisy)
        return line,

    ani = animation.FuncAnimation(fig, animate, Theta.size,
                                  interval=dt*1000, repeat_delay=500,
                                  blit=True, init_func=init)
    return ani