import matplotlib.pyplot as plt
import numpy as np

from LQR_design import LQR_design
from utils import RK4_integrator, ode, project
from plot_utils import plot_dynamic_programming


Ts = 0.1                   # time discretization grid constant
nSteps = 1                 # number of integration steps during Ts
Q = np.diag([100,0.01])    # state weighting matrix
R = np.diag([0.001])       # control weighting matrix


# limits of state and control space
x1_max = 2*np.pi
x1_min = -np.pi/2
x2_max = 10                # lower bound is -x2_max
u_max = 10                 # lower bound is -u_max

N = 20                     # control horizon
N_x1 = 200                 # number of discretization points for state x1
N_x2 = 40                  # number of discretization points for state x2
N_u = 20                   # number of discretization points for control u


# discretized state and control space
x1_values = np.linspace( x1_min, x1_max, N_x1)
x2_values = np.linspace(-x2_max, x2_max, N_x2)
u_values = np.linspace(-u_max, u_max, N_u)


# design LQR controller
# TODO: complete LQR_design.py
K, P = LQR_design(Q, R, Ts, nSteps)
LQR_cost = np.zeros((N_x1, N_x2))
LQR_u = np.zeros((N_x1, N_x2))
for (i, j) in np.ndindex((N_x1, N_x2)):
    x = np.vstack((x1_values[i], x2_values[j]))
    # TODO: compute cost and control at each state combination
    LQR_cost[i, j] = ...
    LQR_u[i, j] = ...
# clip the controls at the lower and upper bound
LQR_u = np.clip(LQR_u, -u_max, u_max)


# the DP operator
def DP_operator(J_map):
    '''dynamic programming operator, i.e. J[k] = DP_operator(J[k+1])'''
    if np.ndim(J_map) != 2:
        raise ValueError("J_map must be rectangular")
    # initialize cost-to-go as infinity
    new_J_map = np.full_like(J_map, np.inf)
    # initialize optimal control as NaN (Not A Number)
    new_u_map = np.full_like(J_map, np.nan)

    # loop through all state combinations
    for i1, i2 in np.ndindex((x1_values.size, x2_values.size)):
        x = np.vstack((x1_values[i1], x2_values[i2]))
        # loop through all control
        for u in u_values:
            if np.ndim(u) == 0:
                u = np.array([u])
            # apply control
            next_x, _, _ = RK4_integrator(x, u, Ts, nSteps, ode)
            
            # TODO: project on discretization grid
            next_i1 = ...
            next_i2 = ...

            # if not on the grid, skip this control
            if (next_i1 is None) or (next_i2 is None):
                continue
            # else (on the grid)
            # TODO: compute cost-to-go
            candidate_J = ...

            # if the cost-to-go is better
            if candidate_J < new_J_map[i1, i2]:
                #TODO: update the optimal control 
                new_J_map[i1, i2] = ...
                new_u_map[i1, i2] = ...
    return (new_J_map, new_u_map)


# design via dynamic programming
J_maps = np.empty((N+1, N_x1, N_x2))
u_maps = np.empty((N, N_x1, N_x2))
# terminal cost of OCP is same as LQR
J_maps[N] = LQR_cost
# compute cost-to-go of initial state via backward recursion
for k in reversed(range(N)):
    J_maps[k], u_maps[k] = DP_operator(J_maps[k+1])


# save some variables to file
np.savez(
    'data.npz',
    Ts=Ts, nSteps=nSteps,
    Q=Q, R=R, K=K,
    x1_values=x1_values, x2_values=x2_values,
    u_max=u_max, u_map=u_maps[0]
)
# show the dynamic plots
plot = plot_dynamic_programming(
    x1_values, x2_values, u_values, LQR_cost, LQR_u, J_maps, u_maps
)
plt.show()
