import pylab as pl
import casadi as ca

from plot_glucose import plot_glucose

## -- General setup -- ##

# States (order: G, X, I, D)

x = ca.SX.sym("x", 4)

# Control (order: dDdt, dist)

u = ca.SX.sym("u", 2)

# Initial values

G_init = 4.5
X_init = 15.0
I_init = 15.0
D_init = 0.0

x0 = ca.vertcat(G_init, X_init, I_init, D_init)

# Parameters

P1 = 0.028735
P2 = 0.028344
P3 = 5.035e-5
V1 = 12.0
n = 5.0 / 54.0

##  -- Task 3 -- ##

## -->

# Insert the ODE from task 1

dxdt = ca.vertcat(

        -P1 * (x[0] - G_init) - (x[1] - X_init) * x[0] + u[1],
        -P2 * (x[1] - X_init) + P3 * (x[2] - I_init),
        -n * x[2] + x[3] / V1,
        u[0],

    )

## <--

## -->

# Also, insert the time points and the number of intervals from task 1

t0 = 0.0;
tf = 200.0;
N = 250;

## <--

# Time points

T = pl.linspace(t0, tf, N + 1)

# Duration of a time interval

dt = (tf - t0) / N


# Control data: the change rate for the insulin infusion is all zero,
# so since the initial value for D is also zero, we do not start to infuse
# insulin

u_icr_init = pl.zeros(T[:-1].shape)

# The values for the "nutrition disturbance" are created as follows

u_meal = 3.0 * pl.exp(-0.05 * T[:-1])

# Stack the control's data together within one structure

u_init = ca.horzcat(u_icr_init, u_meal)

# Define an dictionary with ODE dxdt, states x and "parameters" u
# (this means, u is to be constant over one integration interval, since it
# might only change between two integration steps)

ode = {"x": x, "p": u, "ode": dxdt}

# Instantiate the CVDOES integrator that comes with CasADi, 
# and set the final time

cvodes_integrator = ca.integrator("cvodes_integrator", "cvodes", ode, \
    {"t0": 0.0, "tf": dt})

# Run the integrator for each control interval, while initializing with the
# results of the former integration step

x_sim = [x0]

for k in range(N):

    x_sim.append(cvodes_integrator(x0 = x_sim[-1], p = u_init[k, :])["xf"])

# Rearrange the results data

x_sim = ca.horzcat(*x_sim).T


# Generate the multiple shooting equality constraints, and structs that contain
# the optimization variables, as well as their bounds and initial values

# Bounds for u_icr

u_icr_min = -pl.inf
u_icr_max = pl.inf

# States bounds

x_min = ca.DM([0.0, 0.0, 0.0, 5.0])
x_max = ca.DM([pl.inf, pl.inf, pl.inf, 20.0])

# Introduce short expressions for the number of states and the number of controls

nx = x.numel()
nu = u.numel()

# Initialize a list for the optimization variables and its bounds and initials
# with the corresponding values of s0

V = [ca.MX.sym("s0", nx)]

# Make sure the optimizer cannot choose the initial states

V_min = [x0]
V_max = [x0]
V_init = [x0]

# Initialize the objective with the first entry of eq. (12) for k = 0;
# during setup of the shooting constraints, further entries will be added to f

G_ref = 5.0
D_ref = 13.0

f = (V[-1][0] - G_ref)**2 + (V[-1][3] - D_ref)**2

# Initialize lists for multiple shooting constraints and bounds

g = []
g_min = []
g_max = []


for k in range(N):

    # Generation of the multiple shooting constraints:
    #
    # The solver needs a formulation of the following kind:
    #
    # 0 <= s(k+1) - r(s(k), q(k)) <= 0
    #
    # The solver used here generally demands inequality constraints, so
    # we need to define both g_min and g_max and set them to zero, so that the
    # inequality constraints constraining g to a value between 0 and 0 are
    # then in fact one equality constraint)

    # -->

    # Set the previous sk+1 to sk for the current interval

    sk = V[-1]

    # Introduce an optimization variable for the current qk

    qk = ca.MX.sym("q" + str(k), nu)

    # Introduce an optimization variable for the current sk+1

    skplus1 = ca.MX.sym("s" + str(k+1), nx)

    # Append to current multiple shooting constraints to g

    g.append(skplus1 - cvodes_integrator(x0 = sk, p = qk)["xf"])

    # < --

    g_min.append(pl.zeros(nx))
    g_max.append(pl.zeros(nx))

    # Collect the introduced optimization variables, as well as their
    # bounds and initials

    # Collect controls

    # -->

    q_min_k = ca.DM([u_icr_min, u_meal[k]])
    q_max_k = ca.DM([u_icr_max, u_meal[k]])
    q_init_k = ca.DM([u_icr_init[k], u_meal[k]])

    # <--

    V.append(qk)
    V_min.append(q_min_k)
    V_max.append(q_max_k)
    V_init.append(q_init_k)
    
    # Collect states

    # -->

    s_min_k = x_min
    s_max_k = x_max
    s_init_k = x_sim[k,:]

    # <--

    V.append(skplus1)
    V_min.append(s_min_k)
    V_max.append(s_max_k)
    V_init.append(s_init_k)

    # Add the relevant entries of sk+1 to the objective

    # -->

    f += (V[-1][0] - G_ref)**2 + (V[-1][3] - D_ref)**2

    # <--


# Vectorize collected variables

V = ca.veccat(*V)
V_min = ca.veccat(*V_min)
V_max = ca.veccat(*V_max)
V_init = ca.veccat(*V_init)

g = ca.veccat(*g)
g_min = ca.veccat(*g_min)
g_max = ca.veccat(*g_max)


# Set up the NLP and the solver accordingly, and solve the optimization problem

nlp = {"x": V, "f": f, "g": g}

nlpsolver = ca.nlpsol("nlpsolver", "ipopt", nlp)

solution = nlpsolver(x0 = V_init, lbx = V_min, ubx = V_max, \
    lbg = g_min, ubg = g_max)

# Extract the optimization variables' values ("x") from the solution object,
# rearrange values according to V

V_opt = solution["x"]

G_opt = V_opt[::nx+nu]
X_opt = V_opt[1::nx+nu]
I_opt = V_opt[2::nx+nu]
D_opt = V_opt[3::nx+nu]
x_opt = ca.horzcat(G_opt, X_opt, I_opt, D_opt)

u_icr_opt = V_opt[4::nx+nu]
u_opt = ca.horzcat(u_icr_opt, u_meal)

# Plot the optimization results

plot_glucose(T, x_opt, u_opt)
