import numpy as np
import matplotlib.pyplot as plt

from acados_template import latexify_plot

latexify_plot()

def plot_trajectories(s_trajs, a_trajs, labels):

    if not isinstance(s_trajs, list):
        s_trajs = [s_trajs]
        a_trajs = [a_trajs]
        labels = [labels]

    prop_cycle = plt.rcParams['axes.prop_cycle']
    colors = prop_cycle.by_key()['color']

    N = a_trajs[0].size
    ts = np.arange(0, N+1)

    plt.figure(figsize=(5, 8))
    plt.subplot(3, 1, 1)
    for i, x_traj in enumerate(s_trajs):
        plt.plot(ts, x_traj[0, :].T, '-', alpha=0.7, color=colors[i], label=labels[i])
    plt.ylabel(r'$\theta_k$')
    plt.grid()
    plt.legend()

    plt.subplot(3, 1, 2)
    for i, x_traj in enumerate(s_trajs):
        plt.plot(ts, x_traj[1, :].T, '-', alpha=0.7, color=colors[i], label=labels[i])
    plt.ylabel(r'$\omega_k$')
    plt.grid()
    plt.legend()

    plt.subplot(3, 1, 3)
    for i, u_traj in enumerate(a_trajs):
        plt.step(ts[:-1], u_traj.T, alpha=0.7, color=colors[i], label=labels[i], where='post')
    plt.grid()
    plt.xlabel(r'time step $k$')
    plt.ylabel(r'$a_k$')
    plt.legend()
