File size: 9,122 Bytes
f460ce6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from diffusers.models.attention_processor import FluxAttnProcessor2_0
from safetensors.torch import load_file
import re
import torch
from .layers_cache import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor

device = "cuda"

def load_safetensors(path):
    """Safely loads tensors from a file and maps them to the CPU."""
    return load_file(path, device="cpu")

def get_lora_count_from_checkpoint(checkpoint):
    """
    Infers the number of LoRA modules stored in a checkpoint by inspecting its keys.
    Also prints a sample of keys for debugging.
    """
    lora_indices = set()
    # Regex to find '..._loras.X.' where X is a number.
    indexed_pattern = re.compile(r'._loras\.(\d+)\.')
    found_keys = []

    for key in checkpoint.keys():
        match = indexed_pattern.search(key)
        if match:
            lora_indices.add(int(match.group(1)))
            if len(found_keys) < 5 and key not in found_keys:
                found_keys.append(key)

    if lora_indices:
        lora_count = max(lora_indices) + 1
        print("INFO: Auto-detected indexed LoRA keys in checkpoint.")
        print(f"      Found {lora_count} LoRA module(s).")
        print("      Sample keys:", found_keys)
        return lora_count

    # Fallback for legacy, non-indexed checkpoints.
    legacy_found = False
    legacy_key_sample = ""
    for key in checkpoint.keys():
        if '.q_lora.' in key:
            legacy_found = True
            legacy_key_sample = key
            break

    if legacy_found:
        print("INFO: Auto-detected legacy (non-indexed) LoRA keys in checkpoint.")
        print("      Assuming 1 LoRA module.")
        print("      Sample key:", legacy_key_sample)
        return 1

    print("WARNING: No LoRA keys found in the checkpoint.")
    return 0

def get_lora_ranks(checkpoint, num_loras):
    """
    Determines the rank for each LoRA module from the checkpoint.
    It supports both indexed (e.g., 'loras.0') and legacy non-indexed formats.
    """
    ranks = {}
    
    # First, try to find ranks for all indexed LoRA modules.
    for i in range(num_loras):
        # Find a key that uniquely identifies the i-th LoRA's down projection.
        rank_pattern = re.compile(f'._loras\.({i})\.down\.weight')
        for k, v in checkpoint.items():
            if rank_pattern.search(k):
                ranks[i] = v.shape[0]
                break
    
    # If not all ranks were found, there might be legacy keys or a mismatch.
    if len(ranks) != num_loras:
        # Fallback for single, non-indexed LoRA checkpoints.
        if num_loras == 1:
            for k, v in checkpoint.items():
                if ".q_lora.down.weight" in k:
                    return [v.shape[0]]

        # If still unresolved, use the rank of the very first LoRA found as a default for all.
        first_found_rank = next((v.shape[0] for k, v in checkpoint.items() if k.endswith(".down.weight")), None)
        
        if first_found_rank is None:
            raise ValueError("Could not determine any LoRA rank from the provided checkpoint.")

        # Return a list where missing ranks are filled with the first one found.
        return [ranks.get(i, first_found_rank) for i in range(num_loras)]

    # Return the list of ranks sorted by LoRA index.
    return [ranks[i] for i in range(num_loras)]


def load_checkpoint(local_path):
    if local_path is not None:
        if '.safetensors' in local_path:
            print(f"Loading .safetensors checkpoint from {local_path}")
            checkpoint = load_safetensors(local_path)
        else:
            print(f"Loading checkpoint from {local_path}")
            checkpoint = torch.load(local_path, map_location='cpu')
    return checkpoint


def prepare_lora_processors(checkpoint, lora_weights, transformer, cond_size, number=None):
    # Ensure processors match the transformer's device and dtype
    try:
        first_param = next(transformer.parameters())
        target_device = first_param.device
        target_dtype = first_param.dtype
    except StopIteration:
        target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        target_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

    if number is None:
        number = get_lora_count_from_checkpoint(checkpoint)
        if number == 0:
            return {} 

        if lora_weights and len(lora_weights) != number:
            print(f"WARNING: Provided `lora_weights` length ({len(lora_weights)}) differs from detected LoRA count ({number}).")
            final_weights = (lora_weights + [1.0] * number)[:number]
            print(f"         Adjusting weights to: {final_weights}")
            lora_weights = final_weights
        elif not lora_weights:
            print(f"INFO: No `lora_weights` provided. Defaulting to weights of 1.0 for all {number} LoRAs.")
            lora_weights = [1.0] * number
    
    ranks = get_lora_ranks(checkpoint, number)
    print("INFO: Determined ranks for LoRA modules:", ranks)
    
    cond_widths = cond_size if isinstance(cond_size, list) else [cond_size] * number
    cond_heights = cond_size if isinstance(cond_size, list) else [cond_size] * number
    
    lora_attn_procs = {}
    double_blocks_idx = list(range(19))
    single_blocks_idx = list(range(38))
    
    # Get all attention processor names from the transformer to iterate over
    for name in transformer.attn_processors.keys():
        match = re.search(r'\.(\d+)\.', name)
        if not match:
            continue
        layer_index = int(match.group(1))

        if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
            lora_state_dicts = {
                key: value for key, value in checkpoint.items() 
                if f"transformer_blocks.{layer_index}." in key
            }

            lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
                dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, 
                device=target_device, dtype=target_dtype, cond_widths=cond_widths, cond_heights=cond_heights, n_loras=number
            )

            for n in range(number):
                lora_prefix_q = f"{name}.q_loras.{n}"
                lora_prefix_k = f"{name}.k_loras.{n}"
                lora_prefix_v = f"{name}.v_loras.{n}"
                lora_prefix_proj = f"{name}.proj_loras.{n}"
                
                lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.down.weight')
                lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.up.weight')
                lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.down.weight')
                lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.up.weight')
                lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.down.weight')
                lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.up.weight')
                lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_proj}.down.weight')
                lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_proj}.up.weight')
                lora_attn_procs[name].to(device=target_device, dtype=target_dtype)
        
        elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
            lora_state_dicts = {
                key: value for key, value in checkpoint.items() 
                if f"single_transformer_blocks.{layer_index}." in key
            }
            
            lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
                dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, 
                device=target_device, dtype=target_dtype, cond_widths=cond_widths, cond_heights=cond_heights, n_loras=number
            )

            for n in range(number):
                lora_prefix_q = f"{name}.q_loras.{n}"
                lora_prefix_k = f"{name}.k_loras.{n}"
                lora_prefix_v = f"{name}.v_loras.{n}"
                
                lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.down.weight')
                lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.up.weight')
                lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.down.weight')
                lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.up.weight')
                lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.down.weight')
                lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.up.weight')
                lora_attn_procs[name].to(device=target_device, dtype=target_dtype)
    return lora_attn_procs