Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class LoRAModule(nn.Module): | |
| """ | |
| LoRA module that replaces the forward method of an original Linear or Conv2D module. | |
| """ | |
| def __init__( | |
| self, | |
| lora_name: str, | |
| org_module: nn.Module, | |
| multiplier: float = 1.0, | |
| lora_dim: int = 4, | |
| alpha: Optional[float] = None, | |
| dropout: Optional[float] = None, | |
| rank_dropout: Optional[float] = None, | |
| module_dropout: Optional[float] = None, | |
| ): | |
| """ | |
| Args: | |
| lora_name (str): Name of the LoRA module. | |
| org_module (nn.Module): The original module to wrap. | |
| multiplier (float): Scaling factor for the LoRA output. | |
| lora_dim (int): The rank of the LoRA decomposition. | |
| alpha (float, optional): Scaling factor for LoRA weights. Defaults to lora_dim. | |
| dropout (float, optional): Dropout probability. Defaults to None. | |
| rank_dropout (float, optional): Dropout probability for rank reduction. Defaults to None. | |
| module_dropout (float, optional): Probability of completely dropping the module during training. Defaults to None. | |
| """ | |
| super().__init__() | |
| self.lora_name = lora_name | |
| self.multiplier = multiplier | |
| self.lora_dim = lora_dim | |
| self.dropout = dropout | |
| self.rank_dropout = rank_dropout | |
| self.module_dropout = module_dropout | |
| # Determine layer type (Linear or Conv2D) | |
| is_conv2d = isinstance(org_module, nn.Conv2d) | |
| in_dim = org_module.in_channels if is_conv2d else org_module.in_features | |
| out_dim = org_module.out_channels if is_conv2d else org_module.out_features | |
| # Define LoRA layers | |
| if is_conv2d: | |
| self.lora_down = nn.Conv2d(in_dim, lora_dim, kernel_size=org_module.kernel_size, | |
| stride=org_module.stride, padding=org_module.padding, bias=False) | |
| self.lora_up = nn.Conv2d(lora_dim, out_dim, kernel_size=1, stride=1, bias=False) | |
| else: | |
| self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) | |
| self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) | |
| # Initialize weights | |
| nn.init.xavier_uniform_(self.lora_down.weight) | |
| nn.init.zeros_(self.lora_up.weight) | |
| # Set alpha scaling factor | |
| self.scale = (alpha if alpha is not None else lora_dim) / lora_dim | |
| self.register_buffer("alpha", torch.tensor(self.scale, dtype=torch.float32)) | |
| # Store reference to the original module | |
| self.org_module = org_module | |
| self.org_forward = org_module.forward | |
| def apply_to(self): | |
| """Replace the forward method of the original module with this module's forward method.""" | |
| self.org_module.forward = self.forward | |
| del self.org_module | |
| def forward(self, x): | |
| """ | |
| Forward pass for LoRA-enhanced module. | |
| """ | |
| if self.module_dropout and self.training and torch.rand(1).item() < self.module_dropout: | |
| return self.org_forward(x) | |
| # Compute LoRA down projection | |
| lora_output = self.lora_down(x) | |
| # Apply dropout if training | |
| if self.training: | |
| if self.dropout: | |
| lora_output = F.dropout(lora_output, p=self.dropout) | |
| if self.rank_dropout: | |
| dropout_mask = torch.rand_like(lora_output) > self.rank_dropout | |
| lora_output *= dropout_mask | |
| scale_factor = 1.0 / (1.0 - self.rank_dropout) | |
| lora_output *= scale_factor | |
| # Compute LoRA up projection | |
| lora_output = self.lora_up(lora_output) | |
| # Combine with original output | |
| return self.org_forward(x) + lora_output * self.multiplier * self.scale | |