import casadi as ca
import matplotlib.pyplot as plt
import numpy as np

from rk4step import rk4step
from animate_pendulum import animate_pendulum

# for people who cannot see an interactive plot, uncomment the following lines
# import matplotlib
# if matplotlib.get_backend() == 'agg':
#     matplotlib.use('WebAgg')
# print(f'backend: {matplotlib.get_backend()}')



# continuous dynamics
nx = 2  # state dimension
nu = 1  # control dimension
N = 50  # horizon
dt = 0.1  # discretization time step
N_rk4 = 10  # number of rk4 steps per discretization time step
x0_bar = np.array([np.pi, 0])   # initial state

# build integrator
def dynamics(x, u):
    return ca.vertcat(x[1], ca.sin(x[0]) + u)
h = dt / N_rk4
x = ca.SX.sym('x', nx, 1)
u = ca.SX.sym('u', nu, 1)

# discrete dynamics
x_next = x
for _ in range(N_rk4):
    x_next = rk4step(x, u, dynamics, dt)

# integrator
F = ca.Function('F', [x, u], [x_next])

## helpers to calculate cost
stage_cost = ca.Function('stage_cost', [x, u], [x.T@x + 2*ca.sum1(u**2)])
terminal_cost = ca.Function('terminal_cost', [x], [10 * x.T@x])


## NLP formulation
u = ca.SX.sym('U', N*nu)   # vector of all controls
u_guess = .1 * np.ones((N, 1))  # intial guess U

x = []  # build vector of all states dependend on U
L = 0  # to add cost

# note: you could use the casadi function fold here
x.append(x0_bar)
for k in range(N):
    x_k = x[k]
    u_k = u[k]
    L += stage_cost(x_k, u_k)
    
    x_next = F(x_k, u_k)  
    x.append(x_next)
L += terminal_cost(x[N])


# plot dense pattern of the Hessian of the lagrangian
(hess, grad) = ca.hessian(L, u)
hess_lagrangian = ca.Function('hess_lagrangian', [u], [hess])

plt.figure(1)
plt.spy(hess_lagrangian(u_guess).full(), markersize=3)


# create nlp solver
nlp = {'x': u, 'f': L}
solver = ca.nlpsol('solver','ipopt', nlp)

# solve nlp
sol = solver(x0=u_guess)


## visualize solution
u_opt = sol['x'].full()

x = ca.horzcat(*x)
FX = ca.Function('FX', [u], [x])
x_opt = FX(u_opt).full().T

plt.figure(2)
plt.subplot(2,1,1)
plt.plot(x_opt)
plt.title('state trajectory')
plt.legend([r'$\theta$', r'$\omega$'])

plt.subplot(2,1,2)
plt.step(range(N), u_opt, where='post')
plt.title('control trajectory')
plt.legend([r'$\tau$'])
plt.xlabel('discrete time $k$')


ani = animate_pendulum(x_opt[:,0])

# ani.save('pendulum.gif')

plt.show()
