import casadi as ca
from matplotlib import cm
import matplotlib.pyplot as plt
import numpy as np

# for people who cannot see an interactive plot, uncomment the following lines
# import matplotlib
# if matplotlib.get_backend() == 'agg':
#     matplotlib.use('WebAgg')
# print(f'backend: {matplotlib.get_backend()}')


HESSIAN_APPROXIMATION, rho = 'exact', None
# HESSIAN_APPROXIMATION, rho = 'identity', 600
# HESSIAN_APPROXIMATION, rho = 'identity', 100


nv = 2
x = ca.MX.sym('x', nv)

# Insert your code here (define objective function and its gradient and Hessian
# ...
# ...
# ...

# Insert your code here (define constraints and its Jacobian and Hessian)
# ...
# ...
# ...

# Task 1d)
# SQP solver
max_it = 100
x_iter = np.zeros((nv+1, max_it))
x_iter[:, 0] = [1, -1, 1]  # Initial guess

for i in range(max_it-1):

    x_k = x_iter[0:2, i]
    lambda_k = x_iter[2, i]
    if HESSIAN_APPROXIMATION == 'exact':
        # Insert your code here (exact Hessian)
        # ...
    elif HESSIAN_APPROXIMATION == 'identity':
        # Insert your code here (scaled identity approximation)
        # ...
    else:
        raise ValueError(f'HESSIAN_APPROXIMATION has an undefined value {HESSIAN_APPROXIMATION}.')

    # Build and solve the KKT system
    # ...
    # ...
    # ...


xy1d = np.linspace(-1.5, 1.5, num=61)
[X,Y] = np.meshgrid(xy1d, xy1d)
Z = np.log(1 + 1/2*(X -1)**2 + 1/2*(10*(Y -X**2))**2 + 1/2*Y**2)

y_g = np.linspace(-0.25, 1.5, num=20)
x_g = -(1 - y_g)**2

fig = plt.figure(1)
ax = fig.add_subplot(1,2,1, projection='3d')
ax.plot_surface(X,Y,Z, cmap=cm.coolwarm)
ax.plot(x_iter[0,:], x_iter[1,:], 'ko-', label='solution trajectory')
ax.plot(x_g, y_g, 'r', label=r'$g(x) = 0$')
ax.set_xlim([-1.5,1.5])
ax.set_ylim([-1.5,1.5])
ax.set_xlabel(r'$x_1$')
ax.set_ylabel(r'$x_2$')
ax.legend()

ax = fig.add_subplot(1,2,2)
ax.plot(x_iter[0,:], x_iter[1,:], 'ko-')
ax.plot(x_g, y_g,'r')
ax.contour(X,Y,Z)
ax.set_xlim([-1.5,1.5])
ax.set_ylim([-1.5,1.5])
ax.set_xlabel(r'$x_1$')
ax.set_ylabel(r'$x_2$')
ax.legend()

fig = plt.figure(2)
ax = plt.gca()
ax.plot(x_iter[0:2,:].T)
ax.grid()
ax.set_xlabel('iterations')
ax.set_ylabel('primal solution')
ax.legend([r'$x_1$', r'$x_2$'])


plt.show()
