import numpy as np

def J_FAD(u, m, x0, h):
    N = len(u)

    J = np.zeros((m,N))

    # forward sweep
    X = np.zeros(N)
    X[0] = x0
    for i in range(N-1):
        X[i+1] = X[i] + h * ((1 - X[i]) * X[i] + u[i])

    # forward AD
    for n in range(N):
        udot = np.zeros(N)
        udot[n] = 1
        xdot = np.zeros(N)

        xdot[0] = h * udot[0]
        
        for i in range(1, N):
            dfdu = h
            dfdx = 1 + h * (1 - 2*X[i])
            xdot[i] = dfdu * udot[i] + dfdx * xdot[i-1]
        J[:,n] = xdot[N-m:None]

    return J
