import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, CheckButtons

# --- Signal original ---
fs = 5.0  #Hz
T_total = 2.0 #s

t_cont = np.linspace(0, T_total, 10000)
s_cont = np.sin(2 * np.pi * fs * t_cont)

# --- Figure ---
fig, ax = plt.subplots(figsize=(12, 5))
plt.subplots_adjust(bottom=0.25)
ax_check = plt.axes([0.02, 0.7, 0.07, 0.2])
checks = CheckButtons(ax_check, ['Signal', 'Points', 'Reliés'], [True, False, False])

original = ax.plot(t_cont, s_cont, 'b-', alpha=0.25, linewidth=1.5, label=f'Signal original ({fs} Hz)')

# Points échantillonnés
fe_fs_init = 10
fe_init = fe_fs_init * fs
t_ech = np.arange(0, T_total, 1/fe_init)
s_ech = np.sin(2 * np.pi * fs * t_ech)

line, = ax.plot(t_ech, s_ech, 'r-', markersize=6, linewidth=1.5,label='Échantillons')
dots, = ax.plot(t_ech, s_ech, 'ro', markersize=6, linewidth=1.5,label='')

ax.set_xlim(0, T_total)
ax.set_ylim(-1.6, 1.6)
ax.set_xlabel('Temps (s)', fontsize=12)
ax.set_ylabel('Amplitude', fontsize=12)
ax.legend(loc='upper right', fontsize=10)
ax.grid(True, alpha=0.3)


# --- Slider ---
ax_slider = plt.axes([0.15, 0.08, 0.7, 0.04])
slider = Slider(ax_slider, '$f_e/f_s$ (-)', 0.0, 10.0, valinit=fe_fs_init, valstep=0.05, color='steelblue')

fe_text = ax.text(0.15, 0.95, '', transform=ax.transAxes, fontsize=11, ha='right', va='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

def update_all(val=None):
    fe_fs = slider.val
    fe = fe_fs * fs
    t_ech = np.arange(0, T_total + 1e-10, 1/fe)
    s_ech = np.sin(2 * np.pi * fs * t_ech)

    states = checks.get_status()  # [Signal, Points, Reliés]

    original[0].set_alpha(0.8 if states[0] else 0.0)

    if states[1]:
        dots.set_xdata(t_ech)
        dots.set_ydata(s_ech)
    else:
        dots.set_xdata([])
        dots.set_ydata([])

    if states[2]:
        line.set_xdata(t_ech)
        line.set_ydata(s_ech)
    else:
        line.set_xdata([])
        line.set_ydata([])

    fe_text.set_text(f'$f_e$ = {fe:.1f} Hz')

    fig.canvas.draw_idle()

slider.on_changed(update_all)
checks.on_clicked(update_all)
update_all()

plt.show()