import casadi as ca
import matplotlib.pyplot as plt
import numpy as np


# for people who cannot see an interactive plot, uncomment the following lines
# import matplotlib
# if matplotlib.get_backend() == 'agg':
#     matplotlib.use('WebAgg')
# print(f'backend: {matplotlib.get_backend()}')

T = 10                        # continuous time horizon

# state variables
x = ca.SX.sym('x', 2)         # state
lam = ca.SX.sym('lam', 2)     # costate
aug = ca.vertcat(x, lam)      # augmented state

# boundary values (TODO)
x0bar = np.vstack((0, 1))
lamTbar = np.vstack((0, 0))

# optimal control (TODO)
ca_clip = lambda val, l_bound, u_bound: ca.fmax(l_bound, ca.fmin(u_bound, val))
uopt = ca_clip(-lam[0]/2, -1, 1)

# state dynamics (TODO)
xdot = ca.vertcat(
    (1 - x[1]**2) * x[0] - x[1] + uopt,
    x[0]
)

# costate dynamics (TODO)
lamdot = ca.vertcat(
    -2 * x[0] - lam[0] * (1 - x[1]**2) - lam[1],
    -2 * x[1] + 2 * lam[0] * x[0] * x[1] + lam[0]
)

# augmented dynamics (TODO)
augdot = ca.vertcat(xdot, lamdot)

# integrator
# aim: create integrator function F(x_in, tf) that integrates the
# given dynamics starting at state x_in over a time interval tf and returns
# the resulting state.

# integration interval, symbolic parameter
tf = ca.SX.sym('tf')                 

# Create a dictionary defining the ODE in the form required by CasADi.
# The dynamics are time-scaled by tf.
# This is because CVODES by default integrates over the time interval [0,1].
# Integration of the scaled dynamics from 0 to 1 is equivalent to
# integration of the original unscaled dynamics from 0 to tf
ode = {'x': aug, 'p': tf, 'ode': tf * augdot}

# some options (absolut and relative integration error tolerance)
opts = {'abstol': 1e-8, 'reltol': 1e-8}

# create the desired integrator function
F = ca.integrator('F', 'cvodes', ode, opts)

## build dummy NLP
# initial value of lambda (to be found)
lam0 = ca.MX.sym('lam0', 2)

# compute augmented state at time T dependend on lam0
integrator_output = F(x0=ca.vertcat(x0bar, lam0), p=T)
# augmented state at time T 
augT = integrator_output["xf"]
# costate at time T
lamT = augT[2:4]

# terminal condition
# this residual should be zero (TODO)
g = lamT - lamTbar
lbg = 0
ubg = 0

# NLP with dummy objective
nlp = {'x': lam0, 'f': 0, 'g': g}
solver = ca.nlpsol('solver', 'ipopt', nlp)
sol = solver(lbg=lbg, ubg=ubg)
lam0opt = sol["x"].full()

## simulate solution
# integrate in N timesteps to get intermediate results
N = 100
DT = T/N

# integration loop
AUG0 = np.vstack((x0bar, lam0opt)).flatten()
AUG = np.empty((AUG0.size, N+1))
AUG[:, 0] = AUG0
for i in range(N):
    intres = F(x0=AUG[:, i], p=DT)
    AUG[:, i+1] = intres["xf"].full().flatten()

# split int state, costate, controls
xopt = AUG[0:2, :]
lamopt = AUG[2:4, :]
u_opt = np.clip(-lamopt[0,:]/2, -1, 1)
# time grid
tvec = np.linspace(0, T, num=N+1)

plt.figure(1)
plt.subplot(3,1,1)
plt.plot(tvec, xopt[0,:] )
plt.plot(tvec, xopt[1,:] )
plt.legend([r'$x_1(t)$', r'$x_2(t)$'])
plt.ylabel('states')
plt.xlim([0, T])

plt.subplot(3,1,2)
plt.plot(tvec, lamopt[0,:] )
plt.plot(tvec, lamopt[1,:] )
plt.legend([r'$\lambda_1(t)$', r'$\lambda_2(t)$'])
plt.ylabel('costates')
plt.xlim([0, T])

plt.subplot(3,1,3)
plt.plot(tvec, u_opt)
plt.legend([r'$u(t)$'])
plt.xlabel(r'time $t$')
plt.ylabel('control')
plt.xlim([0, T])

plt.show()
