import timeit

import casadi as ca
import matplotlib.pyplot as plt
import numpy as np

from forw_AD import forw_AD
from back_AD import back_AD
from J_FAD import J_FAD
from J_BAD import J_BAD

# for people who cannot see an interactive plot, uncomment the following lines
# import matplotlib
# if matplotlib.get_backend() == 'agg':
#     matplotlib.use('WebAgg')
# print(f'backend: {matplotlib.get_backend()}')

# construct function
x0 = 0.5
T = 5
N = 100
h = T / N

x = ca.SX.sym('x',1)
u = ca.SX.sym('u',1)
U = ca.SX.sym('U',N)

f = ca.Function('f', [x,u], [x + h * ((1 - x) * x + u)])

Phi_expr = ca.vertcat(x0)
for i in range(N):
    Phi_expr = ca.vertcat(Phi_expr, f(Phi_expr[i], U[i]))
Phi_expr = Phi_expr[1:None]

Phi = ca.Function('Phi', [U], [Phi_expr])
J = ca.Function('J', [U], [ca.jacobian(Phi_expr, U)])

## tests
# compare computed  derivatives to d obtained with CasADi
utest = np.random.rand(N,1)
Jref = J(utest).full()

m = 0
print("Mismatch between hand-coded forward AD and CasADi (column %d):" % m)
print("\t%e\n" % np.amax(Jref[:,m] - forw_AD(utest, m, x0, h)))

m = N-1
print("Mismatch between hand-coded backward AD and CasADi (row %d):" % m)
print("\t%e\n" % np.amax(Jref[m,:] - back_AD(utest, m, x0, h)))

print("Mismatch between hand-coded forward AD and CasADi (full jacobian)")
print("\t%e\n" % np.amax(Jref[N-m:None,:] - J_FAD(utest, m, x0, h)))

print("Mismatch between hand-coded backward AD and CasADi (full jacobian)")
print("\t%e\n" % np.amax(Jref[N-m:None,:] - J_BAD(utest, m, x0, h)))

## timing
# TODO: implement timing experiments
