import casadi as ca
import matplotlib.pyplot as plt
import numpy as np

from ballistic_dynamics_RK4 import ballistic_dynamics_RK4
from newton_type import newton_type
from ode import ode
from rk4step import rk4step


# for people who cannot see an interactive plot, uncomment the following lines
import matplotlib
if matplotlib.get_backend() == 'agg':
    matplotlib.use('WebAgg')
print(f'backend: {matplotlib.get_backend()}')


# decision variables
alpha1 = ca.SX.sym('alpha1',1)
alpha2 = ca.SX.sym('alpha2',1)
w = ca.vertcat(alpha1, alpha2)

vbar = 15
v = lambda alpha:  vbar * ca.vertcat(ca.cos(alpha), ca.sin(alpha))

# expression for objective
P =  ballistic_dynamics_RK4(ca.vertcat(v(alpha1), v(alpha2)))
p1 = P[0:2]
p2 = P[2:4]
dp = p1 - p2
f_expr = dp.T @ dp

# gradient and hessian
f = ca.Function('f', [w], [f_expr])
hess_f_expr, grad_f_expr = ca.hessian(f_expr, w)
grad_f = ca.Function('grad_f', [w], [grad_f_expr])
hess_f = ca.Function('hess_f', [w], [hess_f_expr])

## solve
numit = 20
w0 = np.array([np.pi/4, 3 * np.pi/4])

W = newton_type(
    w0,
    lambda x: grad_f(x).full().flatten(),
    lambda x: hess_f(x).full().squeeze(),
    numit
)
wopt = W[:, -1]
print('objective value at sol: ', f(wopt).full().item())


plt.figure(1)
plt.plot(range(numit+1), W[0,:], label=r'$\alpha_1$')
plt.plot(range(numit+1), W[1,:], label=r'$\alpha_2$')
plt.xlabel(r'iteration $k$')
plt.legend()

# simulate optimal solution
v1opt = v(wopt[0])
v2opt = v(wopt[1])

T = 0.5
M = 100
DT = T/M
X0 = np.vstack((0, 0, v1opt, 10, 0, v2opt)).flatten()
X = np.zeros((X0.size, M+1))
X[:,0] = X0

# RK4 integrator
for j in range(M):
    # insert your code calling the provided ode() function here
    X[:,j+1] = rk4step(lambda t,x: ode(x), DT, X[:,j], 0).full().flatten()


plt.figure(2)
plt.plot(X[0,:], X[1,:], label=r'$p_1$')
plt.plot(X[4,:], X[5,:], label=r'$p_2$')
plt.legend()

plt.show()
