Spaces:
Runtime error
Runtime error
File size: 6,142 Bytes
be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 be88838 a0e6fd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
import xformers
import xformers.ops
xformers_available = True
except Exception:
xformers_available = False
# Region Controller (unchanged)
class RegionControler:
def __init__(self) -> None:
self.prompt_image_conditioning = []
region_control = RegionControler()
# Helper function for weight initialization
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
# Base Attention Processor
class BaseAttnProcessor(nn.Module):
def __init__(self):
super().__init__()
def _process_input(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb):
"""Handles preprocessing for both AttnProcessor and IPAttnProcessor"""
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
else:
batch_size, sequence_length, _ = hidden_states.shape
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
return hidden_states, encoder_hidden_states, residual, batch_size, input_ndim, height, width
def _apply_attention(self, attn, query, key, value, attention_mask):
"""Handles the actual attention operation using either xformers or standard PyTorch"""
if xformers_available:
return xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
else:
attention_probs = attn.get_attention_scores(query, key, attention_mask)
return torch.bmm(attention_probs, value)
# Optimized AttnProcessor
class AttnProcessor(BaseAttnProcessor):
def forward(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
hidden_states, encoder_hidden_states, residual, batch_size, input_ndim, height, width = \
self._process_input(attn, hidden_states, encoder_hidden_states, attention_mask, temb)
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query, key, value = map(attn.head_to_batch_dim, (query, key, value))
hidden_states = self._apply_attention(attn, query, key, value, attention_mask)
hidden_states = attn.batch_to_head_dim(hidden_states)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, -1, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
return hidden_states / attn.rescale_output_factor
# Optimized IPAttnProcessor
class IPAttnProcessor(BaseAttnProcessor):
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.apply(init_weights)
def forward(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
hidden_states, encoder_hidden_states, residual, batch_size, input_ndim, height, width = \
self._process_input(attn, hidden_states, encoder_hidden_states, attention_mask, temb)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query, key, value = map(attn.head_to_batch_dim, (query, key, value))
hidden_states = self._apply_attention(attn, query, key, value, attention_mask)
hidden_states = attn.batch_to_head_dim(hidden_states)
# Image Prompt Attention
ip_key = attn.head_to_batch_dim(self.to_k_ip(ip_hidden_states))
ip_value = attn.head_to_batch_dim(self.to_v_ip(ip_hidden_states))
ip_hidden_states = self._apply_attention(attn, query, ip_key, ip_value, None)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
# Region Control
if len(region_control.prompt_image_conditioning) == 1:
region_mask = region_control.prompt_image_conditioning[0].get("region_mask", None)
if region_mask is not None:
mask = F.interpolate(region_mask[None, None], scale_factor=(ip_hidden_states.shape[1] / region_mask.shape[0]), mode="nearest").reshape([1, -1, 1])
else:
mask = torch.ones_like(ip_hidden_states)
ip_hidden_states *= mask
hidden_states = hidden_states + self.scale * ip_hidden_states
# Linear projection and dropout
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, -1, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
return hidden_states / attn.rescale_output_factor
|