File size: 6,039 Bytes
010a017
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# -*- coding: utf-8 -*-
"""
Usage:
  python plot_loss_from_trainer_state.py --input trainer_state.json --outdir ./plots \
         --checkpoint_steps 263,526,789,1052

功能:
- Curve: 黃橘色實線
- Grid: x,y 虛線
- Epoch markers: 藍色虛線 + EpochN 標籤(含最後一個 epoch)
- Checkpoints: 藍色小圓點(線性插值;超出範圍時使用端點值,並自動擴張 x 軸確保能看見)
"""
import json, argparse
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

YELLOW_ORANGE = "#d58f00"
BLUE = "#1f77b4"

def find_epoch_boundaries(log_items):
    """找到每個 epoch 邊界 (包含最後一個)"""
    boundaries = []
    prev_epoch_int = None
    seen = set()
    last_step, last_epoch = None, None
    for it in log_items:
        step = it.get("step")
        ep = it.get("epoch")
        if step is None or ep is None:
            continue
        last_step, last_epoch = step, ep
        ep_int = int(ep)
        if prev_epoch_int is None:
            prev_epoch_int = ep_int
            continue
        if ep_int != prev_epoch_int:
            if (step, ep_int) not in seen and ep_int >= 1:
                boundaries.append((step, ep_int))
                seen.add((step, ep_int))
            prev_epoch_int = ep_int
    # 最後一個 epoch 也補上
    if last_step is not None and last_epoch is not None:
        ep_final = int(float(last_epoch)) + 1
        if (last_step, ep_final) not in seen:
            boundaries.append((last_step, ep_final))
    boundaries.sort(key=lambda x: x[0])
    return boundaries

def plot_series(x, y, xlabel, ylabel, title, outpath,
                epoch_marks=None, checkpoint_steps=None,
                color=YELLOW_ORANGE, linestyle='-'):
    fig = plt.figure(figsize=(10,6))
    ax = fig.add_subplot(111)
    ax.plot(x, y, color=color, linestyle=linestyle, linewidth=2)

    # 標記 checkpoint 藍點(線性插值;邊界外使用端點值)
    extra_x = []
    if checkpoint_steps:
        for s in checkpoint_steps:
            y_interp = np.interp(s, x, y, left=y[0], right=y[-1])
            ax.plot(s, y_interp, marker='o', color=BLUE, markersize=6)
            extra_x.append(s)

    # === 計算 x 範圍時把 epoch 標線也納入,並加右側 padding ===
    xmin = 0
    all_x_candidates = [max(x)]
    if extra_x:
        all_x_candidates.append(max(extra_x))
    if epoch_marks:
        # 把所有 epoch 標線的 step 納入考量
        ep_steps = [s for (s, _) in epoch_marks]
        if ep_steps:
            all_x_candidates.append(max(ep_steps))

    xmax_base = max(all_x_candidates) if all_x_candidates else x[-1]

    # 右邊加一點 margin,避免剛好貼齊看不到線
    span = max(xmax_base - xmin, 1.0)
    right_pad = max(1.0, 0.02 * span)   # 至少 +1 step 或 2% 寬度
    ax.set_xlim(left=xmin, right=xmax_base + right_pad)

    # y 仍從 0 起
    ax.set_ylim(bottom=0)

    # 虛線格線
    ax.grid(True, which='major', axis='both', linestyle='--', linewidth=0.8, alpha=0.6)

    # epoch 標記 (藍色虛線)
    if epoch_marks:
        for step, ep in epoch_marks:
            ax.axvline(x=step, color=BLUE, linestyle='--', linewidth=1.2)
            ymax = ax.get_ylim()[1]
            ax.text(step, ymax*0.98, f'Epoch{ep}', rotation=90,
                    va='top', ha='right', fontsize=8, color=BLUE)

    # label & look(放到最後避免被 set_xlim/set_ylim 影響)
    ax.set_xlabel(xlabel); ax.set_ylabel(ylabel); ax.set_title(title)
    ax.spines['left'].set_linewidth(2); ax.spines['bottom'].set_linewidth(2)
    ax.spines['right'].set_visible(False); ax.spines['top'].set_visible(False)

    fig.savefig(outpath, bbox_inches="tight")
    plt.close(fig)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--input", required=True, help="Path to trainer_state.json")
    ap.add_argument("--outdir", default="./plots", help="Directory to save PNGs")
    ap.add_argument("--no_epoch_marks", action="store_true", help="Disable vertical epoch markers")
    ap.add_argument("--checkpoint_steps", default="", help="Comma-separated steps (e.g., 100,200,500)")
    args = ap.parse_args()

    src = Path(args.input)
    with open(src, "r", encoding="utf-8") as f:
        state = json.load(f)

    log = state.get("log_history", state.get("logs", []))

    steps, train_losses = [], []
    eval_steps, eval_losses = [], []
    lr_steps, lrs = [], []

    for item in log:
        step = item.get("step")
        if step is None: 
            continue
        if "loss" in item:
            steps.append(step); train_losses.append(item["loss"])
        if "eval_loss" in item:
            eval_steps.append(step); eval_losses.append(item["eval_loss"])
        if "learning_rate" in item:
            lr_steps.append(step); lrs.append(item["learning_rate"])

    outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True)

    epoch_marks = None if args.no_epoch_marks else find_epoch_boundaries(log)
    # 允許空白與混合格式
    raw = [s.strip() for s in args.checkpoint_steps.replace(",", ",").split(",") if s.strip()]
    checkpoint_steps = []
    for s in raw:
        try:
            checkpoint_steps.append(int(float(s)))
        except:
            pass

    if steps and train_losses:
        plot_series(steps, train_losses, "Step", "Training Loss", "Training Loss vs Step",
                    outdir / "loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps)
    if eval_steps and eval_losses:
        plot_series(eval_steps, eval_losses, "Step", "Eval Loss", "Eval Loss vs Step",
                    outdir / "eval_loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps)
    if lr_steps and lrs:
        plot_series(lr_steps, lrs, "Step", "Learning Rate", "Learning Rate vs Step",
                    outdir / "lr_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps)

    print(f"Saved plots to: {outdir.resolve()}")

if __name__ == "__main__":
    main()