File size: 6,142 Bytes
be88838
 
 
 
 
 
 
 
a0e6fd7
be88838
 
a0e6fd7
 
 
be88838
 
a0e6fd7
be88838
 
a0e6fd7
 
 
 
 
 
 
 
 
 
be88838
 
a0e6fd7
 
be88838
 
 
 
 
 
 
 
 
 
a0e6fd7
 
be88838
 
 
 
 
 
 
 
 
 
 
a0e6fd7
be88838
a0e6fd7
 
be88838
a0e6fd7
be88838
 
a0e6fd7
be88838
 
a0e6fd7
 
 
 
 
be88838
 
 
 
 
a0e6fd7
 
 
be88838
 
 
 
 
a0e6fd7
be88838
 
 
 
a0e6fd7
be88838
 
a0e6fd7
 
be88838
 
 
 
 
 
 
 
 
a0e6fd7
be88838
a0e6fd7
 
 
be88838
a0e6fd7
 
be88838
 
 
 
 
a0e6fd7
 
 
be88838
a0e6fd7
 
 
be88838
a0e6fd7
 
be88838
a0e6fd7
be88838
a0e6fd7
be88838
a0e6fd7
be88838
 
a0e6fd7
be88838
 
 
a0e6fd7
be88838
 
 
 
a0e6fd7
be88838
 
 
 
a0e6fd7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    import xformers
    import xformers.ops
    xformers_available = True
except Exception:
    xformers_available = False


# Region Controller (unchanged)
class RegionControler:
    def __init__(self) -> None:
        self.prompt_image_conditioning = []

region_control = RegionControler()


# Helper function for weight initialization
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)


# Base Attention Processor
class BaseAttnProcessor(nn.Module):
    def __init__(self):
        super().__init__()

    def _process_input(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb):
        """Handles preprocessing for both AttnProcessor and IPAttnProcessor"""
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
        else:
            batch_size, sequence_length, _ = hidden_states.shape

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        return hidden_states, encoder_hidden_states, residual, batch_size, input_ndim, height, width

    def _apply_attention(self, attn, query, key, value, attention_mask):
        """Handles the actual attention operation using either xformers or standard PyTorch"""
        if xformers_available:
            return xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
        else:
            attention_probs = attn.get_attention_scores(query, key, attention_mask)
            return torch.bmm(attention_probs, value)


# Optimized AttnProcessor
class AttnProcessor(BaseAttnProcessor):
    def forward(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
        hidden_states, encoder_hidden_states, residual, batch_size, input_ndim, height, width = \
            self._process_input(attn, hidden_states, encoder_hidden_states, attention_mask, temb)

        query = attn.to_q(hidden_states)
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query, key, value = map(attn.head_to_batch_dim, (query, key, value))
        hidden_states = self._apply_attention(attn, query, key, value, attention_mask)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, -1, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        return hidden_states / attn.rescale_output_factor


# Optimized IPAttnProcessor
class IPAttnProcessor(BaseAttnProcessor):
    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
        super().__init__()
        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.num_tokens = num_tokens

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.apply(init_weights)

    def forward(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
        hidden_states, encoder_hidden_states, residual, batch_size, input_ndim, height, width = \
            self._process_input(attn, hidden_states, encoder_hidden_states, attention_mask, temb)

        end_pos = encoder_hidden_states.shape[1] - self.num_tokens
        encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]

        query = attn.to_q(hidden_states)
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query, key, value = map(attn.head_to_batch_dim, (query, key, value))
        hidden_states = self._apply_attention(attn, query, key, value, attention_mask)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # Image Prompt Attention
        ip_key = attn.head_to_batch_dim(self.to_k_ip(ip_hidden_states))
        ip_value = attn.head_to_batch_dim(self.to_v_ip(ip_hidden_states))

        ip_hidden_states = self._apply_attention(attn, query, ip_key, ip_value, None)
        ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)

        # Region Control
        if len(region_control.prompt_image_conditioning) == 1:
            region_mask = region_control.prompt_image_conditioning[0].get("region_mask", None)
            if region_mask is not None:
                mask = F.interpolate(region_mask[None, None], scale_factor=(ip_hidden_states.shape[1] / region_mask.shape[0]), mode="nearest").reshape([1, -1, 1])
            else:
                mask = torch.ones_like(ip_hidden_states)
            ip_hidden_states *= mask

        hidden_states = hidden_states + self.scale * ip_hidden_states

        # Linear projection and dropout
        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, -1, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        return hidden_states / attn.rescale_output_factor