File size: 10,821 Bytes
26a63c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
import os
import re
import json
import gc
import functools
import contextlib
from typing import Dict, Union, Optional, Type, Set

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import (
    StateDictType,
    FullOptimStateDictConfig,
    FullStateDictConfig,
)
import torch.distributed.checkpoint as torch_dcp
import torch.distributed.checkpoint.state_dict
from torch.distributed.fsdp.api import (
    ShardingStrategy,
    BackwardPrefetch,
    MixedPrecision,
)
import accelerate
import safetensors
import diffusers
import transformers
from huggingface_hub.serialization import split_torch_state_dict_into_shards
import os, re, json
from typing import Union
import torch
import safetensors.torch
import accelerate
# from .ema_utils import EMAModel


def upcast_trainable_param_to_fp32_(fsdp_model):
    for m in FSDP.fsdp_modules(fsdp_model):
        if m._has_params:
            param = m._flat_param
            if (
                param.dtype != torch.float32
                and param.device != torch.device("meta")
                and param.requires_grad
            ):
                param.data = param.data.to(torch.float32)
                m._handle._orig_param_dtype = torch.float32


def get_module_to_ignore_mixed_precision():
    try:
        from apex.normalization import FusedLayerNorm

        return [
            torch.nn.GroupNorm,
            torch.nn.modules.batchnorm._BatchNorm,
            torch.nn.LayerNorm,
            FusedLayerNorm,
        ]
    except:
        return [
            torch.nn.GroupNorm,
            torch.nn.modules.batchnorm._BatchNorm,
            torch.nn.LayerNorm,
        ]


def is_fsdp_model(model):
    return len(FSDP.fsdp_modules(model)) > 0


def size_based_auto_wrap_policy(
    module: torch.nn.Module,
    recurse: bool,
    nonwrapped_numel: int,
    # Additional custom arguments
    min_num_params: int = int(1e8),
    force_leaf_modules: Optional[Set[Type[torch.nn.Module]]] = None,
    exclude_wrap_modules: Optional[Set[Type[torch.nn.Module]]] = None,
) -> bool:
    """
    A size-based auto wrap policy.

    Args:
        module (nn.Module): Current module being considered.
        recurse (bool): If ``False``, then this function must decide whether
            ``module`` should be wrapped as an FSDP instance or not. If
            ``True``, then the function is still recursing down the module
            tree as a part of the DFS.
        nonwrapped_numel (int): Parameter numel not yet wrapped.

        min_num_params (int): Customizable policy input that controls the size
            threshold over which a module is ready to be wrapped. This is in
            units of numel.
        force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep
            as leaves, i.e. their children will never be wrapped.
        exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be
            excluded in wrapping.

    Returns:
        Whether ``module`` should be wrapped.
    """
    force_leaf_modules = (
        size_based_auto_wrap_policy.FORCE_LEAF_MODULES  # type: ignore[attr-defined]
        if force_leaf_modules is None
        else force_leaf_modules
    )
    exclude_wrap_modules = (
        size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES  # type: ignore[attr-defined]
        if exclude_wrap_modules is None
        else exclude_wrap_modules
    )

    # Keep the argument `min_num_params` for BC for now, but it represents the
    # minimum non-wrapped *numel* before triggering a wrapping
    min_nonwrapped_numel = min_num_params
    is_large = nonwrapped_numel >= min_nonwrapped_numel
    STOP_FLAG_NAME = "__FSDP_STOP_WARP_FLAG_CUSTOM_POLICY_size_based_auto_wrap_policy"
    if recurse:
        # use MixedPrecision cause ALWAYS recurse
        if isinstance(module, tuple(force_leaf_modules)):
            for m in module.children():
                m.apply(lambda m: setattr(m, STOP_FLAG_NAME, True))
        return True
    else:
        if getattr(module, size_based_auto_wrap_policy.LEAF_ROOT_FLAG_NAME, False):
            return True
        elif getattr(module, STOP_FLAG_NAME, False):
            return False
        else:
            # If we are not recursing, determine if we should wrap.
            return is_large and not isinstance(module, tuple(exclude_wrap_modules))


# Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported.
size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {torch.nn.ModuleList, torch.nn.ModuleDict}  # type: ignore[attr-defined]
size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {torch.nn.MultiheadAttention}  # type: ignore[attr-defined]
size_based_auto_wrap_policy.LEAF_ROOT_FLAG_NAME = (
    "__FSDP_LEAF_ROOT_FLAG_CUSTOM_POLICY_size_based_auto_wrap_policy"
)


def mark_leaf_root_(module):
    setattr(
        module,
        size_based_auto_wrap_policy.LEAF_ROOT_FLAG_NAME,
        True,
    )


def make_model_fsdp(
    model,
    param_dtype,
    device,
    reduce_dtype=None,
    buffer_dtype=None,
    sync_module_states=True,
    process_group=None,
    sharding_strategy=ShardingStrategy.HYBRID_SHARD,
    module_classes_to_ignore_mixed_precision=None,
    ignored_states=None,
    ignored_modules=None,
    auto_wrap_policy=None,
    part_size=1e6,
    force_leaf_modules=None,
    exclude_wrap_modules=None,
    use_orig_params=False
):
    if module_classes_to_ignore_mixed_precision is None:
        module_classes_to_ignore_mixed_precision = (
            get_module_to_ignore_mixed_precision()
        )
    if auto_wrap_policy is not None:
        auto_wrap_policy = auto_wrap_policy
    elif sharding_strategy == ShardingStrategy.NO_SHARD:
        auto_wrap_policy = None
    else:
        auto_wrap_policy = functools.partial(
            size_based_auto_wrap_policy,
            min_num_params=part_size,
            force_leaf_modules=force_leaf_modules,
            exclude_wrap_modules=exclude_wrap_modules,
        )

    model = FSDP(
        model,
        sharding_strategy=sharding_strategy,
        process_group=process_group,
        forward_prefetch=True,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        limit_all_gathers=True,
        use_orig_params=use_orig_params,
        sync_module_states=sync_module_states,
        mixed_precision=MixedPrecision(
            param_dtype=param_dtype,
            reduce_dtype=reduce_dtype or torch.float32,
            buffer_dtype=buffer_dtype or torch.float32,
            keep_low_precision_grads=False,
            cast_forward_inputs=False,
            cast_root_forward_inputs=True,
            _module_classes_to_ignore=module_classes_to_ignore_mixed_precision,
        ),
        auto_wrap_policy=auto_wrap_policy,
        ignored_states=ignored_states,
        ignored_modules=ignored_modules,
        device_id=device,
    )
    torch.cuda.empty_cache()
    gc.collect()
    return model


def save_fsdp_lora(
    model_to_save,                      # FSDP 包裹的模型
    save_directory: Union[str, os.PathLike],
    is_main_process: bool = True,
    lora_regex: str = r"(?:lora)",  # 根据自己命名习惯调
):
    """
    仅保存 LoRA 层的权重。适用于 FSDP 并与 safetensors 兼容。
    """
    # 1. 解包 FSDP,拿到裸模型
    unwrapped_model = accelerate.utils.extract_model_from_parallel(model_to_save)

    # 2. 创建保存目录
    if is_main_process:
        os.makedirs(save_directory, exist_ok=True)

    # 3. 收集完整 state_dict(CPU 上)
    state_dict = torch_dcp.state_dict.get_model_state_dict(
        model_to_save,
        options=torch_dcp.state_dict.StateDictOptions(
            full_state_dict=True,
            cpu_offload=True,
            ignore_frozen_params=False,
        ),
    )

    # 4. 过滤出 LoRA 参数
    lora_pattern = re.compile(lora_regex)
    lora_state_dict = {
        k: v for k, v in state_dict.items() if lora_pattern.search(k) is not None
    }

    if not lora_state_dict:
        raise ValueError(
            "未找到匹配 LoRA 的参数。请检查 lora_regex 是否符合命名规则。"
        )

    # 5. 保存为单文件 *.safetensors
    if is_main_process:
        weight_file = os.path.join(save_directory, "adapter_model.safetensors")
        safetensors.torch.save_file(
            lora_state_dict, weight_file, metadata={"format": "pt", "type": "lora"}
        )


def load_fsdp_model_(model_to_load: FSDP, save_directory: Union[str, os.PathLike]):
    with FSDP.state_dict_type(
        model_to_load,
        state_dict_type=StateDictType.FULL_STATE_DICT,
        state_dict_config=FullStateDictConfig(
            rank0_only=False,
        ),
    ):
        _model = model_to_load.from_pretrained(save_directory)
        model_to_load.load_state_dict(_model.state_dict())


def save_fsdp_optimizer(
    models: Dict,
    optimizer_to_save: torch.optim.Optimizer,
    save_directory: Union[str, os.PathLike],
    is_main_process: bool = True,
):
    _fsdp_state_dict_config = dict(
        state_dict_type=StateDictType.FULL_STATE_DICT,
        optim_state_dict_config=FullOptimStateDictConfig(
            offload_to_cpu=True,
            rank0_only=True,
        ),
    )
    mgrs = list()
    for m in models.values():
        if len(FSDP.fsdp_modules(m)) > 0:
            mgrs.append(FSDP.state_dict_type(m, **_fsdp_state_dict_config))

    with contextlib.ExitStack() as stack:
        for mgr in mgrs:
            stack.enter_context(mgr)
        optim_state_dict = FSDP.optim_state_dict(
            torch.nn.ModuleDict(models),
            optimizer_to_save,
        )
        if is_main_process:
            torch.save(
                optim_state_dict, os.path.join(save_directory, "optim_states.pth")
            )


def load_fsdp_optimizer_(
    models: Dict,
    optimizer_to_load: torch.optim.Optimizer,
    save_directory: Union[str, os.PathLike],
):
    _fsdp_state_dict_config = dict(
        state_dict_type=StateDictType.FULL_STATE_DICT,
        optim_state_dict_config=FullOptimStateDictConfig(
            rank0_only=False,
        ),
    )
    mgrs = list()
    for m in models.values():
        if len(FSDP.fsdp_modules(m)) > 0:
            mgrs.append(FSDP.state_dict_type(m, **_fsdp_state_dict_config))

    with contextlib.ExitStack() as stack:
        for mgr in mgrs:
            stack.enter_context(mgr)
        optimizer_path = os.path.join(save_directory, "optim_states.pth")
        assert os.path.isfile(optimizer_path)
        optim_state_dict = torch.load(optimizer_path)
        optim_state_dict = FSDP.optim_state_dict_to_load(
            torch.nn.ModuleDict(models),
            optimizer_to_load,
            optim_state_dict,
        )
        optimizer_to_load.load_state_dict(optim_state_dict)