import numpy as np


def ode(x, u):
    x_next = np.vstack((
        x[1],
        np.sin(x[0]) + u[0]
    ))
    return x_next


def rk4_step(x, u, h, ode_fun):
    k1 = ode_fun(x         , u)
    k2 = ode_fun(x + h/2*k1, u)
    k3 = ode_fun(x + h/2*k2, u)
    k4 = ode_fun(x +   h*k3, u)
    return x + h/6*(k1 + 2*k2 + 2*k3 + k4)


def RK4_integrator(x0:np.ndarray, u0:np.ndarray, Ts:float, nSteps:int, ode_fun, compute_sensitivities:bool=False):
    h = Ts / nSteps

    if not compute_sensitivities:
        x_end = x0
        for _ in range(nSteps):
            x0 = x_end
            x_end = rk4_step(x0, u0, h, ode_fun)
        return (x_end, None, None)
    
    # else, compute sensitivities via imaginary trick
    perturb = 1e-100
    if np.ndim(u0) == 0:
        u0 = np.array([u0])
    nx = len(x0)
    nu = len(u0)
    A = np.eye(nx)
    B = np.zeros((nx, nu))
    for _ in range(nSteps):
        sens_x = np.zeros((nx, nx))
        for k in range(nx):
            x_temp = x0.astype(complex)
            x_temp[k] += perturb * 1j
            x_temp = rk4_step(x_temp, u0, h, ode_fun)
            sens_x[:, k] = (np.imag(x_temp) / perturb).flatten()
        sens_u = np.zeros((nx, nu))
        for k in range(nu):
            u_temp = u0.astype(complex)
            u_temp[k] += perturb * 1j
            u_temp = rk4_step(x0, u_temp, h, ode_fun)
            sens_u[:, k] = (np.imag(u_temp) / perturb).flatten()
        A = sens_x @ A
        B = sens_x @ B + sens_u
    return (None, A, B)


def project(value, grid:np.ndarray):
    '''project a value to the index of its closest value in a grid'''
    N = grid.size
    min_val = np.amin(grid)
    max_val = np.amax(grid)
    # we assume the grid is a monotonous increasing 1D array
    index = np.around(((value-min_val)/(max_val-min_val))*(N-1)).astype(int)
    if 0 <= index and index < N:
        return index
    # else
    return None
