import casadi as ca
import matplotlib.pyplot as plt
import numpy as np

from newton_type import newton_type


# for people who cannot see an interactive plot, uncomment the following lines
# import matplotlib
# if matplotlib.get_backend() == 'agg':
#     matplotlib.use('WebAgg')
# print(f'backend: {matplotlib.get_backend()}')


def f(x, y):  # objective function
    return (1 - x)**2 + 100*(y - x**2)**2
w0 = np.array([1, 1.1])  # solution guess
numit = 1000  # max num of iters

plt.figure(1)
x = np.linspace(.8, 1.2, 100)
[X,Y] = np.meshgrid(x, x)
CS = plt.contour(X,Y, f(X,Y), levels=50)
plt.clabel(CS)
plt.scatter(w0[0], w0[1], color='c', marker='*',  label=r'$w_0$')

plt.axis([min(x), max(x), min(x), max(x)])
plt.xlabel(r'$x$')
plt.ylabel(r'$y$')
plt.legend()


# (b): impelment gradient and hessian
# TODO: write your own function


# (c): test with two different Hessian approximations
# TODO: fill in the correct parameters (replace "...")
# i) gradient desc, rho = 100
W_gd100 = newton_type(...)
plt.plot(W_gd100[0,:], W_gd100[1,:], 'rx-', label=r'GD, $\rho=100$' )

# i) gradien desc, rho = 500
W_gd500 = newton_type(...)
plt.plot( W_gd500[0,:], W_gd500[1,:], 'bo-', label=r'GD, $\rho=500$' )

# i) gradien desc, rho = 550
W_gd550 = newton_type(...)
plt.plot( W_gd550[0,:], W_gd550[1,:], 'ms-', label=r'GD, $\rho=550$' )

# ii) exact hessian
W_eh = newton_type(...)
plt.plot( W_eh[0,:], W_eh[1,:], 'k^-', label=r'exact Hessian' )
plt.legend()

# plot of convergence speed
plt.figure(2)
xopt = np.array([1, 1])[:, np.newaxis]  # true minimizer
plt.plot(range(numit+1), np.log10(np.amax(np.abs(W_gd500 - xopt), axis=0)), label=r'GD, $\rho=500$')
plt.plot(range(numit+1), np.log10(np.amax(np.abs(W_eh - xopt), axis=0)), label=r'exact hessian')
plt.xlabel(r'iteration $k$')
plt.ylabel(r'$\log \Vert w_k - w^* \Vert_\infty$')
plt.legend()


# (d): now use CasADi
# TODO: calculate gradient, hessian via Casadi


# exact hessian
# TODO: fill in the correct parameters (replace "...")
# HINT: remember to flatten/squeeze the gradient for a proper shape
W_cas = newton_type(...)

plt.figure(1)
plt.plot(W_cas[0,:], W_cas[1,:], 'gv-', label=r'exact Hessian, casadi')
plt.legend()

plt.show()
