|
|
|
|
|
""" |
|
|
Usage: |
|
|
python plot_loss_from_trainer_state.py --input trainer_state.json --outdir ./plots \ |
|
|
--checkpoint_steps 263,526,789,1052 |
|
|
|
|
|
功能: |
|
|
- Curve: 黃橘色實線 |
|
|
- Grid: x,y 虛線 |
|
|
- Epoch markers: 藍色虛線 + EpochN 標籤(含最後一個 epoch) |
|
|
- Checkpoints: 藍色小圓點(線性插值;超出範圍時使用端點值,並自動擴張 x 軸確保能看見) |
|
|
""" |
|
|
import json, argparse |
|
|
from pathlib import Path |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
|
|
|
YELLOW_ORANGE = "#d58f00" |
|
|
BLUE = "#1f77b4" |
|
|
|
|
|
def find_epoch_boundaries(log_items): |
|
|
"""找到每個 epoch 邊界 (包含最後一個)""" |
|
|
boundaries = [] |
|
|
prev_epoch_int = None |
|
|
seen = set() |
|
|
last_step, last_epoch = None, None |
|
|
for it in log_items: |
|
|
step = it.get("step") |
|
|
ep = it.get("epoch") |
|
|
if step is None or ep is None: |
|
|
continue |
|
|
last_step, last_epoch = step, ep |
|
|
ep_int = int(ep) |
|
|
if prev_epoch_int is None: |
|
|
prev_epoch_int = ep_int |
|
|
continue |
|
|
if ep_int != prev_epoch_int: |
|
|
if (step, ep_int) not in seen and ep_int >= 1: |
|
|
boundaries.append((step, ep_int)) |
|
|
seen.add((step, ep_int)) |
|
|
prev_epoch_int = ep_int |
|
|
|
|
|
if last_step is not None and last_epoch is not None: |
|
|
ep_final = int(float(last_epoch)) + 1 |
|
|
if (last_step, ep_final) not in seen: |
|
|
boundaries.append((last_step, ep_final)) |
|
|
boundaries.sort(key=lambda x: x[0]) |
|
|
return boundaries |
|
|
|
|
|
def plot_series(x, y, xlabel, ylabel, title, outpath, |
|
|
epoch_marks=None, checkpoint_steps=None, |
|
|
color=YELLOW_ORANGE, linestyle='-'): |
|
|
fig = plt.figure(figsize=(10,6)) |
|
|
ax = fig.add_subplot(111) |
|
|
ax.plot(x, y, color=color, linestyle=linestyle, linewidth=2) |
|
|
|
|
|
|
|
|
extra_x = [] |
|
|
if checkpoint_steps: |
|
|
for s in checkpoint_steps: |
|
|
y_interp = np.interp(s, x, y, left=y[0], right=y[-1]) |
|
|
ax.plot(s, y_interp, marker='o', color=BLUE, markersize=6) |
|
|
extra_x.append(s) |
|
|
|
|
|
|
|
|
xmin = 0 |
|
|
all_x_candidates = [max(x)] |
|
|
if extra_x: |
|
|
all_x_candidates.append(max(extra_x)) |
|
|
if epoch_marks: |
|
|
|
|
|
ep_steps = [s for (s, _) in epoch_marks] |
|
|
if ep_steps: |
|
|
all_x_candidates.append(max(ep_steps)) |
|
|
|
|
|
xmax_base = max(all_x_candidates) if all_x_candidates else x[-1] |
|
|
|
|
|
|
|
|
span = max(xmax_base - xmin, 1.0) |
|
|
right_pad = max(1.0, 0.02 * span) |
|
|
ax.set_xlim(left=xmin, right=xmax_base + right_pad) |
|
|
|
|
|
|
|
|
ax.set_ylim(bottom=0) |
|
|
|
|
|
|
|
|
ax.grid(True, which='major', axis='both', linestyle='--', linewidth=0.8, alpha=0.6) |
|
|
|
|
|
|
|
|
if epoch_marks: |
|
|
for step, ep in epoch_marks: |
|
|
ax.axvline(x=step, color=BLUE, linestyle='--', linewidth=1.2) |
|
|
ymax = ax.get_ylim()[1] |
|
|
ax.text(step, ymax*0.98, f'Epoch{ep}', rotation=90, |
|
|
va='top', ha='right', fontsize=8, color=BLUE) |
|
|
|
|
|
|
|
|
ax.set_xlabel(xlabel); ax.set_ylabel(ylabel); ax.set_title(title) |
|
|
ax.spines['left'].set_linewidth(2); ax.spines['bottom'].set_linewidth(2) |
|
|
ax.spines['right'].set_visible(False); ax.spines['top'].set_visible(False) |
|
|
|
|
|
fig.savefig(outpath, bbox_inches="tight") |
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
def main(): |
|
|
ap = argparse.ArgumentParser() |
|
|
ap.add_argument("--input", required=True, help="Path to trainer_state.json") |
|
|
ap.add_argument("--outdir", default="./plots", help="Directory to save PNGs") |
|
|
ap.add_argument("--no_epoch_marks", action="store_true", help="Disable vertical epoch markers") |
|
|
ap.add_argument("--checkpoint_steps", default="", help="Comma-separated steps (e.g., 100,200,500)") |
|
|
args = ap.parse_args() |
|
|
|
|
|
src = Path(args.input) |
|
|
with open(src, "r", encoding="utf-8") as f: |
|
|
state = json.load(f) |
|
|
|
|
|
log = state.get("log_history", state.get("logs", [])) |
|
|
|
|
|
steps, train_losses = [], [] |
|
|
eval_steps, eval_losses = [], [] |
|
|
lr_steps, lrs = [], [] |
|
|
|
|
|
for item in log: |
|
|
step = item.get("step") |
|
|
if step is None: |
|
|
continue |
|
|
if "loss" in item: |
|
|
steps.append(step); train_losses.append(item["loss"]) |
|
|
if "eval_loss" in item: |
|
|
eval_steps.append(step); eval_losses.append(item["eval_loss"]) |
|
|
if "learning_rate" in item: |
|
|
lr_steps.append(step); lrs.append(item["learning_rate"]) |
|
|
|
|
|
outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
epoch_marks = None if args.no_epoch_marks else find_epoch_boundaries(log) |
|
|
|
|
|
raw = [s.strip() for s in args.checkpoint_steps.replace(",", ",").split(",") if s.strip()] |
|
|
checkpoint_steps = [] |
|
|
for s in raw: |
|
|
try: |
|
|
checkpoint_steps.append(int(float(s))) |
|
|
except: |
|
|
pass |
|
|
|
|
|
if steps and train_losses: |
|
|
plot_series(steps, train_losses, "Step", "Training Loss", "Training Loss vs Step", |
|
|
outdir / "loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps) |
|
|
if eval_steps and eval_losses: |
|
|
plot_series(eval_steps, eval_losses, "Step", "Eval Loss", "Eval Loss vs Step", |
|
|
outdir / "eval_loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps) |
|
|
if lr_steps and lrs: |
|
|
plot_series(lr_steps, lrs, "Step", "Learning Rate", "Learning Rate vs Step", |
|
|
outdir / "lr_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps) |
|
|
|
|
|
print(f"Saved plots to: {outdir.resolve()}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|