import casadi as ca
import matplotlib.pyplot as plt
import numpy as np

from rk4step import rk4step


# for people who cannot see an interactive plot, uncomment the following lines
import matplotlib
if matplotlib.get_backend() == 'agg':
    matplotlib.use('WebAgg')
print(f'backend: {matplotlib.get_backend()}')


## parameters required for the OCP
nx = 2  # state dimension
nu = 1  # control dimension
N = 50  # horizon
dt = 0.1  # discretization time step
N_rk4 = 10  # number of rk4 steps per discretization time step
x0_bar = np.array([np.pi, 0])   # initial state

## build the integrator
# ode
def dynamics(x, u):
    return ca.vertcat(x[1], ca.sin(x[0]) + u)

# discrete dynamics
x = ca.SX.sym('x', nx, 1)
u = ca.SX.sym('u', nu, 1)
x_next = x
for _ in range(N_rk4):
    x_next = rk4step(x, u, dynamics, dt)

# integrator
F = ca.Function('F', [x, u], [x_next])


## helpers to calculate cost
stage_cost = ca.Function('stage_cost', [x, u], [x.T@x + 2*ca.sum1(u**2)])
terminal_cost = ca.Function('terminal_cost', [x], [10 * x.T@x])

roh = .01
barrier_term = ca.Function('barrier', [u], [-roh*(ca.log(1-u) + ca.log(1+u))])


## NLP formulation
w = []  # decision variables
g = []  # all constraints
L = 0  # cost

x_k = x0_bar
for k in range(N):
    u_k = ca.MX.sym(f'U_{k}', nu, 1)
    L += stage_cost(x_k, u_k)
    L += barrier_term(u_k)

    x_next = F(x_k, u_k)

    x_k = ca.MX.sym(f'X_{k+1}', nx, 1)
    w.extend([u_k, x_k])
    g.extend([x_next - x_k])
L += terminal_cost(x_k)


## lagrangian function
lagrangian = L
z = []
for k in range(N):
    # primal and dual variables in the required order 
    u_k = w[2*k]
    lam_k = ca.MX.sym(f'lam_{k}', nx, 1)
    x_next = w[2*k+1]
    lagrangian += lam_k.T@g[k]
    z.extend([u_k, lam_k, x_next])

z = ca.vertcat(*z)
(hess, grad) = ca.hessian(lagrangian, z)
hess_lagrangian = ca.Function('hess_lagrangian', [z], [hess])

plt.figure(1)
z_mock = np.ones(z.shape)
plt.spy(hess_lagrangian(z_mock).full(), markersize=3)


## nlp solver
w = ca.vertcat(*w)
g = ca.vertcat(*g)
nlp = {'x': w, 'f': L, 'g': g}
solver = ca.nlpsol('solver','ipopt', nlp)

# solve nlp
w0 = 0.1  # initial guess
sol = solver(
    x0=w0,
    lbg=np.zeros(g.shape),
    ubg=np.zeros(g.shape),
)

## visualize solution
w_opt = sol['x'].full()
u_opt = w_opt[0:None:nx+nu]
x_opt = np.column_stack((
    w_opt[1:None:nx+nu],
    w_opt[2:None:nx+nu]
))
x_opt = np.vstack((x0_bar, x_opt))

plt.figure(2)
plt.subplot(2,1,1)
plt.plot(x_opt)
plt.title('state trajectory')
plt.legend([r'$\theta$', r'$\omega$'])

plt.subplot(2,1,2)
plt.step(range(N), u_opt, where='post')
plt.title('control trajectory')
plt.legend([r'$\tau$'])
plt.xlabel('discrete time $k$')


plt.show()
