# %%
# STEP 1：ライブラリ・設定・保存先を用意する
# ============================================================
# Pre2 STEP 1
# ライブラリ・設定・保存先を用意する
# ============================================================

from pathlib import Path

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

base_dir = Path(".")
ncfile = base_dir / "GlobColour_log10CHL1_clim12_100km_AV_199709_202603.nc"

figdir = base_dir / "figures_14p5_chl_pre2"
figdir.mkdir(parents=True, exist_ok=True)

lat_min, lat_max = -60, 60
lon_min, lon_max = 120, 280

# ランダム抽出の再現性を固定する
rng = np.random.default_rng(0)

print("Input file:", ncfile)
print("Exists:", ncfile.exists())
print("Figure directory:", figdir)

if not ncfile.exists():
    raise FileNotFoundError(f"ファイルが見つかりません: {ncfile}")


def save_current_figure(filename):
    out_png = figdir / filename
    plt.savefig(out_png, dpi=300, bbox_inches="tight", facecolor="white")
    print("Saved:", out_png)


# %%
# STEP 2：データを読み、太平洋域を切り出す
# ============================================================
# Pre2 STEP 2
# データを読み、太平洋域を切り出す
# ============================================================

ds = xr.open_dataset(ncfile)
chl = ds["log10_chl_clim"]

# latを昇順にする
chl = chl.sortby("lat")

# lonを0–360表記にする
chl = chl.assign_coords(lon=(chl["lon"] + 360) % 360)
chl = chl.sortby("lon")

# 太平洋域を切り出す
chl_pac = chl.sel(
    lat=slice(lat_min, lat_max),
    lon=slice(lon_min, lon_max)
)

print("chl_pac dims :", chl_pac.dims)
print("chl_pac shape:", chl_pac.shape)


# %%
# STEP 3：chl(month, lat, lon) を X(grid, month) に変形する
# ============================================================
# Pre2 STEP 3
# chl(month, lat, lon) を X(grid, month) に変形する
# ============================================================

# K-meansでは、1格子点を1サンプルとして扱いたい。
# そのため、配列を (lat, lon, month) の順番に並べ替える。
chl_for_cluster = chl_pac.transpose("lat", "lon", "month")

lat = chl_for_cluster["lat"].values
lon = chl_for_cluster["lon"].values
months = chl_for_cluster["month"].values

nlat = len(lat)
nlon = len(lon)
nmon = len(months)

print("nlat:", nlat)
print("nlon:", nlon)
print("nmon:", nmon)

A = chl_for_cluster.values              # shape = (lat, lon, month)
X = A.reshape(nlat * nlon, nmon)        # shape = (grid, month)

print("A shape:", A.shape)
print("X shape:", X.shape)
print("Xの1行 = 1つの格子点の12か月時系列")
print("Xの1列 = ある月の全格子点の値")

# 12か月すべてに値がある格子点だけ使う
valid_grid = np.isfinite(X).all(axis=1)
X_valid = X[valid_grid, :]

# validな格子点の緯度経度も取り出す
lon2d, lat2d = np.meshgrid(lon, lat)
lat_flat = lat2d.ravel()
lon_flat = lon2d.ravel()
lat_valid = lat_flat[valid_grid]
lon_valid = lon_flat[valid_grid]

print("Total grid points:", X.shape[0])
print("Valid grid points:", X_valid.shape[0])


# %%
# STEP 4：Xの一部をヒートマップで見る
# ============================================================
# Pre2 STEP 4
# Xの一部をヒートマップで見る
# ============================================================

# 表示するサンプル数
n_show = 50

# 季節変動が少し見えやすい格子点からランダムに選ぶ
amp_each_grid = np.nanmax(X_valid, axis=1) - np.nanmin(X_valid, axis=1)
candidate_indices = np.where(amp_each_grid >= 0.20)[0]

# 候補が少なすぎる場合は、すべての有効格子点から選ぶ
if len(candidate_indices) < n_show:
    candidate_indices = np.arange(X_valid.shape[0])

sample_indices = rng.choice(candidate_indices, size=n_show, replace=False)
X_show = X_valid[sample_indices, :]

plt.figure(figsize=(8, 7))
im = plt.imshow(
    X_show,
    aspect="auto",
    origin="upper",
    interpolation="nearest"
)
plt.colorbar(im, label="log10(CHL)")
plt.xlabel("Month")
plt.ylabel("Sampled grid index")
plt.title("50 sampled rows from X(grid, month)")
plt.xticks(ticks=np.arange(12), labels=months)
plt.tight_layout()
save_current_figure("pre2_X_matrix_heatmap.png")
plt.show()


# %%
# STEP 5：Xの行を折れ線として描く
# ============================================================
# Pre2 STEP 5
# Xの行を折れ線として描く
# ============================================================

# さきほど選んだ50個のうち、最初の5個だけ折れ線で描く
n_line = 5
line_indices = sample_indices[:n_line]

plt.figure(figsize=(9, 5))
for i, idx in enumerate(line_indices):
    plt.plot(
        months,
        X_valid[idx, :],
        marker="o",
        label=f"grid sample {i}"
    )

plt.xlabel("Month")
plt.ylabel("log10(CHL)")
plt.title("Examples of rows in X(grid, month)")
plt.xticks(months)
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
save_current_figure("pre2_examples_rows_in_X_amplitude_filtered_random.png")
plt.show()


# %%
# STEP 6：各格子点ごとに標準化して X_shape を作る
# ============================================================
# Pre2 STEP 6
# 各格子点ごとに標準化して X_shape を作る
# ============================================================

# 各行ごとに、12か月平均と12か月標準偏差を計算する
X_mean = np.nanmean(X_valid, axis=1, keepdims=True)
X_std  = np.nanstd(X_valid, axis=1, keepdims=True)

# 標準偏差が0の格子点は除く
valid_std = np.squeeze(X_std > 0)

X_valid2 = X_valid[valid_std, :]
X_mean2 = X_mean[valid_std, :]
X_std2 = X_std[valid_std, :]

lat_valid2 = lat_valid[valid_std]
lon_valid2 = lon_valid[valid_std]

# 標準化
X_shape = (X_valid2 - X_mean2) / X_std2

print("X_valid shape:", X_valid.shape)
print("X_shape shape:", X_shape.shape)
print("X_shape全体の平均:", np.nanmean(X_shape))
print("X_shape全体の標準偏差:", np.nanstd(X_shape))

# STEP 4と同じサンプル番号を、valid_std後の番号へ対応させる
# 簡単のため、ここではX_shapeから改めて50個選ぶ
amp_shape_base = np.nanmax(X_valid2, axis=1) - np.nanmin(X_valid2, axis=1)
candidate_shape = np.where(amp_shape_base >= 0.20)[0]
if len(candidate_shape) < n_show:
    candidate_shape = np.arange(X_shape.shape[0])

sample_indices_shape = rng.choice(candidate_shape, size=n_show, replace=False)
X_shape_show = X_shape[sample_indices_shape, :]

plt.figure(figsize=(8, 7))
im = plt.imshow(
    X_shape_show,
    aspect="auto",
    origin="upper",
    interpolation="nearest",
    vmin=-2,
    vmax=2
)
plt.colorbar(im, label="Standardized seasonal-cycle value")
plt.xlabel("Month")
plt.ylabel("Sampled grid index")
plt.title("50 sampled rows from X_shape(grid, month)")
plt.xticks(ticks=np.arange(12), labels=months)
plt.tight_layout()
save_current_figure("pre2_X_shape_matrix_heatmap.png")
plt.show()


# %%
# STEP 7：raw値と標準化後の形を比較する
# ============================================================
# Pre2 STEP 7
# raw値と標準化後の形を比較する
# ============================================================

# 同じ格子点を5個選んで、rawとstandardizedを比較する
n_line = 5
line_indices_shape = sample_indices_shape[:n_line]

plt.figure(figsize=(9, 5))
for i, idx in enumerate(line_indices_shape):
    plt.plot(
        months,
        X_valid2[idx, :],
        marker="o",
        label=f"raw sample {i}"
    )

plt.xlabel("Month")
plt.ylabel("log10(CHL)")
plt.title("Raw seasonal cycles")
plt.xticks(months)
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
save_current_figure("pre2_raw_examples_amplitude_filtered_random.png")
plt.show()

plt.figure(figsize=(9, 5))
for i, idx in enumerate(line_indices_shape):
    plt.plot(
        months,
        X_shape[idx, :],
        marker="o",
        label=f"standardized sample {i}"
    )

plt.axhline(0, color="k", linewidth=0.8)
plt.xlabel("Month")
plt.ylabel("Standardized value")
plt.title("Standardized seasonal-cycle shapes")
plt.xticks(months)
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
save_current_figure("pre2_standardized_examples_amplitude_filtered_random.png")
plt.show()


# %%
# STEP 8：1つの格子点で raw と standardized を上下に比較する
# ============================================================
# Pre2 STEP 8
# 1つの格子点で raw と standardized を上下に比較する
# ============================================================

# 例として1つの格子点を選ぶ
one_idx = line_indices_shape[0]
raw_one = X_valid2[one_idx, :]
shape_one = X_shape[one_idx, :]

one_lat = lat_valid2[one_idx]
one_lon = lon_valid2[one_idx]

fig, axes = plt.subplots(2, 1, figsize=(9, 7), sharex=True)

axes[0].plot(months, raw_one, marker="o")
axes[0].set_ylabel("raw log10(CHL)")
axes[0].set_title(f"One grid point: lat={one_lat:.1f}, lon={one_lon:.1f}")
axes[0].grid(True, alpha=0.3)

axes[1].plot(months, shape_one, marker="s", linestyle="--")
axes[1].axhline(0, color="k", linewidth=0.8)
axes[1].set_xlabel("Month")
axes[1].set_ylabel("standardized value")
axes[1].grid(True, alpha=0.3)

plt.xticks(months)
plt.suptitle("Raw vs standardized seasonal cycle for one grid point", y=0.98)
plt.tight_layout()
save_current_figure("pre2_raw_vs_standardized_one_sample_two_panels.png")
plt.show()


# %%
# STEP 9：次のK-means用に X と X_shape を保存する
# ============================================================
# Pre2 STEP 9
# 次のK-means用に X と X_shape を保存する
# ============================================================

out_npz = base_dir / "pre14_X_for_kmeans.npz"

np.savez(
    out_npz,
    X_valid2=X_valid2,
    X_shape=X_shape,
    lat_valid2=lat_valid2,
    lon_valid2=lon_valid2,
    months=months,
)

print("Saved:", out_npz)
print("次のK-means本編では、このX_shapeを使ってクラスタリングする。")

ds.close()
