import casadi as ca
import matplotlib.pyplot as plt
import numpy as np
from scipy.linalg import null_space


# for people who cannot see an interactive plot, uncomment the following lines
import matplotlib
if matplotlib.get_backend() == 'agg':
    matplotlib.use('WebAgg')
print(f'backend: {matplotlib.get_backend()}')


# Implementation of an interior point method 

# Problem definition
nv = 2
ne = 1
ni = 1

x = ca.MX.sym('x',nv)

x_test = np.array([2,3])
numtol = 1e-12

# YOUR CODE HERE:

# Objective
F = ca.Function('F', [x], [(x[0] - 4)**2 + (x[1] - 4)**2])
F_test = 5
if np.abs(F(x_test).full() - F_test) >= numtol:
    raise ValueError('error in F')

# Equality contstraints
G = ca.Function('G', [x], [ca.sin(x[0]) - x[1]**2])
G_test = -8.090702573174319
if np.abs(G(x_test).full() - G_test) >= numtol:
    raise ValueError('error in G')

# Inequality contstraints
H = ca.Function('H', [x], [x[0]**2 + x[1]**2 - 4])
h_test = 9.0
if np.abs(H(x_test).full() - h_test) >= numtol:
    raise ValueError('error in H')

# Jacobian of f
Jf = ca.Function('Jf', [x], [ca.jacobian(F(x), x)])
Jf_test = np.array([4,   -2])
if np.amin(np.abs(Jf(x_test).full() - Jf_test)) >= numtol:
    raise ValueError('error in Jf')

# Jacobian of g
Jg = ca.Function('Jg', [x], [ca.jacobian(G(x), x)])
Jg_test = np.array([-0.4161,   -6.0000])
if np.amin(np.abs(Jg(x_test).full() - Jg_test)) >= numtol:
    raise ValueError('error in Jg')

# Jacobian of h
Jh = ca.Function('Jh', [x], [ca.jacobian(H(x), x)])
Jh_test = np.array([4,   6])
if np.amin(np.abs(Jh(x_test).full() - Jh_test)) >= numtol:
    raise ValueError('error in Jh')

# Hessian of f
Hf = ca.Function('Hf', [x], [ca.hessian(F(x), x)[0]])
Hf_test = np.array([
    [   2,  0],
    [   0,  2]
])
if np.amin(np.abs(Hf(x_test).full() - Hf_test)) >= numtol:
    raise ValueError('error in Hf')

# Hessian of g
Hg = ca.Function('Hg', [x], [ca.hessian(G(x), x)[0]])
Hg_test = np.array([
    [  -0.9093,  0        ],
    [   0,        -2.0000 ]
])
if np.amin(np.abs(Hg(x_test).full() - Hg_test)) >= numtol:
    raise ValueError('error in Hg')
            
# Hessian of h
Hh = ca.Function('Hh', [x], [ca.hessian(H(x), x)[0]])
Hh_test = np.array([
    [   2,  0],
    [   0,  2]
])
if np.amin(np.abs(Hh(x_test).full() - Hh_test)) >= numtol:
    raise ValueError('error in Hh')


# Interior point solver
max_it = 100
# xk = np.array([-2, 4])
xk = np.array([-2, -4])  # decision variables
lk = 10*np.ones(ne)      # equality multipliers (lambda)
vk = 10*np.ones(ni)      # inequality multipliers (nu)
sk = 10*np.ones(ni)      # slack variables

# save iteration history in iter
z_iter = np.zeros((nv + ne + ni + ni,max_it))
z_iter[:, 0] = np.hstack((xk,lk,vk,sk))

# algorithm parameters
tau = 2        # initial value of tau
k_b = 1/3      # reduction factor for tau
th_1 = 1.0e-8  # decrease tau if rhs of KKT smaller than this
th_2 = 1.0e-8  # stop if tau samller than this
for i in range(1, max_it):
    # Build KKT system
    # relevant function
    g_e     = G(xk)
    h_e     = H(xk)
    # Jacobian
    Jg_e    = Jg(xk)
    Jh_e    = Jh(xk)
    Jf_e    = Jf(xk)
    # Hessian
    Hf_e    = Hf(xk)
    Hg_e    = Hg(xk)
    Hh_e    = Hh(xk)
    # Hessian of Lagrangian
    Hl      = Hf_e + Hg_e*lk + Hh_e*vk

    # YOUR CODE HERE:
    # Buiild the KKT system
    M = ca.blockcat([
        [Hl,    Jg_e.T,  Jh_e.T,      np.zeros((nv, ni))],  
        [Jg_e,    0,     0,           0],
        [Jh_e,    0,     0,           np.eye(ni)],
        [0, 0,    0,     np.diag(sk), np.diag(vk)]
    ]).full()
    # np.block

    rhs = -ca.vertcat(
        Jf_e.T + Jg_e.T @ lk + Jh_e.T @ vk,
        g_e,
        h_e + sk,
        vk*sk - tau
    ).full().flatten()
    # np.vstack

    # Termination condition
    if np.linalg.norm(rhs) < th_1:
        if tau < th_2:
            print('Solution found!')
            break
        else:
            tau = tau*k_b

    # YOUR CODE HERE:
    # Compute Newton step
    z_step = np.linalg.solve(M, rhs)
    
    # line-serach
    max_ls = 100
    x_step  = z_step[0:nv]
    l_step  = z_step[nv:nv+ne]
    v_step  = z_step[nv+ne:nv+ne+ni]
    s_step  = z_step[nv+ne+ni:None]
    alpha = 1
    k_ls = 0.9
    min_step = 1.0e-8
    for j in range(max_ls):
        
        # YOUR CODE HERE: 
        # Compute trial step
        v_t = vk + alpha * v_step
        s_t = sk + alpha * s_step
        if all(v_t >= 0) and all(s_t >= 0):
            break
        
        # YOUR CODE HERE:
        # Decrease alpha
        alpha = k_ls * alpha
        
        # Terminiation condition
        if np.linalg.norm(alpha*np.vstack((v_step, s_step))) < min_step:
            raise ValueError('Line search failed! Could not find dual feasible step.')
    
    # actual step
    xk  = xk + alpha*x_step
    lk  = lk + alpha*l_step
    vk  = vk + alpha*v_step
    sk  = sk + alpha*s_step
    # save for later processing
    z_iter[:,i] = np.hstack((xk,lk,vk,sk))

    # Print some results
    if i%20 == 1:  # every now and then reprint header
        print('-'*49,'\n')
        print('it \t tau \t\t ||rhs|| \t alpha\n')
        print('-'*49,'\n')
    print('%d \t %e \t %e \t %e\n' % (i, tau, np.linalg.norm(rhs), alpha))

z_iter = z_iter[:,0:i-1]


fig, ax = plt.subplots()
ax.plot(z_iter.T)
ax.grid()
ax.set_xlabel('iterations')
ax.set_ylabel('solution')
ax.legend([r'$x_1$', r'$x_2$', r'$\lambda$' , r'$\nu$', r'$s$'])


# Plot feasible set, and iterations
fig, ax = plt.subplots()
xy1d = np.linspace(-2*np.pi, 2*np.pi)
[X,Y] = np.meshgrid(xy1d, xy1d)
gxy = np.sin(X) - Y**2
cs = ax.contour(X, Y, gxy, [0], colors='red')
artists, _ = cs.legend_elements()
labels = [r'$g(x) = 0$']
hxy = X**2 + Y**2 -4
cs = ax.contour(X, Y, hxy, [0], colors='blue')
artists += cs.legend_elements()[0]
labels += [r'$h(x) = 0$']
fxy = (X-4)**2 + (Y-4)**2
ax.contour(X, Y, fxy, 10)
ax.plot(z_iter[0,:], z_iter[1,:], 'ko--', mfc='none', label='iters')
ax.set_title('Iterations in the primal space')
ax.grid()
ax.legend(artists, labels)


## check SOSC
tol = 1e-6                 # tolerance for when sth is zero

strict_complementarity = True
# check for activness of h
if np.abs(h_e.full()) <= tol:  # h is active
    Jg_tilde = np.vstack((Jg_e, Jh_e))  # extended equality constr jacobian
    if vk >= tol:
        print('h is strictly active')
    else:
        print('h is weakly active')
        strict_complementarity = False
else:
    print('h is inactive')
    Jg_tilde  = Jg_e.full()

# compute reduced Hessian
Z = null_space(Jg_tilde)    # nullspace
redH = Z.T @ Hl.full() @ Z  # reduced Hessian
eigs = np.linalg.eigvals(redH)
if len(eigs):
    mineig = np.amin(eigs)

if not strict_complementarity:
    print('strict complentarity does not hold.')
    print('The conditions for the the theorem second order optimality conditions are not fulfilled.')
elif np.size(Z) == 0 or mineig > tol:
    print('redH > 0. SOSC (and SONC) is fullfilled')
    print('The solution is a local minimizer.')
elif mineig >= -tol:
    print('redH >= 0. SONC is fullfilled')
    print('The solution might be a local minimizer.')
else:
    print('redH not PSD. Neither SONC nor SOSC hold.')
    print('The solution is not a local minimizer.')


plt.show()
