Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| class GeneralLoRALoader: | |
| def __init__(self, device="cpu", torch_dtype=torch.float32): | |
| self.device = device | |
| self.torch_dtype = torch_dtype | |
| def get_name_dict(self, lora_state_dict): | |
| lora_name_dict = {} | |
| for key in lora_state_dict: | |
| if ".lora_B." not in key: | |
| continue | |
| keys = key.split(".") | |
| if len(keys) > keys.index("lora_B") + 2: | |
| keys.pop(keys.index("lora_B") + 1) | |
| keys.pop(keys.index("lora_B")) | |
| if keys[0] == "diffusion_model": | |
| keys.pop(0) | |
| keys.pop(-1) | |
| target_name = ".".join(keys) | |
| lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A.")) | |
| return lora_name_dict | |
| def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): | |
| updated_num = 0 | |
| lora_name_dict = self.get_name_dict(state_dict_lora) | |
| for name, module in model.named_modules(): | |
| if name in lora_name_dict: | |
| weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype) | |
| weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype) | |
| if len(weight_up.shape) == 4: | |
| weight_up = weight_up.squeeze(3).squeeze(2) | |
| weight_down = weight_down.squeeze(3).squeeze(2) | |
| weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) | |
| else: | |
| weight_lora = alpha * torch.mm(weight_up, weight_down) | |
| state_dict = module.state_dict() | |
| state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora | |
| module.load_state_dict(state_dict) | |
| updated_num += 1 | |
| print(f"{updated_num} tensors are updated by LoRA.") |