# -*- coding: utf-8 -*- """ Usage: python plot_loss_from_trainer_state.py --input trainer_state.json --outdir ./plots \ --checkpoint_steps 263,526,789,1052 功能: - Curve: 黃橘色實線 - Grid: x,y 虛線 - Epoch markers: 藍色虛線 + EpochN 標籤(含最後一個 epoch) - Checkpoints: 藍色小圓點(線性插值;超出範圍時使用端點值,並自動擴張 x 軸確保能看見) """ import json, argparse from pathlib import Path import matplotlib.pyplot as plt import numpy as np YELLOW_ORANGE = "#d58f00" BLUE = "#1f77b4" def find_epoch_boundaries(log_items): """找到每個 epoch 邊界 (包含最後一個)""" boundaries = [] prev_epoch_int = None seen = set() last_step, last_epoch = None, None for it in log_items: step = it.get("step") ep = it.get("epoch") if step is None or ep is None: continue last_step, last_epoch = step, ep ep_int = int(ep) if prev_epoch_int is None: prev_epoch_int = ep_int continue if ep_int != prev_epoch_int: if (step, ep_int) not in seen and ep_int >= 1: boundaries.append((step, ep_int)) seen.add((step, ep_int)) prev_epoch_int = ep_int # 最後一個 epoch 也補上 if last_step is not None and last_epoch is not None: ep_final = int(float(last_epoch)) + 1 if (last_step, ep_final) not in seen: boundaries.append((last_step, ep_final)) boundaries.sort(key=lambda x: x[0]) return boundaries def plot_series(x, y, xlabel, ylabel, title, outpath, epoch_marks=None, checkpoint_steps=None, color=YELLOW_ORANGE, linestyle='-'): fig = plt.figure(figsize=(10,6)) ax = fig.add_subplot(111) ax.plot(x, y, color=color, linestyle=linestyle, linewidth=2) # 標記 checkpoint 藍點(線性插值;邊界外使用端點值) extra_x = [] if checkpoint_steps: for s in checkpoint_steps: y_interp = np.interp(s, x, y, left=y[0], right=y[-1]) ax.plot(s, y_interp, marker='o', color=BLUE, markersize=6) extra_x.append(s) # === 計算 x 範圍時把 epoch 標線也納入,並加右側 padding === xmin = 0 all_x_candidates = [max(x)] if extra_x: all_x_candidates.append(max(extra_x)) if epoch_marks: # 把所有 epoch 標線的 step 納入考量 ep_steps = [s for (s, _) in epoch_marks] if ep_steps: all_x_candidates.append(max(ep_steps)) xmax_base = max(all_x_candidates) if all_x_candidates else x[-1] # 右邊加一點 margin,避免剛好貼齊看不到線 span = max(xmax_base - xmin, 1.0) right_pad = max(1.0, 0.02 * span) # 至少 +1 step 或 2% 寬度 ax.set_xlim(left=xmin, right=xmax_base + right_pad) # y 仍從 0 起 ax.set_ylim(bottom=0) # 虛線格線 ax.grid(True, which='major', axis='both', linestyle='--', linewidth=0.8, alpha=0.6) # epoch 標記 (藍色虛線) if epoch_marks: for step, ep in epoch_marks: ax.axvline(x=step, color=BLUE, linestyle='--', linewidth=1.2) ymax = ax.get_ylim()[1] ax.text(step, ymax*0.98, f'Epoch{ep}', rotation=90, va='top', ha='right', fontsize=8, color=BLUE) # label & look(放到最後避免被 set_xlim/set_ylim 影響) ax.set_xlabel(xlabel); ax.set_ylabel(ylabel); ax.set_title(title) ax.spines['left'].set_linewidth(2); ax.spines['bottom'].set_linewidth(2) ax.spines['right'].set_visible(False); ax.spines['top'].set_visible(False) fig.savefig(outpath, bbox_inches="tight") plt.close(fig) def main(): ap = argparse.ArgumentParser() ap.add_argument("--input", required=True, help="Path to trainer_state.json") ap.add_argument("--outdir", default="./plots", help="Directory to save PNGs") ap.add_argument("--no_epoch_marks", action="store_true", help="Disable vertical epoch markers") ap.add_argument("--checkpoint_steps", default="", help="Comma-separated steps (e.g., 100,200,500)") args = ap.parse_args() src = Path(args.input) with open(src, "r", encoding="utf-8") as f: state = json.load(f) log = state.get("log_history", state.get("logs", [])) steps, train_losses = [], [] eval_steps, eval_losses = [], [] lr_steps, lrs = [], [] for item in log: step = item.get("step") if step is None: continue if "loss" in item: steps.append(step); train_losses.append(item["loss"]) if "eval_loss" in item: eval_steps.append(step); eval_losses.append(item["eval_loss"]) if "learning_rate" in item: lr_steps.append(step); lrs.append(item["learning_rate"]) outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True) epoch_marks = None if args.no_epoch_marks else find_epoch_boundaries(log) # 允許空白與混合格式 raw = [s.strip() for s in args.checkpoint_steps.replace(",", ",").split(",") if s.strip()] checkpoint_steps = [] for s in raw: try: checkpoint_steps.append(int(float(s))) except: pass if steps and train_losses: plot_series(steps, train_losses, "Step", "Training Loss", "Training Loss vs Step", outdir / "loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps) if eval_steps and eval_losses: plot_series(eval_steps, eval_losses, "Step", "Eval Loss", "Eval Loss vs Step", outdir / "eval_loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps) if lr_steps and lrs: plot_series(lr_steps, lrs, "Step", "Learning Rate", "Learning Rate vs Step", outdir / "lr_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps) print(f"Saved plots to: {outdir.resolve()}") if __name__ == "__main__": main()