# %%
# STEP 1：ライブラリ・ファイル名・保存先を設定する
# ============================================================
# Pre1 STEP 1
# ライブラリ・ファイル名・保存先を設定する
# ============================================================

from pathlib import Path

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

# NetCDFファイルを置いたフォルダ
# 同じフォルダに置いた場合は Path(".") でよい
base_dir = Path(".")

# 入力ファイル名
ncfile = base_dir / "GlobColour_log10CHL1_clim12_100km_AV_199709_202603.nc"

# 図の保存先
figdir = base_dir / "figures_14p5_chl_pre1"
figdir.mkdir(parents=True, exist_ok=True)

# 太平洋域の範囲
lat_min, lat_max = -60, 60
lon_min, lon_max = 120, 280   # 0–360表記。280E = 80W

print("Input file:", ncfile)
print("Exists:", ncfile.exists())
print("Figure directory:", figdir)

if not ncfile.exists():
    raise FileNotFoundError(f"ファイルが見つかりません: {ncfile}")


def save_current_figure(filename):
    """現在の図をPNGで保存する。"""
    out_png = figdir / filename
    plt.savefig(out_png, dpi=300, bbox_inches="tight", facecolor="white")
    print("Saved:", out_png)


# %%
# STEP 2：NetCDFを読み、太平洋域を切り出す
# ============================================================
# Pre1 STEP 2
# NetCDFを読み、太平洋域を切り出す
# ============================================================

# データを開く
ds = xr.open_dataset(ncfile)
print(ds)

# 変数を取り出す
chl = ds["log10_chl_clim"]

print("Original dims :", chl.dims)
print("Original shape:", chl.shape)
print("Original lat range:", float(chl.lat.min()), float(chl.lat.max()))
print("Original lon range:", float(chl.lon.min()), float(chl.lon.max()))

# latを昇順にそろえる
chl = chl.sortby("lat")

# lonを0–360表記にする
chl = chl.assign_coords(lon=(chl["lon"] + 360) % 360)
chl = chl.sortby("lon")

# 太平洋域を切り出す
chl_pac = chl.sel(
    lat=slice(lat_min, lat_max),
    lon=slice(lon_min, lon_max)
)

print("Pacific dims :", chl_pac.dims)
print("Pacific shape:", chl_pac.shape)
print("Pacific lat range:", float(chl_pac.lat.min()), float(chl_pac.lat.max()))
print("Pacific lon range:", float(chl_pac.lon.min()), float(chl_pac.lon.max()))


# %%
# STEP 3：1月と8月の地図を描く
# ============================================================
# Pre1 STEP 3
# 1月と8月の地図を描く
# ============================================================

for m in [1, 8]:
    da = chl_pac.sel(month=m)

    plt.figure(figsize=(11, 5))
    im = plt.pcolormesh(
        chl_pac["lon"],
        chl_pac["lat"],
        da,
        shading="auto",
        vmin=-1.6,
        vmax=1.0
    )
    plt.colorbar(im, label="log10(CHL) [log10(mg m$^{-3}$)]")
    plt.xlabel("Longitude (0–360)")
    plt.ylabel("Latitude")
    plt.title(f"Pacific log10 chlorophyll climatology: month = {m}")
    plt.xlim(lon_min, lon_max)
    plt.ylim(lat_min, lat_max)
    plt.tight_layout()

    save_current_figure(f"pre1_map_log10chl_month{m:02d}.png")
    plt.show()


# %%
# STEP 4：代表地点の12か月時系列を比較する
# ============================================================
# Pre1 STEP 4
# 代表地点の12か月時系列を比較する
# ============================================================

# 代表地点：lonは0–360表記で指定する
points = {
    "Kuroshio region": (30.0, 145.0),
    "Equatorial Pacific": (0.0, 210.0),
    "North Pacific subtropical gyre": (25.0, 200.0),
    "South Pacific mid-high latitude": (-45.0, 170.0),
    "Peru upwelling region": (-15.0, 275.0),
}

months = chl_pac["month"].values

plt.figure(figsize=(11, 5))

for name, (plat, plon) in points.items():
    # 最も近い有効格子点を選ぶ
    ts = chl_pac.sel(lat=plat, lon=plon, method="nearest")
    actual_lat = float(ts["lat"].values)
    actual_lon = float(ts["lon"].values)

    plt.plot(
        months,
        ts.values,
        marker="o",
        label=f"{name} ({actual_lat:.1f}, {actual_lon:.1f})"
    )

plt.xlabel("Month")
plt.ylabel("log10(CHL)")
plt.title("Comparison of seasonal chlorophyll cycles at selected points")
plt.xticks(months)
plt.grid(True, alpha=0.3)
plt.legend(fontsize=9)
plt.tight_layout()

save_current_figure("pre1_compare_selected_points.png")
plt.show()

# 代表地点を1つずつ個別に描く
for name, (plat, plon) in points.items():
    ts = chl_pac.sel(lat=plat, lon=plon, method="nearest")
    actual_lat = float(ts["lat"].values)
    actual_lon = float(ts["lon"].values)

    safe_name = name.replace(" ", "_").replace("-", "_")

    plt.figure(figsize=(8, 5))
    plt.plot(months, ts.values, marker="o")
    plt.xlabel("Month")
    plt.ylabel("log10(CHL)")
    plt.title(f"{name}\nvalid grid: lat={actual_lat:.1f}, lon={actual_lon:.1f}")
    plt.xticks(months)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    save_current_figure(f"pre1_point_timeseries_{safe_name}.png")
    plt.show()


# %%
# STEP 5：代表海域の平均季節変動を見る
# ============================================================
# Pre1 STEP 5
# 代表海域の平均季節変動を見る
# ============================================================

# 緯度範囲, 経度範囲を指定する
regions = {
    "Kuroshio": {
        "lat": (25, 35),
        "lon": (135, 155),
    },
    "Equatorial_Pacific": {
        "lat": (-5, 5),
        "lon": (190, 240),
    },
    "North_Pacific_Subtropical_Gyre": {
        "lat": (15, 30),
        "lon": (180, 230),
    },
    "South_Pacific_MidLat": {
        "lat": (-50, -35),
        "lon": (160, 210),
    },
    "Peru_Upwelling": {
        "lat": (-25, -5),
        "lon": (260, 280),
    },
}

plt.figure(figsize=(10, 5))

for name, reg in regions.items():
    lat1, lat2 = reg["lat"]
    lon1, lon2 = reg["lon"]

    sub = chl_pac.sel(lat=slice(lat1, lat2), lon=slice(lon1, lon2))

    # 緯度経度方向に平均して、monthだけを残す
    cycle = sub.mean(dim=("lat", "lon"), skipna=True)

    plt.plot(months, cycle.values, marker="o", label=name)

plt.xlabel("Month")
plt.ylabel("Area-mean log10(CHL)")
plt.title("Area-mean seasonal chlorophyll cycles")
plt.xticks(months)
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()

save_current_figure("pre1_area_mean_seasonal_cycles.png")
plt.show()


# %%
# STEP 6：年平均・季節振幅・最大月を計算して地図にする
# ============================================================
# Pre1 STEP 6
# 年平均・季節振幅・最大月を計算して地図にする
# ============================================================

annual_mean = chl_pac.mean(dim="month", skipna=True)
annual_max  = chl_pac.max(dim="month", skipna=True)
annual_min  = chl_pac.min(dim="month", skipna=True)

# 季節振幅 = 最大値 - 最小値
amplitude = annual_max - annual_min

# 12か月の標準偏差
annual_std = chl_pac.std(dim="month", skipna=True)

# 最大月
max_month = chl_pac.idxmax(dim="month", skipna=True)

# ---- 年平均 ----
plt.figure(figsize=(11, 5))
im = plt.pcolormesh(
    chl_pac["lon"], chl_pac["lat"], annual_mean,
    shading="auto",
    vmin=-1.6, vmax=0.4
)
plt.colorbar(im, label="mean log10(CHL)")
plt.xlabel("Longitude (0–360)")
plt.ylabel("Latitude")
plt.title("Annual mean log10(CHL)")
plt.xlim(lon_min, lon_max)
plt.ylim(lat_min, lat_max)
plt.tight_layout()
save_current_figure("pre1_feature_annual_mean.png")
plt.show()

# ---- 季節振幅 ----
plt.figure(figsize=(11, 5))
im = plt.pcolormesh(
    chl_pac["lon"], chl_pac["lat"], amplitude,
    shading="auto",
    vmin=0.0, vmax=1.6
)
plt.colorbar(im, label="max - min in log10(CHL)")
plt.xlabel("Longitude (0–360)")
plt.ylabel("Latitude")
plt.title("Seasonal amplitude of log10(CHL)")
plt.xlim(lon_min, lon_max)
plt.ylim(lat_min, lat_max)
plt.tight_layout()
save_current_figure("pre1_feature_amplitude.png")
plt.show()

# ---- 標準偏差 ----
plt.figure(figsize=(11, 5))
im = plt.pcolormesh(
    chl_pac["lon"], chl_pac["lat"], annual_std,
    shading="auto",
    vmin=0.0, vmax=0.7
)
plt.colorbar(im, label="std of monthly log10(CHL)")
plt.xlabel("Longitude (0–360)")
plt.ylabel("Latitude")
plt.title("Seasonal standard deviation of log10(CHL)")
plt.xlim(lon_min, lon_max)
plt.ylim(lat_min, lat_max)
plt.tight_layout()
save_current_figure("pre1_feature_annual_std.png")
plt.show()

# ---- 最大月 ----
plt.figure(figsize=(11, 5))
im = plt.pcolormesh(
    chl_pac["lon"], chl_pac["lat"], max_month,
    shading="auto",
    vmin=1, vmax=12,
    cmap="twilight_shifted"
)
plt.colorbar(im, label="month of maximum CHL")
plt.xlabel("Longitude (0–360)")
plt.ylabel("Latitude")
plt.title("Month of maximum log10(CHL)")
plt.xlim(lon_min, lon_max)
plt.ylim(lat_min, lat_max)
plt.tight_layout()
save_current_figure("pre1_feature_max_month.png")
plt.show()


# %%
# STEP 7：特徴量空間と簡単なルール分類を見る
# ============================================================
# Pre1 STEP 7
# 特徴量空間と簡単なルール分類を見る
# ============================================================

mean_vals = annual_mean.values.ravel()
amp_vals = amplitude.values.ravel()
valid_feature = np.isfinite(mean_vals) & np.isfinite(amp_vals)

plt.figure(figsize=(8, 6))
plt.scatter(
    mean_vals[valid_feature],
    amp_vals[valid_feature],
    s=8,
    alpha=0.25
)
plt.xlabel("Annual mean log10(CHL)")
plt.ylabel("Seasonal amplitude")
plt.title("Feature space: mean vs seasonal amplitude")
plt.grid(True, alpha=0.3)
plt.tight_layout()
save_current_figure("pre1_feature_scatter_mean_vs_amplitude.png")
plt.show()

# ------------------------------------------------------------
# これはK-meansではない。
# 人間が平均と振幅にしきい値を置いた場合の例である。
# ------------------------------------------------------------

rule_class = xr.full_like(annual_mean, np.nan)

# 1: 低CHL・弱季節変動
rule_class = rule_class.where(~((annual_mean < -0.8) & (amplitude < 0.45)), 1)

# 2: 高CHL
rule_class = rule_class.where(~(annual_mean >= -0.4), 2)

# 3: 強い季節変動
rule_class = rule_class.where(~(amplitude >= 0.45), 3)

plt.figure(figsize=(11, 5))
im = plt.pcolormesh(
    chl_pac["lon"], chl_pac["lat"], rule_class,
    shading="auto",
    cmap="tab10",
    vmin=0.5, vmax=3.5
)
cbar = plt.colorbar(im, label="Rule-based class")
cbar.set_ticks([1, 2, 3])
cbar.set_ticklabels([
    "1: low mean / weak seasonality",
    "2: high mean",
    "3: strong seasonality",
])
plt.xlabel("Longitude (0–360)")
plt.ylabel("Latitude")
plt.title("Simple rule-based classification before K-means")
plt.xlim(lon_min, lon_max)
plt.ylim(lat_min, lat_max)
plt.tight_layout()
save_current_figure("pre1_rule_based_classification.png")
plt.show()

# NetCDFを閉じる
ds.close()
