multi_eng_zhtw_and_apigen / plot_loss_from_trainer_state.py
AaronWu901225's picture
Upload LoRA adapter folder
010a017 verified
# -*- coding: utf-8 -*-
"""
Usage:
python plot_loss_from_trainer_state.py --input trainer_state.json --outdir ./plots \
--checkpoint_steps 263,526,789,1052
功能:
- Curve: 黃橘色實線
- Grid: x,y 虛線
- Epoch markers: 藍色虛線 + EpochN 標籤(含最後一個 epoch)
- Checkpoints: 藍色小圓點(線性插值;超出範圍時使用端點值,並自動擴張 x 軸確保能看見)
"""
import json, argparse
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
YELLOW_ORANGE = "#d58f00"
BLUE = "#1f77b4"
def find_epoch_boundaries(log_items):
"""找到每個 epoch 邊界 (包含最後一個)"""
boundaries = []
prev_epoch_int = None
seen = set()
last_step, last_epoch = None, None
for it in log_items:
step = it.get("step")
ep = it.get("epoch")
if step is None or ep is None:
continue
last_step, last_epoch = step, ep
ep_int = int(ep)
if prev_epoch_int is None:
prev_epoch_int = ep_int
continue
if ep_int != prev_epoch_int:
if (step, ep_int) not in seen and ep_int >= 1:
boundaries.append((step, ep_int))
seen.add((step, ep_int))
prev_epoch_int = ep_int
# 最後一個 epoch 也補上
if last_step is not None and last_epoch is not None:
ep_final = int(float(last_epoch)) + 1
if (last_step, ep_final) not in seen:
boundaries.append((last_step, ep_final))
boundaries.sort(key=lambda x: x[0])
return boundaries
def plot_series(x, y, xlabel, ylabel, title, outpath,
epoch_marks=None, checkpoint_steps=None,
color=YELLOW_ORANGE, linestyle='-'):
fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(111)
ax.plot(x, y, color=color, linestyle=linestyle, linewidth=2)
# 標記 checkpoint 藍點(線性插值;邊界外使用端點值)
extra_x = []
if checkpoint_steps:
for s in checkpoint_steps:
y_interp = np.interp(s, x, y, left=y[0], right=y[-1])
ax.plot(s, y_interp, marker='o', color=BLUE, markersize=6)
extra_x.append(s)
# === 計算 x 範圍時把 epoch 標線也納入,並加右側 padding ===
xmin = 0
all_x_candidates = [max(x)]
if extra_x:
all_x_candidates.append(max(extra_x))
if epoch_marks:
# 把所有 epoch 標線的 step 納入考量
ep_steps = [s for (s, _) in epoch_marks]
if ep_steps:
all_x_candidates.append(max(ep_steps))
xmax_base = max(all_x_candidates) if all_x_candidates else x[-1]
# 右邊加一點 margin,避免剛好貼齊看不到線
span = max(xmax_base - xmin, 1.0)
right_pad = max(1.0, 0.02 * span) # 至少 +1 step 或 2% 寬度
ax.set_xlim(left=xmin, right=xmax_base + right_pad)
# y 仍從 0 起
ax.set_ylim(bottom=0)
# 虛線格線
ax.grid(True, which='major', axis='both', linestyle='--', linewidth=0.8, alpha=0.6)
# epoch 標記 (藍色虛線)
if epoch_marks:
for step, ep in epoch_marks:
ax.axvline(x=step, color=BLUE, linestyle='--', linewidth=1.2)
ymax = ax.get_ylim()[1]
ax.text(step, ymax*0.98, f'Epoch{ep}', rotation=90,
va='top', ha='right', fontsize=8, color=BLUE)
# label & look(放到最後避免被 set_xlim/set_ylim 影響)
ax.set_xlabel(xlabel); ax.set_ylabel(ylabel); ax.set_title(title)
ax.spines['left'].set_linewidth(2); ax.spines['bottom'].set_linewidth(2)
ax.spines['right'].set_visible(False); ax.spines['top'].set_visible(False)
fig.savefig(outpath, bbox_inches="tight")
plt.close(fig)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--input", required=True, help="Path to trainer_state.json")
ap.add_argument("--outdir", default="./plots", help="Directory to save PNGs")
ap.add_argument("--no_epoch_marks", action="store_true", help="Disable vertical epoch markers")
ap.add_argument("--checkpoint_steps", default="", help="Comma-separated steps (e.g., 100,200,500)")
args = ap.parse_args()
src = Path(args.input)
with open(src, "r", encoding="utf-8") as f:
state = json.load(f)
log = state.get("log_history", state.get("logs", []))
steps, train_losses = [], []
eval_steps, eval_losses = [], []
lr_steps, lrs = [], []
for item in log:
step = item.get("step")
if step is None:
continue
if "loss" in item:
steps.append(step); train_losses.append(item["loss"])
if "eval_loss" in item:
eval_steps.append(step); eval_losses.append(item["eval_loss"])
if "learning_rate" in item:
lr_steps.append(step); lrs.append(item["learning_rate"])
outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True)
epoch_marks = None if args.no_epoch_marks else find_epoch_boundaries(log)
# 允許空白與混合格式
raw = [s.strip() for s in args.checkpoint_steps.replace(",", ",").split(",") if s.strip()]
checkpoint_steps = []
for s in raw:
try:
checkpoint_steps.append(int(float(s)))
except:
pass
if steps and train_losses:
plot_series(steps, train_losses, "Step", "Training Loss", "Training Loss vs Step",
outdir / "loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps)
if eval_steps and eval_losses:
plot_series(eval_steps, eval_losses, "Step", "Eval Loss", "Eval Loss vs Step",
outdir / "eval_loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps)
if lr_steps and lrs:
plot_series(lr_steps, lrs, "Step", "Learning Rate", "Learning Rate vs Step",
outdir / "lr_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps)
print(f"Saved plots to: {outdir.resolve()}")
if __name__ == "__main__":
main()