File size: 6,039 Bytes
010a017 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# -*- coding: utf-8 -*-
"""
Usage:
python plot_loss_from_trainer_state.py --input trainer_state.json --outdir ./plots \
--checkpoint_steps 263,526,789,1052
功能:
- Curve: 黃橘色實線
- Grid: x,y 虛線
- Epoch markers: 藍色虛線 + EpochN 標籤(含最後一個 epoch)
- Checkpoints: 藍色小圓點(線性插值;超出範圍時使用端點值,並自動擴張 x 軸確保能看見)
"""
import json, argparse
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
YELLOW_ORANGE = "#d58f00"
BLUE = "#1f77b4"
def find_epoch_boundaries(log_items):
"""找到每個 epoch 邊界 (包含最後一個)"""
boundaries = []
prev_epoch_int = None
seen = set()
last_step, last_epoch = None, None
for it in log_items:
step = it.get("step")
ep = it.get("epoch")
if step is None or ep is None:
continue
last_step, last_epoch = step, ep
ep_int = int(ep)
if prev_epoch_int is None:
prev_epoch_int = ep_int
continue
if ep_int != prev_epoch_int:
if (step, ep_int) not in seen and ep_int >= 1:
boundaries.append((step, ep_int))
seen.add((step, ep_int))
prev_epoch_int = ep_int
# 最後一個 epoch 也補上
if last_step is not None and last_epoch is not None:
ep_final = int(float(last_epoch)) + 1
if (last_step, ep_final) not in seen:
boundaries.append((last_step, ep_final))
boundaries.sort(key=lambda x: x[0])
return boundaries
def plot_series(x, y, xlabel, ylabel, title, outpath,
epoch_marks=None, checkpoint_steps=None,
color=YELLOW_ORANGE, linestyle='-'):
fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(111)
ax.plot(x, y, color=color, linestyle=linestyle, linewidth=2)
# 標記 checkpoint 藍點(線性插值;邊界外使用端點值)
extra_x = []
if checkpoint_steps:
for s in checkpoint_steps:
y_interp = np.interp(s, x, y, left=y[0], right=y[-1])
ax.plot(s, y_interp, marker='o', color=BLUE, markersize=6)
extra_x.append(s)
# === 計算 x 範圍時把 epoch 標線也納入,並加右側 padding ===
xmin = 0
all_x_candidates = [max(x)]
if extra_x:
all_x_candidates.append(max(extra_x))
if epoch_marks:
# 把所有 epoch 標線的 step 納入考量
ep_steps = [s for (s, _) in epoch_marks]
if ep_steps:
all_x_candidates.append(max(ep_steps))
xmax_base = max(all_x_candidates) if all_x_candidates else x[-1]
# 右邊加一點 margin,避免剛好貼齊看不到線
span = max(xmax_base - xmin, 1.0)
right_pad = max(1.0, 0.02 * span) # 至少 +1 step 或 2% 寬度
ax.set_xlim(left=xmin, right=xmax_base + right_pad)
# y 仍從 0 起
ax.set_ylim(bottom=0)
# 虛線格線
ax.grid(True, which='major', axis='both', linestyle='--', linewidth=0.8, alpha=0.6)
# epoch 標記 (藍色虛線)
if epoch_marks:
for step, ep in epoch_marks:
ax.axvline(x=step, color=BLUE, linestyle='--', linewidth=1.2)
ymax = ax.get_ylim()[1]
ax.text(step, ymax*0.98, f'Epoch{ep}', rotation=90,
va='top', ha='right', fontsize=8, color=BLUE)
# label & look(放到最後避免被 set_xlim/set_ylim 影響)
ax.set_xlabel(xlabel); ax.set_ylabel(ylabel); ax.set_title(title)
ax.spines['left'].set_linewidth(2); ax.spines['bottom'].set_linewidth(2)
ax.spines['right'].set_visible(False); ax.spines['top'].set_visible(False)
fig.savefig(outpath, bbox_inches="tight")
plt.close(fig)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--input", required=True, help="Path to trainer_state.json")
ap.add_argument("--outdir", default="./plots", help="Directory to save PNGs")
ap.add_argument("--no_epoch_marks", action="store_true", help="Disable vertical epoch markers")
ap.add_argument("--checkpoint_steps", default="", help="Comma-separated steps (e.g., 100,200,500)")
args = ap.parse_args()
src = Path(args.input)
with open(src, "r", encoding="utf-8") as f:
state = json.load(f)
log = state.get("log_history", state.get("logs", []))
steps, train_losses = [], []
eval_steps, eval_losses = [], []
lr_steps, lrs = [], []
for item in log:
step = item.get("step")
if step is None:
continue
if "loss" in item:
steps.append(step); train_losses.append(item["loss"])
if "eval_loss" in item:
eval_steps.append(step); eval_losses.append(item["eval_loss"])
if "learning_rate" in item:
lr_steps.append(step); lrs.append(item["learning_rate"])
outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True)
epoch_marks = None if args.no_epoch_marks else find_epoch_boundaries(log)
# 允許空白與混合格式
raw = [s.strip() for s in args.checkpoint_steps.replace(",", ",").split(",") if s.strip()]
checkpoint_steps = []
for s in raw:
try:
checkpoint_steps.append(int(float(s)))
except:
pass
if steps and train_losses:
plot_series(steps, train_losses, "Step", "Training Loss", "Training Loss vs Step",
outdir / "loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps)
if eval_steps and eval_losses:
plot_series(eval_steps, eval_losses, "Step", "Eval Loss", "Eval Loss vs Step",
outdir / "eval_loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps)
if lr_steps and lrs:
plot_series(lr_steps, lrs, "Step", "Learning Rate", "Learning Rate vs Step",
outdir / "lr_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps)
print(f"Saved plots to: {outdir.resolve()}")
if __name__ == "__main__":
main()
|