import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from exo_skryer.rate_jax import load_nasa9_cache
from exo_skryer.vert_Tp import Modified_Milne
from exo_skryer.vert_chem import quench_approx
from exo_skryer.data_constants import bar

try:
    root = Path(__file__).resolve().parent
except NameError:
    root = Path.cwd().resolve()
for _ in range(5):
    if (root / "NASA9").is_dir():
        break
    root = root.parent
load_nasa9_cache(str(root / "NASA9"))

nlev = 100
p_bot = np.log10(1000.0)
p_top = np.log10(1e-8)
p_lev = np.logspace(p_bot, p_top, nlev) * bar

params_tp = {
    "T_int": 1200.0,
    "T_ratio": 0.333,
    "log_10_g": 4.5,
    "log_10_k_ir": -2.0,
    "log_10_p_t": 0.0,
    "beta": 0.55
}
T_lev, T_lay = Modified_Milne(p_lev, params_tp)
p_lay = (p_lev[1:] - p_lev[:-1]) / np.log(p_lev[1:] / p_lev[:-1])

params = {"M_to_H": 0.0, "C_to_O": 0.55, "log_10_Kzz": 8.0, "log_10_g": 4.5}
vmr_lay = quench_approx(p_lay, T_lay, params, nlev - 1)

fig, ax = plt.subplots(figsize=(10, 5))
for key in ("H2O", "CO", "CH4", "NH3", "HCN", "CO2"):
    if key in vmr_lay:
        ax.semilogy(vmr_lay[key], p_lay / bar, label=key)
ax.set_xlabel("VMR", fontsize=16)
ax.set_ylabel("pressure [bar]", fontsize=16)
ax.set_title("Quench Approx Chemistry", fontsize=14)
ax.legend(fontsize=10)
ax.set_xscale('log')
ax.invert_yaxis()
plt.tight_layout()