import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

def ball_animate(alphas, w_opt_list, lambdas, Jacobians, v_max, F_finalstate):

    alpha_min = np.min(alphas)
    alpha_max = np.max(alphas)
    n_points = alphas.shape[0]

    # plotting preparations
    # constrained norm of v0 constr
    theta = np.linspace(0, 2*np.pi, 100)
    Vcirc = v_max * np.array([np.sin(theta), np.cos(theta)])

    # ground constraint
    vv = np.linspace(-15, 15, 20)  # plot constraints in this range
    V1, V2 = np.meshgrid(vv,vv)
    G_ground = np.zeros(V1.shape)
    for i in range(V1.shape[0]):
        for j in range(V1.shape[1]):
            xxx = F_finalstate( np.vstack((V1[i,j], V2[i, j])) )
            G_ground[i,j] = -xxx[1].full()

    fig = plt.figure(1)
    # plot the normalized gradients
    ax = fig.add_subplot(2,2,1)
    [line1] = ax.plot([], [])
    [line2] = ax.plot([], [])
    [line3] = ax.plot([], [])

    ax.set_title(r'constraint Jacobians (normalized)')
    ax.set_xlabel(r'$y$')
    ax.set_ylabel(r'$z$')
    ax.grid()
    ax.set_xlim([-1,1])
    ax.set_ylim([-1,1])
    ax.legend([r'ground constr',r'$\alpha$ constr',r'$\vert v_0 \vert$ constr'], loc='lower left')

    # plot the lagrange multipliers
    ax = fig.add_subplot(2,2,2)
    [line4] = ax.semilogy([], [])
    [line5] = ax.semilogy([], [])
    [line6] = ax.semilogy([], [])
    ax.set_title(r'Lagrange multipliers')
    ax.set_xlabel(r'$\alpha$')
    ax.set_ylabel(r'$\lambda$')
    ax.grid()
    ax.legend([r'ground constr',r'$\alpha$ constr',r'$\Vert v_0 \Vert$ constr'], loc='upper left')
    ax.set_xlim([alpha_min, alpha_max])
    ax.set_ylim([1e-10, 1e4])

    # plot the constraints
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    ax = fig.add_subplot(2,2,3)
    [line7] = ax.plot(Vcirc[0,:], Vcirc[1,:], color=colors[2], label=r'$\Vert v_0 \Vert$ constr')
    line8 = ax.contour(V1, V2, G_ground, [0], colors=colors[0])
    line8.collections[0].set_label(r'ground constr')
    line9 = ax.contour(V1, V2, G_ground, [0], colors=colors[1])  # G_ground is a mock data for initialization
    line9.collections[0].set_label(r'$\alpha$ constr')
    [line10] = ax.plot([], [], 'm*', label=r'$v_0^*$')
    ax.set_xlabel(r'$v_{0,y}$')
    ax.set_ylabel(r'$v_{0,z}$')
    ax.legend(loc='lower left')
    ax.axis('equal')

    def init():
        return  [line1, line2, line3, line4, line5, line6, line7] + line8.collections +  line9.collections + [line10]

    def animate(i):
        alpha_val = alphas[i]
        Jg_eval = Jacobians[i]
        w_opt = w_opt_list[i]
        G_alpha = np.zeros(V1.shape)
        for k in range(V1.shape[0]):
            for j in range(V1.shape[1]):
                xxx =  F_finalstate( np.vstack((V1[k,j], V2[k, j])) )
                G_alpha[k, j] = (-alpha_val * (xxx[0] - 10) - xxx[1]).full()  # alpha constraint

        line1.set_data([0, Jg_eval[0,0]/np.linalg.norm(Jg_eval[0,:])], [0, Jg_eval[0,1]/np.linalg.norm(Jg_eval[0,:])])
        line2.set_data([0, Jg_eval[1,0]/np.linalg.norm(Jg_eval[1,:])], [0, Jg_eval[1,1]/np.linalg.norm(Jg_eval[1,:])])
        line3.set_data([0, Jg_eval[2,0]/np.linalg.norm(Jg_eval[2,:])], [0, Jg_eval[2,1]/np.linalg.norm(Jg_eval[2,:])])
        
        line4.set_data(alphas[0:i+1], np.abs(lambdas[0, 0:i+1]))
        line5.set_data(alphas[0:i+1], np.abs(lambdas[1, 0:i+1]))
        line6.set_data(alphas[0:i+1], np.abs(lambdas[2, 0:i+1]))
        
        line9 = ax.contour(V1, V2, G_alpha, [0], colors=colors[1])
        line10.set_data(w_opt[0], w_opt[1])

        return [line1, line2, line3, line4, line5, line6, line7] + line8.collections +  line9.collections + [line10]

    ani = FuncAnimation(
        fig, animate, frames=range(n_points), interval=1000, repeat_delay=3000,
        blit=True, init_func=init
    )
    return ani