import casadi as ca
import matplotlib.pyplot as plt
import numpy as np


# for people who cannot see an interactive plot, uncomment the following lines
import matplotlib
if matplotlib.get_backend() == 'agg':
    matplotlib.use('WebAgg')
print(f'backend: {matplotlib.get_backend()}')


# Implementation of a Gauss-Neston SQP solver usign CasADi

# Problem definition
nv = 2

x = ca.MX.sym('x',nv)


# Define objective function
f = (x[0]- 4)**2 + (x[1] - 4)**2
F = ca.Function('F', [x], [f])
Jf = ca.Function('Jf', [x], [ca.jacobian(f, x)])

# Define residuals (for Gauss-Newton Hessian approximation)
r = ca.sqrt(2) * ca.vertcat(x[0]-4, x[1] - 4)
R = ca.Function('R', [x], [r])
Jr = ca.Function('Jr', [x], [ca.jacobian(r,x)])

# Define equalities 
g = ca.sin(x[0]) - x[1]**2
G = ca.Function('G', [x], [g])
Jg = ca.Function('Jg', [x], [ca.jacobian(g, x)])

# Define inequalities 
h = x[0]**2 + x[1]**2 - 4
H = ca.Function('H', [x], [h])
Jh = ca.Function('Jh', [x], [ca.jacobian(h, x)])

# Define linearization point
xk = ca.MX.sym('xk', nv)

# define decision variable
delta_x = ca.MX.sym('delta_x',nv)

# Define linearized constraints 
# linearized equality constraints
g_l = G(xk) + Jg(xk) @ delta_x
# linearized inequality constraints
h_l = H(xk) + Jh(xk) @ delta_x

# Gauss-Newton Hessian approximation
Bk = Jr(xk).T @ Jr(xk)
# quadratic Gauss-Newton objective function
f_gn = .5 * delta_x.T @ Bk @ delta_x + Jf(xk) @ delta_x

# Allocate QP solver
qp = {'x': delta_x, 'f': f_gn, 'g': ca.vertcat(g_l, h_l), 'p': xk}
solver = ca.qpsol('solver', 'qpoases', qp)


# SQP solver
max_it = 100
xk = np.array([-2, 4])
# xk = np.array([-2, -4])
z_iter = np.zeros((nv, max_it))
z_iter[:, 0] = xk

lbg = [0, -ca.inf]
ubg = [0, 0]

for i in range(1, max_it):
    # Solve the QP
    sol = solver(lbg=lbg, ubg=ubg, p=xk)
    step = sol['x'].full().flatten()
    if np.linalg.norm(step) < 1.0e-16:
        break

    # line-serach
    t = 1
    kappa = 0.7
    alpha = 1.1
    out = F(z_iter[:, i-1])
    prev_cost = out.full().item()
    next_cost = ca.inf
    while (next_cost > alpha*prev_cost):
        trial = z_iter[:, i-1] + t*step
        out = F(trial)
        next_cost = out.full().item()
        t = t*kappa
    z_iter[:, i] = z_iter[:, i-1] + t*step
    xk = z_iter[:, i]

z_iter = z_iter[:,0:i-1]


fig, ax = plt.subplots()
ax.plot(z_iter.T)
ax.grid()
ax.set_xlabel('iterations')
ax.set_ylabel('primal solution')
ax.legend([r'$x_1$', r'$x_2$'])


# Plot feasible set, and iterations
fig, ax = plt.subplots()
xy1d = np.linspace(-2*np.pi, 2*np.pi)
[X,Y] = np.meshgrid(xy1d, xy1d)
gxy = np.sin(X) - Y**2
cs = ax.contour(X, Y, gxy, [0], colors='red')
cs.collections[0].set_label(r'$g(x)$')
hxy = X**2 + Y**2 -4
cs = ax.contour(X, Y, hxy, [0], colors='blue')
cs.collections[0].set_label(r'$h(x)$')
fxy = (X-4)**2 + (Y-4)**2
ax.contour(X, Y, fxy, 10)
ax.plot(z_iter[0,:], z_iter[1,:], 'ko--', mfc='none', label='iters')
ax.set_title('Iterations in the primal space')
ax.grid()
ax.legend()


plt.show()
