import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

# ============================================================
# Paramètres physiques
# ============================================================
R = 300e-6  # rayon du capillaire (m) — comme dans Chevalier

fluids = {
    'Eau': {
        'eta': 1e-3,       # viscosité dynamique (Pa.s)
        'rho': 1000,       # masse volumique (kg/m^3)
        'gamma': 72e-3,    # tension de surface (N/m)
        'theta': 0,        # angle de contact (rad)
        'tmax': 3.0,       # temps max de simulation (s)
    },
    'PDMS': {
        'eta': 48e-3,
        'rho': 960,
        'gamma': 20.8e-3,
        'theta': 0,
        'tmax': 45.0,
    },
}

g = 9.81  # pesanteur (m/s^2)

# ============================================================
# Fonctions
# ============================================================

def hauteur_jurin(gamma, theta, rho, R):
    """Hauteur d'équilibre de Jurin."""
    return 2 * gamma * np.cos(theta) / (rho * g * R)


def eq_generale(t, y, gamma, theta, rho, eta, R):
    """
    Équation générale (3) de Chevalier.
    y = [z, zdot]
    z'' = (1/z) * [ 2*gamma*cos(theta)/(rho*R) - 8*eta*z*zdot/(rho*R^2) - z*g - zdot^2 ]
    
    Attention à z -> 0 : on protège avec un z minimum.
    """
    z, zdot = y
    z = max(z, 1e-10)  # éviter division par zéro
    
    terme_cap = 2 * gamma * np.cos(theta) / (rho * R)
    terme_visq = 8 * eta * z * zdot / (rho * R**2)
    terme_grav = z * g
    terme_inertie_convective = zdot**2
    
    zddot = (1.0 / z) * (terme_cap - terme_visq - terme_grav - terme_inertie_convective)
    
    return [zdot, zddot]


def washburn(t, gamma, theta, eta, R):
    """
    Modèle de Washburn (temps courts, quasi-stationnaire, sans gravité).
    z(t) = sqrt(R * gamma * cos(theta) * t / (2 * eta))
    """
    return np.sqrt(R * gamma * np.cos(theta) * t / (2 * eta))


def relaxation_expo(t, gamma, theta, rho, eta, R):
    """
    Modèle temps longs quasi-stationnaire (relaxation exponentielle vers Jurin).
    
    Equation quasi-stationnaire complète (avec gravité) :
        (8*eta / R^2) * z * zdot = rho*g*(h_eq - z)
    
    On linéarise autour de h_eq : z = h_eq*(1 - epsilon)
    => epsilon(t) = epsilon_0 * exp(-t/tau)
    => z(t) = h_eq * (1 - epsilon_0 * exp(-t/tau))
    
    avec tau = 8*eta*h_eq / (rho*g*R^2)
    
    On choisit epsilon_0 = 1 (z(0) = 0) pour raccorder.
    Note : ce modèle n'est valide que pour t >> tau (quand epsilon << 1).
    """
    h_eq = hauteur_jurin(gamma, theta, rho, R)
    tau = 8 * eta * h_eq / (rho * g * R**2)
    return h_eq * (1 - np.exp(-t / tau))


# ============================================================
# Résolution et tracé
# ============================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for idx, (name, params) in enumerate(fluids.items()):
    eta = params['eta']
    rho = params['rho']
    gamma = params['gamma']
    theta = params['theta']
    tmax = params['tmax']
    
    h_eq = hauteur_jurin(gamma, theta, rho, R)
    tau = 8 * eta * h_eq / (rho * g * R**2)
    
    print(f"--- {name} ---")
    print(f"  Hauteur de Jurin : {h_eq*1e3:.2f} mm")
    print(f"  Tau (relaxation) : {tau:.4f} s")
    print(f"  Longueur capillaire : {np.sqrt(gamma/(rho*g))*1e3:.2f} mm")
    print()
    
    # --- Modèle général (résolution numérique) ---
    # Conditions initiales : z(0) très petit, zdot(0) depuis modèle inertiel
    z0 = 1e-6  # 1 micron pour éviter singularité
    zdot0 = np.sqrt(2 * gamma * np.cos(theta) / (rho * R))  # vitesse initiale inertielle
    
    t_span = (0, tmax)
    t_eval = np.linspace(0, tmax, 5000)
    
    sol = solve_ivp(
        eq_generale, t_span, [z0, zdot0],
        args=(gamma, theta, rho, eta, R),
        t_eval=t_eval,
        method='Radau',  # méthode implicite, robuste pour EDO raides
        max_step=tmax/5000,
        rtol=1e-8, atol=1e-12
    )
    
    t_gen = sol.t
    z_gen = sol.y[0] * 1e3  # conversion en mm
    
    # --- Modèle Washburn (temps courts) ---
    z_wash = washburn(t_eval, gamma, theta, eta, R) * 1e3  # mm
    
    # --- Modèle relaxation exponentielle (temps longs) ---
    z_relax = relaxation_expo(t_eval, gamma, theta, rho, eta, R) * 1e3  # mm
    
    # --- Tracé ---
    ax = axes[idx]
    
    ax.plot(t_gen, z_gen, 'k-', linewidth=2, label='Modèle général (eq. 3)')
    ax.plot(t_eval, z_wash, 'b--', linewidth=1.5, label=r'Washburn : $h \propto \sqrt{t}$')
    ax.plot(t_eval, z_relax, 'r-.', linewidth=1.5, label=r'Relaxation expo : $h_{eq}(1-e^{-t/\tau})$')
    
    # Hauteur de Jurin
    ax.axhline(y=h_eq*1e3, color='gray', linestyle=':', linewidth=1, label=f'Jurin = {h_eq*1e3:.1f} mm')
    
    # Limiter l'axe y pour la lisibilité
    ymax = min(h_eq * 1e3 * 1.5, max(z_wash[-1], h_eq * 1e3 * 1.3))
    ax.set_ylim(0, ymax)
    
    ax.set_xlabel('Temps (s)', fontsize=12)
    ax.set_ylabel('Hauteur h (mm)', fontsize=12)
    ax.set_title(f'{name}', fontsize=14, fontweight='bold')
    ax.legend(fontsize=9, loc='best')
    ax.grid(True, alpha=0.3)
    
    # Annotations
    ax.text(0.95, 0.05, 
            f'η = {eta*1e3:.0f} mPa·s\n'
            f'γ = {gamma*1e3:.1f} mN/m\n'
            f'R = {R*1e6:.0f} μm\n'
            f'τ = {tau:.3f} s',
            transform=ax.transAxes, fontsize=9,
            verticalalignment='bottom', horizontalalignment='right',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.suptitle("Dynamique de l'ascension capillaire", fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
#plt.savefig('/mnt/user-data/outputs/dynamique_capillaire.png', dpi=150, bbox_inches='tight')
#plt.savefig('/mnt/user-data/outputs/dynamique_capillaire.pdf', bbox_inches='tight')
#print("Figures sauvegardées.")
plt.show()