from typing import Tuple
from acados_template import AcadosOcp, AcadosOcpSolver, AcadosSimSolver
from cartpole_model import export_cartpole_ode_model
from utils import plot_cartpole, animate_cart_pole
import numpy as np
import scipy.linalg
from casadi import vertcat

def main(use_RTI=False):
    
    # for the first MPC task, we have use_RTI == False.
    # i.e., ignore everything preceded by "if use_RTI:"

    # set initial state
    x0 = np.array([0.0, np.pi, 0.0, 0.0])
    Fmax = 80

    # OCP horizon length
    Tf = .8
    # OCP horizon discretation (number of shooting nodes)
    N_horizon = 40

    # setup ocp solver and integrator
    ocp_solver, integrator = setup(x0, Fmax, N_horizon, Tf, use_RTI)

    nx = ocp_solver.acados_ocp.dims.nx
    nu = ocp_solver.acados_ocp.dims.nu

    # simulation steps for "real" system
    Nsim = 100
    # preallocate arrays for closed loop simulation
    simX = np.zeros((Nsim+1, nx))
    simU = np.zeros((Nsim, nu))
    simX[0,:] = x0

    # preaollacate timing arrays
    if use_RTI:
        t_preparation = np.zeros((Nsim))
        t_feedback = np.zeros((Nsim))
    else:
        t = np.zeros((Nsim))

    # do some initial iterations to start with a good initial guess
    if use_RTI:
        num_iter_initial = 5
    else:
        num_iter_initial = 1
    for _ in range(num_iter_initial):
        ocp_solver.solve_for_x0(x0_bar = x0)

    # closed loop
    for i in range(Nsim):
        
        # TODO: initial state for OCP solved in this iteration
        x0_bar = ..
        if use_RTI:
            # preparation phase
            ocp_solver.options_set('rti_phase', 1)
            status = ocp_solver.solve()
            t_preparation[i] = ocp_solver.get_stats('time_tot')

            # set initial state
            ocp_solver.set(0, "lbx", x0_bar)
            ocp_solver.set(0, "ubx", x0_bar)

            # feedback phase
            ocp_solver.options_set('rti_phase', 2)
            status = ocp_solver.solve()
            t_feedback[i] = ocp_solver.get_stats('time_tot')

            simU[i, :] = ocp_solver.get(0, "u")

        else:
            # solve ocp and get next control input
            simU[i,:] = ocp_solver.solve_for_x0(x0_bar = x0_bar, fail_on_nonzero_status=False)
            t[i] = ocp_solver.get_stats('time_tot')

        # simulate system
        # TODO: simulate system forward one step by applying the mpc control input

        simX[i+1, :] = integrator.simulate(x=, u=)


    # evaluate timings
    if use_RTI:
        # scale to milliseconds
        t_preparation *= 1000
        t_feedback *= 1000
        print(f'Computation time in preparation phase in ms: \
                min {np.min(t_preparation):.3f} median {np.median(t_preparation):.3f} max {np.max(t_preparation):.3f}')
        print(f'Computation time in feedback phase in ms:    \
                min {np.min(t_feedback):.3f} median {np.median(t_feedback):.3f} max {np.max(t_feedback):.3f}')
    else:
        # scale to milliseconds
        t *= 1000
        print(f'Computation time in ms: min {np.min(t):.3f} median {np.median(t):.3f} max {np.max(t):.3f}')

    # plot results
    model = ocp_solver.acados_ocp.model

    if use_RTI:
        plotfile = "trajectories_closed_loop_RTI.pdf"
        anifile = "cartpole_closed_loop_RTI.gif"
    else:
        plotfile = "trajectories_closed_loop.pdf"
        anifile = "cartpole_closed_loop.gif"
    plot_cartpole(np.linspace(0, (Tf/N_horizon)*Nsim, Nsim+1), Fmax, simU, simX, latexify=False, time_label=model.t_label, x_labels=model.x_labels, u_labels=model.u_labels, saveas=plotfile)
    animate_cart_pole(simX, plt_show=True, saveas=anifile)

    ocp_solver = None

def setup(x0:np.ndarray, Fmax:float, N_horizon:int, Tf:float, RTI:bool=False) -> Tuple[AcadosOcpSolver, AcadosSimSolver]:
    # create ocp object to formulate the OCP
    ocp = AcadosOcp()

    # set model
    model = export_cartpole_ode_model()
    ocp.model = model

    nx = model.x.rows()
    nu = model.u.rows()
    ny = nx + nu
    ny_e = nx


    # set cost module
    ocp.cost.cost_type = 'NONLINEAR_LS'
    ocp.cost.cost_type_e = 'NONLINEAR_LS'

    Q_mat = 2*np.diag([1e3, 1e3, 1e-2, 1e-2])
    R_mat = 2*np.diag([1e-2])

    ocp.cost.W = scipy.linalg.block_diag(Q_mat, R_mat)
    ocp.cost.W_e = Q_mat

    ocp.model.cost_y_expr = vertcat(model.x, model.u)
    ocp.model.cost_y_expr_e = model.x
    ocp.cost.yref  = np.zeros((ny, ))
    ocp.cost.yref_e = np.zeros((ny_e, ))

    # set constraints
    ocp.constraints.lbu = np.array([-Fmax])
    ocp.constraints.ubu = np.array([+Fmax])

    ocp.constraints.x0 = x0
    ocp.constraints.idxbu = np.array([0])

    # set prediction horizon
    ocp.solver_options.N_horizon = N_horizon
    ocp.solver_options.tf = Tf

    ocp.solver_options.qp_solver = 'PARTIAL_CONDENSING_HPIPM' # FULL_CONDENSING_QPOASES
    ocp.solver_options.hessian_approx = 'GAUSS_NEWTON'
    ocp.solver_options.integrator_type = 'IRK'
    ocp.solver_options.sim_method_newton_iter = 10

    if RTI:
        ocp.solver_options.nlp_solver_type = 'SQP_RTI'
    else:
        ocp.solver_options.nlp_solver_type = 'SQP'
        # ocp.solver_options.globalization = 'MERIT_BACKTRACKING' # turns on globalization
        ocp.solver_options.nlp_solver_max_iter = 150

    ocp.solver_options.qp_solver_cond_N = N_horizon


    solver_json = 'acados_ocp_' + model.name + '.json'
    acados_ocp_solver = AcadosOcpSolver(ocp, json_file = solver_json)

    # create an integrator with the same settings as used in the OCP solver.
    acados_integrator = AcadosSimSolver(ocp, json_file = solver_json)

    return acados_ocp_solver, acados_integrator


if __name__ == '__main__':
    main(use_RTI=False)
    main(use_RTI=True)
