from acados_template import AcadosModel, AcadosOcp, AcadosOcpSolver
import numpy as np
import casadi as ca
from plot_helper import plot_cartpole, animate_cartpole

X0 = np.array([0.0, np.pi, 0.0, 0.0])       # initial state
T_horizon = 2.0                             # continuous time horizon length [seconds]
N_horizon = 100                             # discrete time horizon length (time steps)
Fmax = 80                                   # maximum force that can be applied to cart


def export_cartpole_ode_model() -> AcadosModel:

    model_name = 'cartpole_ode'

    # constants
    M = 1.                              # mass of the cart [kg]
    m = 0.1                             # mass of the ball [kg]
    g = 9.81                            # gravity constant [m/s^2]
    l = 0.8                             # length of the pole [m]

    # set up state
    px      = ca.SX.sym('px')
    theta   = ca.SX.sym('theta')
    vx      = ca.SX.sym('vx')
    omega  = ca.SX.sym('omega')
    x = ca.vertcat(px, theta, vx, omega)
    # set up control
    F = ca.SX.sym('F')
    u = ca.vertcat(F)

    # xdot (needed to define ODE in implicit form for acados)
    px_dot      = ca.SX.sym('px_dot')
    theta_dot   = ca.SX.sym('theta_dot')
    vx_dot      = ca.SX.sym('vx_dot')
    omega_dot  = ca.SX.sym('omega_dot')
    xdot = ca.vertcat(px_dot, theta_dot, vx_dot, omega_dot)

    # dynamics
    cos_theta = ca.cos(theta)
    sin_theta = ca.sin(theta)
    denominator = M + m - m*cos_theta**2

    # ODE in explicit form
    f_expl = ca.vertcat(vx,
                     omega,
                     (-m*l*sin_theta*omega**2 + m*g*cos_theta*sin_theta+F) / denominator,
                     (-m*l*cos_theta*sin_theta*omega**2 + F*cos_theta+(M+m)*g*sin_theta) / (l*denominator)
                     )

    # ODE in implicit form
    f_impl = xdot - f_expl

    # put everything together in an AcadosModel
    model = AcadosModel()
    model.f_impl_expr = f_impl
    model.f_expl_expr = f_expl
    model.x = x
    model.xdot = xdot
    model.u = u
    model.name = model_name

    return model

def create_ocp_solver_description() -> AcadosOcp:

    # create ocp object to formulate the OCP
    ocp = AcadosOcp()

    # get and set system model
    model = export_cartpole_ode_model()
    ocp.model = model

    # set dimensions
    ocp.dims.N = N_horizon

    # set cost matrices
    Q_mat = 2*np.diag([1e3, 1e3, 1e-2, 1e-2])
    R_mat = 2*np.diag([1e-1])

    # define cost function explicitly in the form of "nonlinear least-squares" so acados can use the Gauss-Newton Hessian approximation
    # i.e., the stage cost is of the form .5 * (y - y_ref)^T * W * (y - y_ref),
    # with y a possibly nonlinear expression of x and u and W a weight matrix.
    # Here we simply have y = (x, u) and W = blkdiag(Q, R) for the stage cost,
    # and y = x, W = Q for the terminal cost, as well as y_ref = 0.
    ocp.cost.cost_type = 'NONLINEAR_LS'
    ocp.cost.cost_type_e = 'NONLINEAR_LS'
    ocp.model.cost_y_expr = ca.vertcat(model.x, model.u)
    ocp.model.cost_y_expr_e = model.x
    ocp.cost.W = ca.diagcat(Q_mat, R_mat).full()
    ocp.cost.W_e = Q_mat
    ocp.cost.yref  = np.zeros((model.x.shape[0] + model.u.shape[0], ))
    ocp.cost.yref_e = np.zeros((model.x.shape[0], ))

    # set constraints
    ocp.constraints.lbu = np.array([-Fmax])
    ocp.constraints.ubu = np.array([+Fmax])
    ocp.constraints.idxbu = np.array([0])
    
    ocp.constraints.x0 = X0

    # set some options
    ocp.solver_options.qp_solver = 'PARTIAL_CONDENSING_HPIPM' # FULL_CONDENSING_QPOASES
    # PARTIAL_CONDENSING_HPIPM, FULL_CONDENSING_QPOASES, FULL_CONDENSING_HPIPM,
    # PARTIAL_CONDENSING_QPDUNES, PARTIAL_CONDENSING_OSQP, FULL_CONDENSING_DAQP
    ocp.solver_options.hessian_approx = 'GAUSS_NEWTON' # 'GAUSS_NEWTON', 'EXACT'
    ocp.solver_options.integrator_type = 'IRK'
    # ocp.solver_options.print_level = 1
    ocp.solver_options.nlp_solver_type = 'SQP' # SQP_RTI, SQP
    ocp.solver_options.nlp_solver_max_iter = 400
    # ocp.solver_options.levenberg_marquardt = 1e-4

    # set prediction horizon
    ocp.solver_options.tf = T_horizon

    return ocp


def solve_single_ocp():
    # get OCP object
    ocp = create_ocp_solver_description()
    # create OCP solver
    acados_ocp_solver = AcadosOcpSolver(ocp, json_file = 'acados_ocp_' + ocp.model.name + '.json')

    nx = ocp.model.x.size()[0]
    nu = ocp.model.u.size()[0]
    Xopt = np.ndarray((N_horizon+1, nx))
    Uopt = np.ndarray((N_horizon, nu))

    # solve OCP with X0 as value of the initial state
    # (this returns the first control input, u_0, because in Model Predictive Control (MPC)
    # we would only apply the first control input to the (real) system and then re-solve the OCP after getting the resulting state)
    u0 = acados_ocp_solver.solve_for_x0(X0)
    # print some info on the solver iterations
    acados_ocp_solver.print_statistics()

    # read full solution trajectory
    # loop through discrete time indices
    for i in range(N_horizon):
        Xopt[i,:] = acados_ocp_solver.get(i, "x")
        Uopt[i,:] = acados_ocp_solver.get(i, "u")
    Xopt[N_horizon,:] = acados_ocp_solver.get(N_horizon, "x")

    plot_cartpole(np.linspace(0, T_horizon, N_horizon+1), Fmax, Uopt, Xopt, saveas='cartpole.pdf')
    animate_cartpole(Xopt, dt=T_horizon/N_horizon, saveas='cartpole.gif', plt_show=True)

if __name__ == "__main__":

    solve_single_ocp()
