import casadi as ca
import matplotlib.pyplot as plt
import numpy as np

from ballistic_dynamics import ballistic_dynamics
from ball_animate import ball_animate

# for people who cannot see an interactive plot, uncomment the following lines
# import matplotlib
# if matplotlib.get_backend() == 'agg':
#     matplotlib.use('WebAgg')
# print(f'backend: {matplotlib.get_backend()}')

# set to 1 / True to solve NLP for many alpha and plot the gradients
# of the constraints and the lagrange multipliers
PLOT_GRADIENTS = False

if PLOT_GRADIENTS:
    n_points = 25      # num of alpha values for which NLP is solved
    alpha_min = -1
    alpha_max = .6
    alphas = np.linspace(alpha_min,alpha_max,n_points)
else:
    n_points = 1
    alpha_val = .1  # solve NLP for this alpha only


T = 0.5                     # time horizon
N = 100                     # discretization steps
p0 = np.vstack(([0], [0]))  # Initial position of the ball
v_max = 15                  # max velocity

alpha = ca.MX.sym('alpha',1,1)    # variable for alpha parameter

# TODO: complete file ballistic_dynamics.py
x = ca.MX.sym('x',4,1)            # state variable
xdot = ballistic_dynamics(x)   # Model equations
f = ca.Function('f', [x], [xdot]) # ode function 

## Formulate discrete time dynamics
DT = T/N

# One step Runge-Kutta4 integratpr
k1 = f(x)
k2 = f(x + DT/2 * k1)
k3 = f(x + DT/2 * k2)
k4 = f(x + DT * k3)
xnext = x + DT/6 * (k1 + 2*k2 + 2*k3 + k4)
rk4step = ca.Function('rk4step', [x], [xnext]) 

# build expression of final state depending on initial velocity
v0 = ca.MX.sym('v0', 2, 1)        # initial velocity 
X = ca.vertcat(p0, v0)            # initial state (fixed)

# repeated integration
for j in range(N):
    X = rk4step(X)

# final state as function of initial velocity
F_finalstate = ca.Function('F', [v0], [X])

## NLP Formulation
## TODO: formulate NLP

# decision variable
w = ...


# cost
...

# constraints (build up a list)
g = []
lbg = []
ubg = []

# first constraint (equation (4b))
...

# second constraint  (equation (4c))
...

# third constraint  (equation (4d))
...

# convert g to a vector
g = ca.vertcat(*g)

# Insert you code here (Jacobian of the constraints g)
Jg = ca.Function('Jg', [w, alpha], ...)

# Insert you code here (create an NLP solver)
prob = {'f': ..., 'x': ..., 'g': ..., 'p': alpha}
solver = ca.nlpsol('solver', 'ipopt', prob)

lambdas = np.zeros((3,n_points))
    
def solve_prob(alpha_val):

    sol = solver(
        x0=..., lbx=..., ubx=..., lbg=..., ubg=...,
        p=alpha_val
    )

    w_opt = sol["x"].full()
    Jg_eval = Jg(w_opt,alpha_val).full()
    lam = sol["lam_g"].full().flatten()
    return w_opt, Jg_eval, lam

if PLOT_GRADIENTS:
    w_opt_list = []
    Jacobian_list = []
    lambdas = np.nan * np.ones((3, n_points))
    for i, alpha_val in enumerate(alphas):
        w_opt, Jg_eval, lam = solve_prob(alpha_val)
        w_opt_list += [w_opt]
        Jacobian_list += [Jg_eval]
        lambdas[:,i] = lam

    ani = ball_animate(alphas, w_opt_list, lambdas, Jacobian_list, v_max, F_finalstate)

if not PLOT_GRADIENTS:
    w_opt, _, _ = solve_prob(alpha_val)
    # compute trajectory
    X_traj = np.zeros((4, N+1))
    X_traj[:,0] = np.vstack((p0, w_opt[0], w_opt[1])).flatten()   # initial state
    # integrate
    for i in range(N):
        X_traj[:,i+1] = rk4step(X_traj[:,i]).full().flatten()
    
    fig = plt.figure(2)
    ax = fig.add_subplot()
    ax.plot(X_traj[0,:], X_traj[1,:])
    ax.plot([0,10], [10*alpha_val, 0])
    ax.set_xlabel(r'$y$')
    ax.set_ylabel(r'$z$')
    ax.legend([r'trajectory',r'$\alpha$ constraint'])
    ax.grid()

plt.show()
