Johannes Kolbe
commited on
Commit
·
ff2b8e3
1
Parent(s):
3b72cdb
add original sefa files back in
Browse files- SessionState.py +129 -0
- interface.py +128 -0
- models/__init__.py +114 -0
- models/pggan_discriminator.py +402 -0
- models/pggan_generator.py +338 -0
- models/stylegan2_discriminator.py +468 -0
- models/stylegan2_generator.py +996 -0
- models/stylegan_discriminator.py +530 -0
- models/stylegan_generator.py +869 -0
- models/sync_op.py +18 -0
- sefa.py +145 -0
- utils.py +509 -0
SessionState.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Adds pre-session state to StreamLit.
|
| 2 |
+
|
| 3 |
+
This file is borrowed from
|
| 4 |
+
https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
# pylint: disable=protected-access
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import streamlit.ReportThread as ReportThread
|
| 11 |
+
from streamlit.server.Server import Server
|
| 12 |
+
except ModuleNotFoundError:
|
| 13 |
+
# Streamlit >= 0.65.0
|
| 14 |
+
import streamlit.report_thread as ReportThread
|
| 15 |
+
from streamlit.server.server import Server
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SessionState(object):
|
| 19 |
+
"""Hack to add per-session state to Streamlit.
|
| 20 |
+
|
| 21 |
+
Usage
|
| 22 |
+
-----
|
| 23 |
+
|
| 24 |
+
>>> import SessionState
|
| 25 |
+
>>>
|
| 26 |
+
>>> session_state = SessionState.get(user_name='', favorite_color='black')
|
| 27 |
+
>>> session_state.user_name
|
| 28 |
+
''
|
| 29 |
+
>>> session_state.user_name = 'Mary'
|
| 30 |
+
>>> session_state.favorite_color
|
| 31 |
+
'black'
|
| 32 |
+
|
| 33 |
+
Since you set user_name above, next time your script runs this will be the
|
| 34 |
+
result:
|
| 35 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
| 36 |
+
>>> session_state.user_name
|
| 37 |
+
'Mary'
|
| 38 |
+
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, **kwargs):
|
| 42 |
+
"""A new SessionState object.
|
| 43 |
+
|
| 44 |
+
Parameters
|
| 45 |
+
----------
|
| 46 |
+
**kwargs : any
|
| 47 |
+
Default values for the session state.
|
| 48 |
+
|
| 49 |
+
Example
|
| 50 |
+
-------
|
| 51 |
+
>>> session_state = SessionState(user_name='', favorite_color='black')
|
| 52 |
+
>>> session_state.user_name = 'Mary'
|
| 53 |
+
''
|
| 54 |
+
>>> session_state.favorite_color
|
| 55 |
+
'black'
|
| 56 |
+
|
| 57 |
+
"""
|
| 58 |
+
for key, val in kwargs.items():
|
| 59 |
+
setattr(self, key, val)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get(**kwargs):
|
| 63 |
+
"""Gets a SessionState object for the current session.
|
| 64 |
+
|
| 65 |
+
Creates a new object if necessary.
|
| 66 |
+
|
| 67 |
+
Parameters
|
| 68 |
+
----------
|
| 69 |
+
**kwargs : any
|
| 70 |
+
Default values you want to add to the session state, if we're creating a
|
| 71 |
+
new one.
|
| 72 |
+
|
| 73 |
+
Example
|
| 74 |
+
-------
|
| 75 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
| 76 |
+
>>> session_state.user_name
|
| 77 |
+
''
|
| 78 |
+
>>> session_state.user_name = 'Mary'
|
| 79 |
+
>>> session_state.favorite_color
|
| 80 |
+
'black'
|
| 81 |
+
|
| 82 |
+
Since you set user_name above, next time your script runs this will be the
|
| 83 |
+
result:
|
| 84 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
| 85 |
+
>>> session_state.user_name
|
| 86 |
+
'Mary'
|
| 87 |
+
|
| 88 |
+
"""
|
| 89 |
+
# Hack to get the session object from Streamlit.
|
| 90 |
+
|
| 91 |
+
ctx = ReportThread.get_report_ctx()
|
| 92 |
+
|
| 93 |
+
this_session = None
|
| 94 |
+
|
| 95 |
+
current_server = Server.get_current()
|
| 96 |
+
if hasattr(current_server, '_session_infos'):
|
| 97 |
+
# Streamlit < 0.56
|
| 98 |
+
session_infos = Server.get_current()._session_infos.values()
|
| 99 |
+
else:
|
| 100 |
+
session_infos = Server.get_current()._session_info_by_id.values()
|
| 101 |
+
|
| 102 |
+
for session_info in session_infos:
|
| 103 |
+
s = session_info.session
|
| 104 |
+
if (
|
| 105 |
+
# Streamlit < 0.54.0
|
| 106 |
+
(hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
|
| 107 |
+
or
|
| 108 |
+
# Streamlit >= 0.54.0
|
| 109 |
+
(not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
|
| 110 |
+
or
|
| 111 |
+
# Streamlit >= 0.65.2
|
| 112 |
+
(not hasattr(s, '_main_dg') and
|
| 113 |
+
s._uploaded_file_mgr == ctx.uploaded_file_mgr)
|
| 114 |
+
):
|
| 115 |
+
this_session = s
|
| 116 |
+
|
| 117 |
+
if this_session is None:
|
| 118 |
+
raise RuntimeError(
|
| 119 |
+
"Oh noes. Couldn't get your Streamlit Session object. "
|
| 120 |
+
'Are you doing something fancy with threads?')
|
| 121 |
+
|
| 122 |
+
# Got the session object! Now let's attach some state into it.
|
| 123 |
+
|
| 124 |
+
if not hasattr(this_session, '_custom_session_state'):
|
| 125 |
+
this_session._custom_session_state = SessionState(**kwargs)
|
| 126 |
+
|
| 127 |
+
return this_session._custom_session_state
|
| 128 |
+
|
| 129 |
+
# pylint: enable=protected-access
|
interface.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python 3.7
|
| 2 |
+
"""Demo."""
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import streamlit as st
|
| 7 |
+
import SessionState
|
| 8 |
+
|
| 9 |
+
from models import parse_gan_type
|
| 10 |
+
from utils import to_tensor
|
| 11 |
+
from utils import postprocess
|
| 12 |
+
from utils import load_generator
|
| 13 |
+
from utils import factorize_weight
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
| 17 |
+
def get_model(model_name):
|
| 18 |
+
"""Gets model by name."""
|
| 19 |
+
return load_generator(model_name)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
| 23 |
+
def factorize_model(model, layer_idx):
|
| 24 |
+
"""Factorizes semantics from target layers of the given model."""
|
| 25 |
+
return factorize_weight(model, layer_idx)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def sample(model, gan_type, num=1):
|
| 29 |
+
"""Samples latent codes."""
|
| 30 |
+
codes = torch.randn(num, model.z_space_dim).cuda()
|
| 31 |
+
if gan_type == 'pggan':
|
| 32 |
+
codes = model.layer0.pixel_norm(codes)
|
| 33 |
+
elif gan_type == 'stylegan':
|
| 34 |
+
codes = model.mapping(codes)['w']
|
| 35 |
+
codes = model.truncation(codes,
|
| 36 |
+
trunc_psi=0.7,
|
| 37 |
+
trunc_layers=8)
|
| 38 |
+
elif gan_type == 'stylegan2':
|
| 39 |
+
codes = model.mapping(codes)['w']
|
| 40 |
+
codes = model.truncation(codes,
|
| 41 |
+
trunc_psi=0.5,
|
| 42 |
+
trunc_layers=18)
|
| 43 |
+
codes = codes.detach().cpu().numpy()
|
| 44 |
+
return codes
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
| 48 |
+
def synthesize(model, gan_type, code):
|
| 49 |
+
"""Synthesizes an image with the give code."""
|
| 50 |
+
if gan_type == 'pggan':
|
| 51 |
+
image = model(to_tensor(code))['image']
|
| 52 |
+
elif gan_type in ['stylegan', 'stylegan2']:
|
| 53 |
+
image = model.synthesis(to_tensor(code))['image']
|
| 54 |
+
image = postprocess(image)[0]
|
| 55 |
+
return image
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def main():
|
| 59 |
+
"""Main function (loop for StreamLit)."""
|
| 60 |
+
st.title('Closed-Form Factorization of Latent Semantics in GANs')
|
| 61 |
+
st.sidebar.title('Options')
|
| 62 |
+
reset = st.sidebar.button('Reset')
|
| 63 |
+
|
| 64 |
+
model_name = st.sidebar.selectbox(
|
| 65 |
+
'Model to Interpret',
|
| 66 |
+
['stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256',
|
| 67 |
+
'pggan_celebahq1024'])
|
| 68 |
+
|
| 69 |
+
model = get_model(model_name)
|
| 70 |
+
gan_type = parse_gan_type(model)
|
| 71 |
+
layer_idx = st.sidebar.selectbox(
|
| 72 |
+
'Layers to Interpret',
|
| 73 |
+
['all', '0-1', '2-5', '6-13'])
|
| 74 |
+
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
|
| 75 |
+
|
| 76 |
+
num_semantics = st.sidebar.number_input(
|
| 77 |
+
'Number of semantics', value=10, min_value=0, max_value=None, step=1)
|
| 78 |
+
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
|
| 79 |
+
if gan_type == 'pggan':
|
| 80 |
+
max_step = 5.0
|
| 81 |
+
elif gan_type == 'stylegan':
|
| 82 |
+
max_step = 2.0
|
| 83 |
+
elif gan_type == 'stylegan2':
|
| 84 |
+
max_step = 15.0
|
| 85 |
+
for sem_idx in steps:
|
| 86 |
+
eigen_value = eigen_values[sem_idx]
|
| 87 |
+
steps[sem_idx] = st.sidebar.slider(
|
| 88 |
+
f'Semantic {sem_idx:03d} (eigen value: {eigen_value:.3f})',
|
| 89 |
+
value=0.0,
|
| 90 |
+
min_value=-max_step,
|
| 91 |
+
max_value=max_step,
|
| 92 |
+
step=0.04 * max_step if not reset else 0.0)
|
| 93 |
+
|
| 94 |
+
image_placeholder = st.empty()
|
| 95 |
+
button_placeholder = st.empty()
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
|
| 99 |
+
except FileNotFoundError:
|
| 100 |
+
base_codes = sample(model, gan_type)
|
| 101 |
+
|
| 102 |
+
state = SessionState.get(model_name=model_name,
|
| 103 |
+
code_idx=0,
|
| 104 |
+
codes=base_codes[0:1])
|
| 105 |
+
if state.model_name != model_name:
|
| 106 |
+
state.model_name = model_name
|
| 107 |
+
state.code_idx = 0
|
| 108 |
+
state.codes = base_codes[0:1]
|
| 109 |
+
|
| 110 |
+
if button_placeholder.button('Random', key=0):
|
| 111 |
+
state.code_idx += 1
|
| 112 |
+
if state.code_idx < base_codes.shape[0]:
|
| 113 |
+
state.codes = base_codes[state.code_idx][np.newaxis]
|
| 114 |
+
else:
|
| 115 |
+
state.codes = sample(model, gan_type)
|
| 116 |
+
|
| 117 |
+
code = state.codes.copy()
|
| 118 |
+
for sem_idx, step in steps.items():
|
| 119 |
+
if gan_type == 'pggan':
|
| 120 |
+
code += boundaries[sem_idx:sem_idx + 1] * step
|
| 121 |
+
elif gan_type in ['stylegan', 'stylegan2']:
|
| 122 |
+
code[:, layers, :] += boundaries[sem_idx:sem_idx + 1] * step
|
| 123 |
+
image = synthesize(model, gan_type, code)
|
| 124 |
+
image_placeholder.image(image / 255.0)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if __name__ == '__main__':
|
| 128 |
+
main()
|
models/__init__.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3.7
|
| 2 |
+
"""Collects all available models together."""
|
| 3 |
+
|
| 4 |
+
from .model_zoo import MODEL_ZOO
|
| 5 |
+
from .pggan_generator import PGGANGenerator
|
| 6 |
+
from .pggan_discriminator import PGGANDiscriminator
|
| 7 |
+
from .stylegan_generator import StyleGANGenerator
|
| 8 |
+
from .stylegan_discriminator import StyleGANDiscriminator
|
| 9 |
+
from .stylegan2_generator import StyleGAN2Generator
|
| 10 |
+
from .stylegan2_discriminator import StyleGAN2Discriminator
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
'MODEL_ZOO', 'PGGANGenerator', 'PGGANDiscriminator', 'StyleGANGenerator',
|
| 14 |
+
'StyleGANDiscriminator', 'StyleGAN2Generator', 'StyleGAN2Discriminator',
|
| 15 |
+
'build_generator', 'build_discriminator', 'build_model'
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
_GAN_TYPES_ALLOWED = ['pggan', 'stylegan', 'stylegan2']
|
| 19 |
+
_MODULES_ALLOWED = ['generator', 'discriminator']
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def build_generator(gan_type, resolution, **kwargs):
|
| 23 |
+
"""Builds generator by GAN type.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
gan_type: GAN type to which the generator belong.
|
| 27 |
+
resolution: Synthesis resolution.
|
| 28 |
+
**kwargs: Additional arguments to build the generator.
|
| 29 |
+
|
| 30 |
+
Raises:
|
| 31 |
+
ValueError: If the `gan_type` is not supported.
|
| 32 |
+
NotImplementedError: If the `gan_type` is not implemented.
|
| 33 |
+
"""
|
| 34 |
+
if gan_type not in _GAN_TYPES_ALLOWED:
|
| 35 |
+
raise ValueError(f'Invalid GAN type: `{gan_type}`!\n'
|
| 36 |
+
f'Types allowed: {_GAN_TYPES_ALLOWED}.')
|
| 37 |
+
|
| 38 |
+
if gan_type == 'pggan':
|
| 39 |
+
return PGGANGenerator(resolution, **kwargs)
|
| 40 |
+
if gan_type == 'stylegan':
|
| 41 |
+
return StyleGANGenerator(resolution, **kwargs)
|
| 42 |
+
if gan_type == 'stylegan2':
|
| 43 |
+
return StyleGAN2Generator(resolution, **kwargs)
|
| 44 |
+
raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!')
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def build_discriminator(gan_type, resolution, **kwargs):
|
| 48 |
+
"""Builds discriminator by GAN type.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
gan_type: GAN type to which the discriminator belong.
|
| 52 |
+
resolution: Synthesis resolution.
|
| 53 |
+
**kwargs: Additional arguments to build the discriminator.
|
| 54 |
+
|
| 55 |
+
Raises:
|
| 56 |
+
ValueError: If the `gan_type` is not supported.
|
| 57 |
+
NotImplementedError: If the `gan_type` is not implemented.
|
| 58 |
+
"""
|
| 59 |
+
if gan_type not in _GAN_TYPES_ALLOWED:
|
| 60 |
+
raise ValueError(f'Invalid GAN type: `{gan_type}`!\n'
|
| 61 |
+
f'Types allowed: {_GAN_TYPES_ALLOWED}.')
|
| 62 |
+
|
| 63 |
+
if gan_type == 'pggan':
|
| 64 |
+
return PGGANDiscriminator(resolution, **kwargs)
|
| 65 |
+
if gan_type == 'stylegan':
|
| 66 |
+
return StyleGANDiscriminator(resolution, **kwargs)
|
| 67 |
+
if gan_type == 'stylegan2':
|
| 68 |
+
return StyleGAN2Discriminator(resolution, **kwargs)
|
| 69 |
+
raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!')
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def build_model(gan_type, module, resolution, **kwargs):
|
| 73 |
+
"""Builds a GAN module (generator/discriminator/etc).
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
gan_type: GAN type to which the model belong.
|
| 77 |
+
module: GAN module to build, such as generator or discrimiantor.
|
| 78 |
+
resolution: Synthesis resolution.
|
| 79 |
+
**kwargs: Additional arguments to build the discriminator.
|
| 80 |
+
|
| 81 |
+
Raises:
|
| 82 |
+
ValueError: If the `module` is not supported.
|
| 83 |
+
NotImplementedError: If the `module` is not implemented.
|
| 84 |
+
"""
|
| 85 |
+
if module not in _MODULES_ALLOWED:
|
| 86 |
+
raise ValueError(f'Invalid module: `{module}`!\n'
|
| 87 |
+
f'Modules allowed: {_MODULES_ALLOWED}.')
|
| 88 |
+
|
| 89 |
+
if module == 'generator':
|
| 90 |
+
return build_generator(gan_type, resolution, **kwargs)
|
| 91 |
+
if module == 'discriminator':
|
| 92 |
+
return build_discriminator(gan_type, resolution, **kwargs)
|
| 93 |
+
raise NotImplementedError(f'Unsupported module `{module}`!')
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def parse_gan_type(module):
|
| 97 |
+
"""Parses GAN type of a given module.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
module: The module to parse GAN type from.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
A string, indicating the GAN type.
|
| 104 |
+
|
| 105 |
+
Raises:
|
| 106 |
+
ValueError: If the GAN type is unknown.
|
| 107 |
+
"""
|
| 108 |
+
if isinstance(module, (PGGANGenerator, PGGANDiscriminator)):
|
| 109 |
+
return 'pggan'
|
| 110 |
+
if isinstance(module, (StyleGANGenerator, StyleGANDiscriminator)):
|
| 111 |
+
return 'stylegan'
|
| 112 |
+
if isinstance(module, (StyleGAN2Generator, StyleGAN2Discriminator)):
|
| 113 |
+
return 'stylegan2'
|
| 114 |
+
raise ValueError(f'Unable to parse GAN type from type `{type(module)}`!')
|
models/pggan_discriminator.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3.7
|
| 2 |
+
"""Contains the implementation of discriminator described in PGGAN.
|
| 3 |
+
|
| 4 |
+
Paper: https://arxiv.org/pdf/1710.10196.pdf
|
| 5 |
+
|
| 6 |
+
Official TensorFlow implementation:
|
| 7 |
+
https://github.com/tkarras/progressive_growing_of_gans
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
__all__ = ['PGGANDiscriminator']
|
| 17 |
+
|
| 18 |
+
# Resolutions allowed.
|
| 19 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
| 20 |
+
|
| 21 |
+
# Initial resolution.
|
| 22 |
+
_INIT_RES = 4
|
| 23 |
+
|
| 24 |
+
# Default gain factor for weight scaling.
|
| 25 |
+
_WSCALE_GAIN = np.sqrt(2.0)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PGGANDiscriminator(nn.Module):
|
| 29 |
+
"""Defines the discriminator network in PGGAN.
|
| 30 |
+
|
| 31 |
+
NOTE: The discriminator takes images with `RGB` channel order and pixel
|
| 32 |
+
range [-1, 1] as inputs.
|
| 33 |
+
|
| 34 |
+
Settings for the network:
|
| 35 |
+
|
| 36 |
+
(1) resolution: The resolution of the input image.
|
| 37 |
+
(2) image_channels: Number of channels of the input image. (default: 3)
|
| 38 |
+
(3) label_size: Size of the additional label for conditional generation.
|
| 39 |
+
(default: 0)
|
| 40 |
+
(4) fused_scale: Whether to fused `conv2d` and `downsample` together,
|
| 41 |
+
resulting in `conv2d` with strides. (default: False)
|
| 42 |
+
(5) use_wscale: Whether to use weight scaling. (default: True)
|
| 43 |
+
(6) minibatch_std_group_size: Group size for the minibatch standard
|
| 44 |
+
deviation layer. 0 means disable. (default: 16)
|
| 45 |
+
(7) fmaps_base: Factor to control number of feature maps for each layer.
|
| 46 |
+
(default: 16 << 10)
|
| 47 |
+
(8) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self,
|
| 51 |
+
resolution,
|
| 52 |
+
image_channels=3,
|
| 53 |
+
label_size=0,
|
| 54 |
+
fused_scale=False,
|
| 55 |
+
use_wscale=True,
|
| 56 |
+
minibatch_std_group_size=16,
|
| 57 |
+
fmaps_base=16 << 10,
|
| 58 |
+
fmaps_max=512):
|
| 59 |
+
"""Initializes with basic settings.
|
| 60 |
+
|
| 61 |
+
Raises:
|
| 62 |
+
ValueError: If the `resolution` is not supported.
|
| 63 |
+
"""
|
| 64 |
+
super().__init__()
|
| 65 |
+
|
| 66 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
| 67 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
| 68 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
| 69 |
+
|
| 70 |
+
self.init_res = _INIT_RES
|
| 71 |
+
self.init_res_log2 = int(np.log2(self.init_res))
|
| 72 |
+
self.resolution = resolution
|
| 73 |
+
self.final_res_log2 = int(np.log2(self.resolution))
|
| 74 |
+
self.image_channels = image_channels
|
| 75 |
+
self.label_size = label_size
|
| 76 |
+
self.fused_scale = fused_scale
|
| 77 |
+
self.use_wscale = use_wscale
|
| 78 |
+
self.minibatch_std_group_size = minibatch_std_group_size
|
| 79 |
+
self.fmaps_base = fmaps_base
|
| 80 |
+
self.fmaps_max = fmaps_max
|
| 81 |
+
|
| 82 |
+
# Level of detail (used for progressive training).
|
| 83 |
+
self.register_buffer('lod', torch.zeros(()))
|
| 84 |
+
self.pth_to_tf_var_mapping = {'lod': 'lod'}
|
| 85 |
+
|
| 86 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
| 87 |
+
res = 2 ** res_log2
|
| 88 |
+
block_idx = self.final_res_log2 - res_log2
|
| 89 |
+
|
| 90 |
+
# Input convolution layer for each resolution.
|
| 91 |
+
self.add_module(
|
| 92 |
+
f'input{block_idx}',
|
| 93 |
+
ConvBlock(in_channels=self.image_channels,
|
| 94 |
+
out_channels=self.get_nf(res),
|
| 95 |
+
kernel_size=1,
|
| 96 |
+
padding=0,
|
| 97 |
+
use_wscale=self.use_wscale))
|
| 98 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = (
|
| 99 |
+
f'FromRGB_lod{block_idx}/weight')
|
| 100 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = (
|
| 101 |
+
f'FromRGB_lod{block_idx}/bias')
|
| 102 |
+
|
| 103 |
+
# Convolution block for each resolution (except the last one).
|
| 104 |
+
if res != self.init_res:
|
| 105 |
+
self.add_module(
|
| 106 |
+
f'layer{2 * block_idx}',
|
| 107 |
+
ConvBlock(in_channels=self.get_nf(res),
|
| 108 |
+
out_channels=self.get_nf(res),
|
| 109 |
+
use_wscale=self.use_wscale))
|
| 110 |
+
tf_layer0_name = 'Conv0'
|
| 111 |
+
self.add_module(
|
| 112 |
+
f'layer{2 * block_idx + 1}',
|
| 113 |
+
ConvBlock(in_channels=self.get_nf(res),
|
| 114 |
+
out_channels=self.get_nf(res // 2),
|
| 115 |
+
downsample=True,
|
| 116 |
+
fused_scale=self.fused_scale,
|
| 117 |
+
use_wscale=self.use_wscale))
|
| 118 |
+
tf_layer1_name = 'Conv1_down' if self.fused_scale else 'Conv1'
|
| 119 |
+
|
| 120 |
+
# Convolution block for last resolution.
|
| 121 |
+
else:
|
| 122 |
+
self.add_module(
|
| 123 |
+
f'layer{2 * block_idx}',
|
| 124 |
+
ConvBlock(
|
| 125 |
+
in_channels=self.get_nf(res),
|
| 126 |
+
out_channels=self.get_nf(res),
|
| 127 |
+
use_wscale=self.use_wscale,
|
| 128 |
+
minibatch_std_group_size=self.minibatch_std_group_size))
|
| 129 |
+
tf_layer0_name = 'Conv'
|
| 130 |
+
self.add_module(
|
| 131 |
+
f'layer{2 * block_idx + 1}',
|
| 132 |
+
DenseBlock(in_channels=self.get_nf(res) * res * res,
|
| 133 |
+
out_channels=self.get_nf(res // 2),
|
| 134 |
+
use_wscale=self.use_wscale))
|
| 135 |
+
tf_layer1_name = 'Dense0'
|
| 136 |
+
|
| 137 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
|
| 138 |
+
f'{res}x{res}/{tf_layer0_name}/weight')
|
| 139 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
|
| 140 |
+
f'{res}x{res}/{tf_layer0_name}/bias')
|
| 141 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
|
| 142 |
+
f'{res}x{res}/{tf_layer1_name}/weight')
|
| 143 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
|
| 144 |
+
f'{res}x{res}/{tf_layer1_name}/bias')
|
| 145 |
+
|
| 146 |
+
# Final dense block.
|
| 147 |
+
self.add_module(
|
| 148 |
+
f'layer{2 * block_idx + 2}',
|
| 149 |
+
DenseBlock(in_channels=self.get_nf(res // 2),
|
| 150 |
+
out_channels=1 + self.label_size,
|
| 151 |
+
use_wscale=self.use_wscale,
|
| 152 |
+
wscale_gain=1.0,
|
| 153 |
+
activation_type='linear'))
|
| 154 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.weight'] = (
|
| 155 |
+
f'{res}x{res}/Dense1/weight')
|
| 156 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.bias'] = (
|
| 157 |
+
f'{res}x{res}/Dense1/bias')
|
| 158 |
+
|
| 159 |
+
self.downsample = DownsamplingLayer()
|
| 160 |
+
|
| 161 |
+
def get_nf(self, res):
|
| 162 |
+
"""Gets number of feature maps according to current resolution."""
|
| 163 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
| 164 |
+
|
| 165 |
+
def forward(self, image, lod=None, **_unused_kwargs):
|
| 166 |
+
expected_shape = (self.image_channels, self.resolution, self.resolution)
|
| 167 |
+
if image.ndim != 4 or image.shape[1:] != expected_shape:
|
| 168 |
+
raise ValueError(f'The input tensor should be with shape '
|
| 169 |
+
f'[batch_size, channel, height, width], where '
|
| 170 |
+
f'`channel` equals to {self.image_channels}, '
|
| 171 |
+
f'`height`, `width` equal to {self.resolution}!\n'
|
| 172 |
+
f'But `{image.shape}` is received!')
|
| 173 |
+
|
| 174 |
+
lod = self.lod.cpu().tolist() if lod is None else lod
|
| 175 |
+
if lod + self.init_res_log2 > self.final_res_log2:
|
| 176 |
+
raise ValueError(f'Maximum level-of-detail (lod) is '
|
| 177 |
+
f'{self.final_res_log2 - self.init_res_log2}, '
|
| 178 |
+
f'but `{lod}` is received!')
|
| 179 |
+
|
| 180 |
+
lod = self.lod.cpu().tolist()
|
| 181 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
| 182 |
+
block_idx = current_lod = self.final_res_log2 - res_log2
|
| 183 |
+
if current_lod <= lod < current_lod + 1:
|
| 184 |
+
x = self.__getattr__(f'input{block_idx}')(image)
|
| 185 |
+
elif current_lod - 1 < lod < current_lod:
|
| 186 |
+
alpha = lod - np.floor(lod)
|
| 187 |
+
x = (self.__getattr__(f'input{block_idx}')(image) * alpha +
|
| 188 |
+
x * (1 - alpha))
|
| 189 |
+
if lod < current_lod + 1:
|
| 190 |
+
x = self.__getattr__(f'layer{2 * block_idx}')(x)
|
| 191 |
+
x = self.__getattr__(f'layer{2 * block_idx + 1}')(x)
|
| 192 |
+
if lod > current_lod:
|
| 193 |
+
image = self.downsample(image)
|
| 194 |
+
x = self.__getattr__(f'layer{2 * block_idx + 2}')(x)
|
| 195 |
+
return x
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class MiniBatchSTDLayer(nn.Module):
|
| 199 |
+
"""Implements the minibatch standard deviation layer."""
|
| 200 |
+
|
| 201 |
+
def __init__(self, group_size=16, epsilon=1e-8):
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.group_size = group_size
|
| 204 |
+
self.epsilon = epsilon
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
if self.group_size <= 1:
|
| 208 |
+
return x
|
| 209 |
+
group_size = min(self.group_size, x.shape[0]) # [NCHW]
|
| 210 |
+
y = x.view(group_size, -1, x.shape[1], x.shape[2], x.shape[3]) # [GMCHW]
|
| 211 |
+
y = y - torch.mean(y, dim=0, keepdim=True) # [GMCHW]
|
| 212 |
+
y = torch.mean(y ** 2, dim=0) # [MCHW]
|
| 213 |
+
y = torch.sqrt(y + self.epsilon) # [MCHW]
|
| 214 |
+
y = torch.mean(y, dim=[1, 2, 3], keepdim=True) # [M111]
|
| 215 |
+
y = y.repeat(group_size, 1, x.shape[2], x.shape[3]) # [N1HW]
|
| 216 |
+
return torch.cat([x, y], dim=1)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class DownsamplingLayer(nn.Module):
|
| 220 |
+
"""Implements the downsampling layer.
|
| 221 |
+
|
| 222 |
+
Basically, this layer can be used to downsample feature maps with average
|
| 223 |
+
pooling.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
def __init__(self, scale_factor=2):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.scale_factor = scale_factor
|
| 229 |
+
|
| 230 |
+
def forward(self, x):
|
| 231 |
+
if self.scale_factor <= 1:
|
| 232 |
+
return x
|
| 233 |
+
return F.avg_pool2d(x,
|
| 234 |
+
kernel_size=self.scale_factor,
|
| 235 |
+
stride=self.scale_factor,
|
| 236 |
+
padding=0)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class ConvBlock(nn.Module):
|
| 240 |
+
"""Implements the convolutional block.
|
| 241 |
+
|
| 242 |
+
Basically, this block executes minibatch standard deviation layer (if
|
| 243 |
+
needed), convolutional layer, activation layer, and downsampling layer (
|
| 244 |
+
if needed) in sequence.
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(self,
|
| 248 |
+
in_channels,
|
| 249 |
+
out_channels,
|
| 250 |
+
kernel_size=3,
|
| 251 |
+
stride=1,
|
| 252 |
+
padding=1,
|
| 253 |
+
add_bias=True,
|
| 254 |
+
downsample=False,
|
| 255 |
+
fused_scale=False,
|
| 256 |
+
use_wscale=True,
|
| 257 |
+
wscale_gain=_WSCALE_GAIN,
|
| 258 |
+
activation_type='lrelu',
|
| 259 |
+
minibatch_std_group_size=0):
|
| 260 |
+
"""Initializes with block settings.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
in_channels: Number of channels of the input tensor.
|
| 264 |
+
out_channels: Number of channels of the output tensor.
|
| 265 |
+
kernel_size: Size of the convolutional kernels. (default: 3)
|
| 266 |
+
stride: Stride parameter for convolution operation. (default: 1)
|
| 267 |
+
padding: Padding parameter for convolution operation. (default: 1)
|
| 268 |
+
add_bias: Whether to add bias onto the convolutional result.
|
| 269 |
+
(default: True)
|
| 270 |
+
downsample: Whether to downsample the result after convolution.
|
| 271 |
+
(default: False)
|
| 272 |
+
fused_scale: Whether to fused `conv2d` and `downsample` together,
|
| 273 |
+
resulting in `conv2d` with strides. (default: False)
|
| 274 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
| 275 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
| 276 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
| 277 |
+
(default: `lrelu`)
|
| 278 |
+
minibatch_std_group_size: Group size for the minibatch standard
|
| 279 |
+
deviation layer. 0 means disable. (default: 0)
|
| 280 |
+
|
| 281 |
+
Raises:
|
| 282 |
+
NotImplementedError: If the `activation_type` is not supported.
|
| 283 |
+
"""
|
| 284 |
+
super().__init__()
|
| 285 |
+
|
| 286 |
+
if minibatch_std_group_size > 1:
|
| 287 |
+
in_channels = in_channels + 1
|
| 288 |
+
self.mbstd = MiniBatchSTDLayer(group_size=minibatch_std_group_size)
|
| 289 |
+
else:
|
| 290 |
+
self.mbstd = nn.Identity()
|
| 291 |
+
|
| 292 |
+
if downsample and not fused_scale:
|
| 293 |
+
self.downsample = DownsamplingLayer()
|
| 294 |
+
else:
|
| 295 |
+
self.downsample = nn.Identity()
|
| 296 |
+
|
| 297 |
+
if downsample and fused_scale:
|
| 298 |
+
self.use_stride = True
|
| 299 |
+
self.stride = 2
|
| 300 |
+
self.padding = 1
|
| 301 |
+
else:
|
| 302 |
+
self.use_stride = False
|
| 303 |
+
self.stride = stride
|
| 304 |
+
self.padding = padding
|
| 305 |
+
|
| 306 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
| 307 |
+
fan_in = kernel_size * kernel_size * in_channels
|
| 308 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
| 309 |
+
if use_wscale:
|
| 310 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape))
|
| 311 |
+
self.wscale = wscale
|
| 312 |
+
else:
|
| 313 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
|
| 314 |
+
self.wscale = 1.0
|
| 315 |
+
|
| 316 |
+
if add_bias:
|
| 317 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
| 318 |
+
else:
|
| 319 |
+
self.bias = None
|
| 320 |
+
|
| 321 |
+
if activation_type == 'linear':
|
| 322 |
+
self.activate = nn.Identity()
|
| 323 |
+
elif activation_type == 'lrelu':
|
| 324 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 325 |
+
else:
|
| 326 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
| 327 |
+
f'`{activation_type}`!')
|
| 328 |
+
|
| 329 |
+
def forward(self, x):
|
| 330 |
+
x = self.mbstd(x)
|
| 331 |
+
weight = self.weight * self.wscale
|
| 332 |
+
if self.use_stride:
|
| 333 |
+
weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
|
| 334 |
+
weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
|
| 335 |
+
weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25
|
| 336 |
+
x = F.conv2d(x,
|
| 337 |
+
weight=weight,
|
| 338 |
+
bias=self.bias,
|
| 339 |
+
stride=self.stride,
|
| 340 |
+
padding=self.padding)
|
| 341 |
+
x = self.activate(x)
|
| 342 |
+
x = self.downsample(x)
|
| 343 |
+
return x
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class DenseBlock(nn.Module):
|
| 347 |
+
"""Implements the dense block.
|
| 348 |
+
|
| 349 |
+
Basically, this block executes fully-connected layer, and activation layer.
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
def __init__(self,
|
| 353 |
+
in_channels,
|
| 354 |
+
out_channels,
|
| 355 |
+
add_bias=True,
|
| 356 |
+
use_wscale=True,
|
| 357 |
+
wscale_gain=_WSCALE_GAIN,
|
| 358 |
+
activation_type='lrelu'):
|
| 359 |
+
"""Initializes with block settings.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
in_channels: Number of channels of the input tensor.
|
| 363 |
+
out_channels: Number of channels of the output tensor.
|
| 364 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
| 365 |
+
(default: True)
|
| 366 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
| 367 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
| 368 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
| 369 |
+
(default: `lrelu`)
|
| 370 |
+
|
| 371 |
+
Raises:
|
| 372 |
+
NotImplementedError: If the `activation_type` is not supported.
|
| 373 |
+
"""
|
| 374 |
+
super().__init__()
|
| 375 |
+
weight_shape = (out_channels, in_channels)
|
| 376 |
+
wscale = wscale_gain / np.sqrt(in_channels)
|
| 377 |
+
if use_wscale:
|
| 378 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape))
|
| 379 |
+
self.wscale = wscale
|
| 380 |
+
else:
|
| 381 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
|
| 382 |
+
self.wscale = 1.0
|
| 383 |
+
|
| 384 |
+
if add_bias:
|
| 385 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
| 386 |
+
else:
|
| 387 |
+
self.bias = None
|
| 388 |
+
|
| 389 |
+
if activation_type == 'linear':
|
| 390 |
+
self.activate = nn.Identity()
|
| 391 |
+
elif activation_type == 'lrelu':
|
| 392 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 393 |
+
else:
|
| 394 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
| 395 |
+
f'`{activation_type}`!')
|
| 396 |
+
|
| 397 |
+
def forward(self, x):
|
| 398 |
+
if x.ndim != 2:
|
| 399 |
+
x = x.view(x.shape[0], -1)
|
| 400 |
+
x = F.linear(x, weight=self.weight * self.wscale, bias=self.bias)
|
| 401 |
+
x = self.activate(x)
|
| 402 |
+
return x
|
models/pggan_generator.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3.7
|
| 2 |
+
"""Contains the implementation of generator described in PGGAN.
|
| 3 |
+
|
| 4 |
+
Paper: https://arxiv.org/pdf/1710.10196.pdf
|
| 5 |
+
|
| 6 |
+
Official TensorFlow implementation:
|
| 7 |
+
https://github.com/tkarras/progressive_growing_of_gans
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
__all__ = ['PGGANGenerator']
|
| 17 |
+
|
| 18 |
+
# Resolutions allowed.
|
| 19 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
| 20 |
+
|
| 21 |
+
# Initial resolution.
|
| 22 |
+
_INIT_RES = 4
|
| 23 |
+
|
| 24 |
+
# Default gain factor for weight scaling.
|
| 25 |
+
_WSCALE_GAIN = np.sqrt(2.0)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PGGANGenerator(nn.Module):
|
| 29 |
+
"""Defines the generator network in PGGAN.
|
| 30 |
+
|
| 31 |
+
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
| 32 |
+
[-1, 1].
|
| 33 |
+
|
| 34 |
+
Settings for the network:
|
| 35 |
+
|
| 36 |
+
(1) resolution: The resolution of the output image.
|
| 37 |
+
(2) z_space_dim: The dimension of the latent space, Z. (default: 512)
|
| 38 |
+
(3) image_channels: Number of channels of the output image. (default: 3)
|
| 39 |
+
(4) final_tanh: Whether to use `tanh` to control the final pixel range.
|
| 40 |
+
(default: False)
|
| 41 |
+
(5) label_size: Size of the additional label for conditional generation.
|
| 42 |
+
(default: 0)
|
| 43 |
+
(6) fused_scale: Whether to fused `upsample` and `conv2d` together,
|
| 44 |
+
resulting in `conv2d_transpose`. (default: False)
|
| 45 |
+
(7) use_wscale: Whether to use weight scaling. (default: True)
|
| 46 |
+
(8) fmaps_base: Factor to control number of feature maps for each layer.
|
| 47 |
+
(default: 16 << 10)
|
| 48 |
+
(9) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self,
|
| 52 |
+
resolution,
|
| 53 |
+
z_space_dim=512,
|
| 54 |
+
image_channels=3,
|
| 55 |
+
final_tanh=False,
|
| 56 |
+
label_size=0,
|
| 57 |
+
fused_scale=False,
|
| 58 |
+
use_wscale=True,
|
| 59 |
+
fmaps_base=16 << 10,
|
| 60 |
+
fmaps_max=512):
|
| 61 |
+
"""Initializes with basic settings.
|
| 62 |
+
|
| 63 |
+
Raises:
|
| 64 |
+
ValueError: If the `resolution` is not supported.
|
| 65 |
+
"""
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
| 69 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
| 70 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
| 71 |
+
|
| 72 |
+
self.init_res = _INIT_RES
|
| 73 |
+
self.init_res_log2 = int(np.log2(self.init_res))
|
| 74 |
+
self.resolution = resolution
|
| 75 |
+
self.final_res_log2 = int(np.log2(self.resolution))
|
| 76 |
+
self.z_space_dim = z_space_dim
|
| 77 |
+
self.image_channels = image_channels
|
| 78 |
+
self.final_tanh = final_tanh
|
| 79 |
+
self.label_size = label_size
|
| 80 |
+
self.fused_scale = fused_scale
|
| 81 |
+
self.use_wscale = use_wscale
|
| 82 |
+
self.fmaps_base = fmaps_base
|
| 83 |
+
self.fmaps_max = fmaps_max
|
| 84 |
+
|
| 85 |
+
# Number of convolutional layers.
|
| 86 |
+
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
|
| 87 |
+
|
| 88 |
+
# Level of detail (used for progressive training).
|
| 89 |
+
self.register_buffer('lod', torch.zeros(()))
|
| 90 |
+
self.pth_to_tf_var_mapping = {'lod': 'lod'}
|
| 91 |
+
|
| 92 |
+
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
|
| 93 |
+
res = 2 ** res_log2
|
| 94 |
+
block_idx = res_log2 - self.init_res_log2
|
| 95 |
+
|
| 96 |
+
# First convolution layer for each resolution.
|
| 97 |
+
if res == self.init_res:
|
| 98 |
+
self.add_module(
|
| 99 |
+
f'layer{2 * block_idx}',
|
| 100 |
+
ConvBlock(in_channels=self.z_space_dim + self.label_size,
|
| 101 |
+
out_channels=self.get_nf(res),
|
| 102 |
+
kernel_size=self.init_res,
|
| 103 |
+
padding=self.init_res - 1,
|
| 104 |
+
use_wscale=self.use_wscale))
|
| 105 |
+
tf_layer_name = 'Dense'
|
| 106 |
+
else:
|
| 107 |
+
self.add_module(
|
| 108 |
+
f'layer{2 * block_idx}',
|
| 109 |
+
ConvBlock(in_channels=self.get_nf(res // 2),
|
| 110 |
+
out_channels=self.get_nf(res),
|
| 111 |
+
upsample=True,
|
| 112 |
+
fused_scale=self.fused_scale,
|
| 113 |
+
use_wscale=self.use_wscale))
|
| 114 |
+
tf_layer_name = 'Conv0_up' if self.fused_scale else 'Conv0'
|
| 115 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
|
| 116 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
| 117 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
|
| 118 |
+
f'{res}x{res}/{tf_layer_name}/bias')
|
| 119 |
+
|
| 120 |
+
# Second convolution layer for each resolution.
|
| 121 |
+
self.add_module(
|
| 122 |
+
f'layer{2 * block_idx + 1}',
|
| 123 |
+
ConvBlock(in_channels=self.get_nf(res),
|
| 124 |
+
out_channels=self.get_nf(res),
|
| 125 |
+
use_wscale=self.use_wscale))
|
| 126 |
+
tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
|
| 127 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
|
| 128 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
| 129 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
|
| 130 |
+
f'{res}x{res}/{tf_layer_name}/bias')
|
| 131 |
+
|
| 132 |
+
# Output convolution layer for each resolution.
|
| 133 |
+
self.add_module(
|
| 134 |
+
f'output{block_idx}',
|
| 135 |
+
ConvBlock(in_channels=self.get_nf(res),
|
| 136 |
+
out_channels=self.image_channels,
|
| 137 |
+
kernel_size=1,
|
| 138 |
+
padding=0,
|
| 139 |
+
use_wscale=self.use_wscale,
|
| 140 |
+
wscale_gain=1.0,
|
| 141 |
+
activation_type='linear'))
|
| 142 |
+
self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = (
|
| 143 |
+
f'ToRGB_lod{self.final_res_log2 - res_log2}/weight')
|
| 144 |
+
self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = (
|
| 145 |
+
f'ToRGB_lod{self.final_res_log2 - res_log2}/bias')
|
| 146 |
+
|
| 147 |
+
self.upsample = UpsamplingLayer()
|
| 148 |
+
self.final_activate = nn.Tanh() if self.final_tanh else nn.Identity()
|
| 149 |
+
|
| 150 |
+
def get_nf(self, res):
|
| 151 |
+
"""Gets number of feature maps according to current resolution."""
|
| 152 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
| 153 |
+
|
| 154 |
+
def forward(self, z, label=None, lod=None, **_unused_kwargs):
|
| 155 |
+
if z.ndim != 2 or z.shape[1] != self.z_space_dim:
|
| 156 |
+
raise ValueError(f'Input latent code should be with shape '
|
| 157 |
+
f'[batch_size, latent_dim], where '
|
| 158 |
+
f'`latent_dim` equals to {self.z_space_dim}!\n'
|
| 159 |
+
f'But `{z.shape}` is received!')
|
| 160 |
+
z = self.layer0.pixel_norm(z)
|
| 161 |
+
if self.label_size:
|
| 162 |
+
if label is None:
|
| 163 |
+
raise ValueError(f'Model requires an additional label '
|
| 164 |
+
f'(with size {self.label_size}) as input, '
|
| 165 |
+
f'but no label is received!')
|
| 166 |
+
if label.ndim != 2 or label.shape != (z.shape[0], self.label_size):
|
| 167 |
+
raise ValueError(f'Input label should be with shape '
|
| 168 |
+
f'[batch_size, label_size], where '
|
| 169 |
+
f'`batch_size` equals to that of '
|
| 170 |
+
f'latent codes ({z.shape[0]}) and '
|
| 171 |
+
f'`label_size` equals to {self.label_size}!\n'
|
| 172 |
+
f'But `{label.shape}` is received!')
|
| 173 |
+
z = torch.cat((z, label), dim=1)
|
| 174 |
+
|
| 175 |
+
lod = self.lod.cpu().tolist() if lod is None else lod
|
| 176 |
+
if lod + self.init_res_log2 > self.final_res_log2:
|
| 177 |
+
raise ValueError(f'Maximum level-of-detail (lod) is '
|
| 178 |
+
f'{self.final_res_log2 - self.init_res_log2}, '
|
| 179 |
+
f'but `{lod}` is received!')
|
| 180 |
+
|
| 181 |
+
x = z.view(z.shape[0], self.z_space_dim + self.label_size, 1, 1)
|
| 182 |
+
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
|
| 183 |
+
current_lod = self.final_res_log2 - res_log2
|
| 184 |
+
if lod < current_lod + 1:
|
| 185 |
+
block_idx = res_log2 - self.init_res_log2
|
| 186 |
+
x = self.__getattr__(f'layer{2 * block_idx}')(x)
|
| 187 |
+
x = self.__getattr__(f'layer{2 * block_idx + 1}')(x)
|
| 188 |
+
if current_lod - 1 < lod <= current_lod:
|
| 189 |
+
image = self.__getattr__(f'output{block_idx}')(x)
|
| 190 |
+
elif current_lod < lod < current_lod + 1:
|
| 191 |
+
alpha = np.ceil(lod) - lod
|
| 192 |
+
image = (self.__getattr__(f'output{block_idx}')(x) * alpha +
|
| 193 |
+
self.upsample(image) * (1 - alpha))
|
| 194 |
+
elif lod >= current_lod + 1:
|
| 195 |
+
image = self.upsample(image)
|
| 196 |
+
image = self.final_activate(image)
|
| 197 |
+
|
| 198 |
+
results = {
|
| 199 |
+
'z': z,
|
| 200 |
+
'label': label,
|
| 201 |
+
'image': image,
|
| 202 |
+
}
|
| 203 |
+
return results
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class PixelNormLayer(nn.Module):
|
| 207 |
+
"""Implements pixel-wise feature vector normalization layer."""
|
| 208 |
+
|
| 209 |
+
def __init__(self, epsilon=1e-8):
|
| 210 |
+
super().__init__()
|
| 211 |
+
self.eps = epsilon
|
| 212 |
+
|
| 213 |
+
def forward(self, x):
|
| 214 |
+
norm = torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.eps)
|
| 215 |
+
return x / norm
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class UpsamplingLayer(nn.Module):
|
| 219 |
+
"""Implements the upsampling layer.
|
| 220 |
+
|
| 221 |
+
Basically, this layer can be used to upsample feature maps with nearest
|
| 222 |
+
neighbor interpolation.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
def __init__(self, scale_factor=2):
|
| 226 |
+
super().__init__()
|
| 227 |
+
self.scale_factor = scale_factor
|
| 228 |
+
|
| 229 |
+
def forward(self, x):
|
| 230 |
+
if self.scale_factor <= 1:
|
| 231 |
+
return x
|
| 232 |
+
return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class ConvBlock(nn.Module):
|
| 236 |
+
"""Implements the convolutional block.
|
| 237 |
+
|
| 238 |
+
Basically, this block executes pixel-wise normalization layer, upsampling
|
| 239 |
+
layer (if needed), convolutional layer, and activation layer in sequence.
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def __init__(self,
|
| 243 |
+
in_channels,
|
| 244 |
+
out_channels,
|
| 245 |
+
kernel_size=3,
|
| 246 |
+
stride=1,
|
| 247 |
+
padding=1,
|
| 248 |
+
add_bias=True,
|
| 249 |
+
upsample=False,
|
| 250 |
+
fused_scale=False,
|
| 251 |
+
use_wscale=True,
|
| 252 |
+
wscale_gain=_WSCALE_GAIN,
|
| 253 |
+
activation_type='lrelu'):
|
| 254 |
+
"""Initializes with block settings.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
in_channels: Number of channels of the input tensor.
|
| 258 |
+
out_channels: Number of channels of the output tensor.
|
| 259 |
+
kernel_size: Size of the convolutional kernels. (default: 3)
|
| 260 |
+
stride: Stride parameter for convolution operation. (default: 1)
|
| 261 |
+
padding: Padding parameter for convolution operation. (default: 1)
|
| 262 |
+
add_bias: Whether to add bias onto the convolutional result.
|
| 263 |
+
(default: True)
|
| 264 |
+
upsample: Whether to upsample the input tensor before convolution.
|
| 265 |
+
(default: False)
|
| 266 |
+
fused_scale: Whether to fused `upsample` and `conv2d` together,
|
| 267 |
+
resulting in `conv2d_transpose`. (default: False)
|
| 268 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
| 269 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
| 270 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
| 271 |
+
(default: `lrelu`)
|
| 272 |
+
|
| 273 |
+
Raises:
|
| 274 |
+
NotImplementedError: If the `activation_type` is not supported.
|
| 275 |
+
"""
|
| 276 |
+
super().__init__()
|
| 277 |
+
|
| 278 |
+
self.pixel_norm = PixelNormLayer()
|
| 279 |
+
|
| 280 |
+
if upsample and not fused_scale:
|
| 281 |
+
self.upsample = UpsamplingLayer()
|
| 282 |
+
else:
|
| 283 |
+
self.upsample = nn.Identity()
|
| 284 |
+
|
| 285 |
+
if upsample and fused_scale:
|
| 286 |
+
self.use_conv2d_transpose = True
|
| 287 |
+
weight_shape = (in_channels, out_channels, kernel_size, kernel_size)
|
| 288 |
+
self.stride = 2
|
| 289 |
+
self.padding = 1
|
| 290 |
+
else:
|
| 291 |
+
self.use_conv2d_transpose = False
|
| 292 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
| 293 |
+
self.stride = stride
|
| 294 |
+
self.padding = padding
|
| 295 |
+
|
| 296 |
+
fan_in = kernel_size * kernel_size * in_channels
|
| 297 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
| 298 |
+
if use_wscale:
|
| 299 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape))
|
| 300 |
+
self.wscale = wscale
|
| 301 |
+
else:
|
| 302 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
|
| 303 |
+
self.wscale = 1.0
|
| 304 |
+
|
| 305 |
+
if add_bias:
|
| 306 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
| 307 |
+
else:
|
| 308 |
+
self.bias = None
|
| 309 |
+
|
| 310 |
+
if activation_type == 'linear':
|
| 311 |
+
self.activate = nn.Identity()
|
| 312 |
+
elif activation_type == 'lrelu':
|
| 313 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 314 |
+
else:
|
| 315 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
| 316 |
+
f'`{activation_type}`!')
|
| 317 |
+
|
| 318 |
+
def forward(self, x):
|
| 319 |
+
x = self.pixel_norm(x)
|
| 320 |
+
x = self.upsample(x)
|
| 321 |
+
weight = self.weight * self.wscale
|
| 322 |
+
if self.use_conv2d_transpose:
|
| 323 |
+
weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
|
| 324 |
+
weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
|
| 325 |
+
weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1])
|
| 326 |
+
x = F.conv_transpose2d(x,
|
| 327 |
+
weight=weight,
|
| 328 |
+
bias=self.bias,
|
| 329 |
+
stride=self.stride,
|
| 330 |
+
padding=self.padding)
|
| 331 |
+
else:
|
| 332 |
+
x = F.conv2d(x,
|
| 333 |
+
weight=weight,
|
| 334 |
+
bias=self.bias,
|
| 335 |
+
stride=self.stride,
|
| 336 |
+
padding=self.padding)
|
| 337 |
+
x = self.activate(x)
|
| 338 |
+
return x
|
models/stylegan2_discriminator.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3.7
|
| 2 |
+
"""Contains the implementation of discriminator described in StyleGAN2.
|
| 3 |
+
|
| 4 |
+
Compared to that of StyleGAN, the discriminator in StyleGAN2 mainly adds skip
|
| 5 |
+
connections, increases model size and disables progressive growth. This script
|
| 6 |
+
ONLY supports config F in the original paper.
|
| 7 |
+
|
| 8 |
+
Paper: https://arxiv.org/pdf/1912.04958.pdf
|
| 9 |
+
|
| 10 |
+
Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
__all__ = ['StyleGAN2Discriminator']
|
| 20 |
+
|
| 21 |
+
# Resolutions allowed.
|
| 22 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
| 23 |
+
|
| 24 |
+
# Initial resolution.
|
| 25 |
+
_INIT_RES = 4
|
| 26 |
+
|
| 27 |
+
# Architectures allowed.
|
| 28 |
+
_ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin']
|
| 29 |
+
|
| 30 |
+
# Default gain factor for weight scaling.
|
| 31 |
+
_WSCALE_GAIN = 1.0
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class StyleGAN2Discriminator(nn.Module):
|
| 35 |
+
"""Defines the discriminator network in StyleGAN2.
|
| 36 |
+
|
| 37 |
+
NOTE: The discriminator takes images with `RGB` channel order and pixel
|
| 38 |
+
range [-1, 1] as inputs.
|
| 39 |
+
|
| 40 |
+
Settings for the network:
|
| 41 |
+
|
| 42 |
+
(1) resolution: The resolution of the input image.
|
| 43 |
+
(2) image_channels: Number of channels of the input image. (default: 3)
|
| 44 |
+
(3) label_size: Size of the additional label for conditional generation.
|
| 45 |
+
(default: 0)
|
| 46 |
+
(4) architecture: Type of architecture. Support `origin`, `skip`, and
|
| 47 |
+
`resnet`. (default: `resnet`)
|
| 48 |
+
(5) use_wscale: Whether to use weight scaling. (default: True)
|
| 49 |
+
(6) minibatch_std_group_size: Group size for the minibatch standard
|
| 50 |
+
deviation layer. 0 means disable. (default: 4)
|
| 51 |
+
(7) minibatch_std_channels: Number of new channels after the minibatch
|
| 52 |
+
standard deviation layer. (default: 1)
|
| 53 |
+
(8) fmaps_base: Factor to control number of feature maps for each layer.
|
| 54 |
+
(default: 32 << 10)
|
| 55 |
+
(9) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(self,
|
| 59 |
+
resolution,
|
| 60 |
+
image_channels=3,
|
| 61 |
+
label_size=0,
|
| 62 |
+
architecture='resnet',
|
| 63 |
+
use_wscale=True,
|
| 64 |
+
minibatch_std_group_size=4,
|
| 65 |
+
minibatch_std_channels=1,
|
| 66 |
+
fmaps_base=32 << 10,
|
| 67 |
+
fmaps_max=512):
|
| 68 |
+
"""Initializes with basic settings.
|
| 69 |
+
|
| 70 |
+
Raises:
|
| 71 |
+
ValueError: If the `resolution` is not supported, or `architecture`
|
| 72 |
+
is not supported.
|
| 73 |
+
"""
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
| 77 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
| 78 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
| 79 |
+
if architecture not in _ARCHITECTURES_ALLOWED:
|
| 80 |
+
raise ValueError(f'Invalid architecture: `{architecture}`!\n'
|
| 81 |
+
f'Architectures allowed: '
|
| 82 |
+
f'{_ARCHITECTURES_ALLOWED}.')
|
| 83 |
+
|
| 84 |
+
self.init_res = _INIT_RES
|
| 85 |
+
self.init_res_log2 = int(np.log2(self.init_res))
|
| 86 |
+
self.resolution = resolution
|
| 87 |
+
self.final_res_log2 = int(np.log2(self.resolution))
|
| 88 |
+
self.image_channels = image_channels
|
| 89 |
+
self.label_size = label_size
|
| 90 |
+
self.architecture = architecture
|
| 91 |
+
self.use_wscale = use_wscale
|
| 92 |
+
self.minibatch_std_group_size = minibatch_std_group_size
|
| 93 |
+
self.minibatch_std_channels = minibatch_std_channels
|
| 94 |
+
self.fmaps_base = fmaps_base
|
| 95 |
+
self.fmaps_max = fmaps_max
|
| 96 |
+
|
| 97 |
+
self.pth_to_tf_var_mapping = {}
|
| 98 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
| 99 |
+
res = 2 ** res_log2
|
| 100 |
+
block_idx = self.final_res_log2 - res_log2
|
| 101 |
+
|
| 102 |
+
# Input convolution layer for each resolution (if needed).
|
| 103 |
+
if res_log2 == self.final_res_log2 or self.architecture == 'skip':
|
| 104 |
+
self.add_module(
|
| 105 |
+
f'input{block_idx}',
|
| 106 |
+
ConvBlock(in_channels=self.image_channels,
|
| 107 |
+
out_channels=self.get_nf(res),
|
| 108 |
+
kernel_size=1,
|
| 109 |
+
use_wscale=self.use_wscale))
|
| 110 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = (
|
| 111 |
+
f'{res}x{res}/FromRGB/weight')
|
| 112 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = (
|
| 113 |
+
f'{res}x{res}/FromRGB/bias')
|
| 114 |
+
|
| 115 |
+
# Convolution block for each resolution (except the last one).
|
| 116 |
+
if res != self.init_res:
|
| 117 |
+
self.add_module(
|
| 118 |
+
f'layer{2 * block_idx}',
|
| 119 |
+
ConvBlock(in_channels=self.get_nf(res),
|
| 120 |
+
out_channels=self.get_nf(res),
|
| 121 |
+
use_wscale=self.use_wscale))
|
| 122 |
+
tf_layer0_name = 'Conv0'
|
| 123 |
+
self.add_module(
|
| 124 |
+
f'layer{2 * block_idx + 1}',
|
| 125 |
+
ConvBlock(in_channels=self.get_nf(res),
|
| 126 |
+
out_channels=self.get_nf(res // 2),
|
| 127 |
+
scale_factor=2,
|
| 128 |
+
use_wscale=self.use_wscale))
|
| 129 |
+
tf_layer1_name = 'Conv1_down'
|
| 130 |
+
|
| 131 |
+
if self.architecture == 'resnet':
|
| 132 |
+
layer_name = f'skip_layer{block_idx}'
|
| 133 |
+
self.add_module(
|
| 134 |
+
layer_name,
|
| 135 |
+
ConvBlock(in_channels=self.get_nf(res),
|
| 136 |
+
out_channels=self.get_nf(res // 2),
|
| 137 |
+
kernel_size=1,
|
| 138 |
+
add_bias=False,
|
| 139 |
+
scale_factor=2,
|
| 140 |
+
use_wscale=self.use_wscale,
|
| 141 |
+
activation_type='linear'))
|
| 142 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
| 143 |
+
f'{res}x{res}/Skip/weight')
|
| 144 |
+
|
| 145 |
+
# Convolution block for last resolution.
|
| 146 |
+
else:
|
| 147 |
+
self.add_module(
|
| 148 |
+
f'layer{2 * block_idx}',
|
| 149 |
+
ConvBlock(in_channels=self.get_nf(res),
|
| 150 |
+
out_channels=self.get_nf(res),
|
| 151 |
+
use_wscale=self.use_wscale,
|
| 152 |
+
minibatch_std_group_size=minibatch_std_group_size,
|
| 153 |
+
minibatch_std_channels=minibatch_std_channels))
|
| 154 |
+
tf_layer0_name = 'Conv'
|
| 155 |
+
self.add_module(
|
| 156 |
+
f'layer{2 * block_idx + 1}',
|
| 157 |
+
DenseBlock(in_channels=self.get_nf(res) * res * res,
|
| 158 |
+
out_channels=self.get_nf(res // 2),
|
| 159 |
+
use_wscale=self.use_wscale))
|
| 160 |
+
tf_layer1_name = 'Dense0'
|
| 161 |
+
|
| 162 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
|
| 163 |
+
f'{res}x{res}/{tf_layer0_name}/weight')
|
| 164 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
|
| 165 |
+
f'{res}x{res}/{tf_layer0_name}/bias')
|
| 166 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
|
| 167 |
+
f'{res}x{res}/{tf_layer1_name}/weight')
|
| 168 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
|
| 169 |
+
f'{res}x{res}/{tf_layer1_name}/bias')
|
| 170 |
+
|
| 171 |
+
# Final dense block.
|
| 172 |
+
self.add_module(
|
| 173 |
+
f'layer{2 * block_idx + 2}',
|
| 174 |
+
DenseBlock(in_channels=self.get_nf(res // 2),
|
| 175 |
+
out_channels=max(self.label_size, 1),
|
| 176 |
+
use_wscale=self.use_wscale,
|
| 177 |
+
activation_type='linear'))
|
| 178 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.weight'] = (
|
| 179 |
+
f'Output/weight')
|
| 180 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.bias'] = (
|
| 181 |
+
f'Output/bias')
|
| 182 |
+
|
| 183 |
+
if self.architecture == 'skip':
|
| 184 |
+
self.downsample = DownsamplingLayer()
|
| 185 |
+
|
| 186 |
+
def get_nf(self, res):
|
| 187 |
+
"""Gets number of feature maps according to current resolution."""
|
| 188 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
| 189 |
+
|
| 190 |
+
def forward(self, image, label=None, **_unused_kwargs):
|
| 191 |
+
expected_shape = (self.image_channels, self.resolution, self.resolution)
|
| 192 |
+
if image.ndim != 4 or image.shape[1:] != expected_shape:
|
| 193 |
+
raise ValueError(f'The input tensor should be with shape '
|
| 194 |
+
f'[batch_size, channel, height, width], where '
|
| 195 |
+
f'`channel` equals to {self.image_channels}, '
|
| 196 |
+
f'`height`, `width` equal to {self.resolution}!\n'
|
| 197 |
+
f'But `{image.shape}` is received!')
|
| 198 |
+
if self.label_size:
|
| 199 |
+
if label is None:
|
| 200 |
+
raise ValueError(f'Model requires an additional label '
|
| 201 |
+
f'(with size {self.label_size}) as inputs, '
|
| 202 |
+
f'but no label is received!')
|
| 203 |
+
batch_size = image.shape[0]
|
| 204 |
+
if label.ndim != 2 or label.shape != (batch_size, self.label_size):
|
| 205 |
+
raise ValueError(f'Input label should be with shape '
|
| 206 |
+
f'[batch_size, label_size], where '
|
| 207 |
+
f'`batch_size` equals to that of '
|
| 208 |
+
f'images ({image.shape[0]}) and '
|
| 209 |
+
f'`label_size` equals to {self.label_size}!\n'
|
| 210 |
+
f'But `{label.shape}` is received!')
|
| 211 |
+
|
| 212 |
+
x = self.input0(image)
|
| 213 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
| 214 |
+
block_idx = self.final_res_log2 - res_log2
|
| 215 |
+
if self.architecture == 'skip' and block_idx > 0:
|
| 216 |
+
image = self.downsample(image)
|
| 217 |
+
x = x + self.__getattr__(f'input{block_idx}')(image)
|
| 218 |
+
if self.architecture == 'resnet' and res_log2 != self.init_res_log2:
|
| 219 |
+
residual = self.__getattr__(f'skip_layer{block_idx}')(x)
|
| 220 |
+
x = self.__getattr__(f'layer{2 * block_idx}')(x)
|
| 221 |
+
x = self.__getattr__(f'layer{2 * block_idx + 1}')(x)
|
| 222 |
+
if self.architecture == 'resnet' and res_log2 != self.init_res_log2:
|
| 223 |
+
x = (x + residual) / np.sqrt(2.0)
|
| 224 |
+
x = self.__getattr__(f'layer{2 * block_idx + 2}')(x)
|
| 225 |
+
|
| 226 |
+
if self.label_size:
|
| 227 |
+
x = torch.sum(x * label, dim=1, keepdim=True)
|
| 228 |
+
return x
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class MiniBatchSTDLayer(nn.Module):
|
| 232 |
+
"""Implements the minibatch standard deviation layer."""
|
| 233 |
+
|
| 234 |
+
def __init__(self, group_size=4, new_channels=1, epsilon=1e-8):
|
| 235 |
+
super().__init__()
|
| 236 |
+
self.group_size = group_size
|
| 237 |
+
self.new_channels = new_channels
|
| 238 |
+
self.epsilon = epsilon
|
| 239 |
+
|
| 240 |
+
def forward(self, x):
|
| 241 |
+
if self.group_size <= 1:
|
| 242 |
+
return x
|
| 243 |
+
ng = min(self.group_size, x.shape[0])
|
| 244 |
+
nc = self.new_channels
|
| 245 |
+
temp_c = x.shape[1] // nc # [NCHW]
|
| 246 |
+
y = x.view(ng, -1, nc, temp_c, x.shape[2], x.shape[3]) # [GMncHW]
|
| 247 |
+
y = y - torch.mean(y, dim=0, keepdim=True) # [GMncHW]
|
| 248 |
+
y = torch.mean(y ** 2, dim=0) # [MncHW]
|
| 249 |
+
y = torch.sqrt(y + self.epsilon) # [MncHW]
|
| 250 |
+
y = torch.mean(y, dim=[2, 3, 4], keepdim=True) # [Mn111]
|
| 251 |
+
y = torch.mean(y, dim=2) # [Mn11]
|
| 252 |
+
y = y.repeat(ng, 1, x.shape[2], x.shape[3]) # [NnHW]
|
| 253 |
+
return torch.cat([x, y], dim=1)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class DownsamplingLayer(nn.Module):
|
| 257 |
+
"""Implements the downsampling layer.
|
| 258 |
+
|
| 259 |
+
This layer can also be used as filtering by setting `scale_factor` as 1.
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
def __init__(self, scale_factor=2, kernel=(1, 3, 3, 1), extra_padding=0):
|
| 263 |
+
super().__init__()
|
| 264 |
+
assert scale_factor >= 1
|
| 265 |
+
self.scale_factor = scale_factor
|
| 266 |
+
|
| 267 |
+
if extra_padding != 0:
|
| 268 |
+
assert scale_factor == 1
|
| 269 |
+
|
| 270 |
+
if kernel is None:
|
| 271 |
+
kernel = np.ones((scale_factor), dtype=np.float32)
|
| 272 |
+
else:
|
| 273 |
+
kernel = np.array(kernel, dtype=np.float32)
|
| 274 |
+
assert kernel.ndim == 1
|
| 275 |
+
kernel = np.outer(kernel, kernel)
|
| 276 |
+
kernel = kernel / np.sum(kernel)
|
| 277 |
+
assert kernel.ndim == 2
|
| 278 |
+
assert kernel.shape[0] == kernel.shape[1]
|
| 279 |
+
kernel = kernel[np.newaxis, np.newaxis]
|
| 280 |
+
self.register_buffer('kernel', torch.from_numpy(kernel))
|
| 281 |
+
self.kernel = self.kernel.flip(0, 1)
|
| 282 |
+
padding = kernel.shape[2] - scale_factor + extra_padding
|
| 283 |
+
self.padding = ((padding + 1) // 2, padding // 2,
|
| 284 |
+
(padding + 1) // 2, padding // 2)
|
| 285 |
+
|
| 286 |
+
def forward(self, x):
|
| 287 |
+
assert x.ndim == 4
|
| 288 |
+
channels = x.shape[1]
|
| 289 |
+
x = x.view(-1, 1, x.shape[2], x.shape[3])
|
| 290 |
+
x = F.pad(x, self.padding, mode='constant', value=0)
|
| 291 |
+
x = F.conv2d(x, self.kernel, stride=self.scale_factor)
|
| 292 |
+
x = x.view(-1, channels, x.shape[2], x.shape[3])
|
| 293 |
+
return x
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class ConvBlock(nn.Module):
|
| 297 |
+
"""Implements the convolutional block.
|
| 298 |
+
|
| 299 |
+
Basically, this block executes minibatch standard deviation layer (if
|
| 300 |
+
needed), filtering layer (if needed), convolutional layer, and activation
|
| 301 |
+
layer in sequence.
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
def __init__(self,
|
| 305 |
+
in_channels,
|
| 306 |
+
out_channels,
|
| 307 |
+
kernel_size=3,
|
| 308 |
+
add_bias=True,
|
| 309 |
+
scale_factor=1,
|
| 310 |
+
filtering_kernel=(1, 3, 3, 1),
|
| 311 |
+
use_wscale=True,
|
| 312 |
+
wscale_gain=_WSCALE_GAIN,
|
| 313 |
+
lr_mul=1.0,
|
| 314 |
+
activation_type='lrelu',
|
| 315 |
+
minibatch_std_group_size=0,
|
| 316 |
+
minibatch_std_channels=1):
|
| 317 |
+
"""Initializes with block settings.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
in_channels: Number of channels of the input tensor.
|
| 321 |
+
out_channels: Number of channels of the output tensor.
|
| 322 |
+
kernel_size: Size of the convolutional kernels. (default: 3)
|
| 323 |
+
add_bias: Whether to add bias onto the convolutional result.
|
| 324 |
+
(default: True)
|
| 325 |
+
scale_factor: Scale factor for downsampling. `1` means skip
|
| 326 |
+
downsampling. (default: 1)
|
| 327 |
+
filtering_kernel: Kernel used for filtering before downsampling.
|
| 328 |
+
(default: (1, 3, 3, 1))
|
| 329 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
| 330 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
| 331 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
| 332 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
| 333 |
+
(default: `lrelu`)
|
| 334 |
+
minibatch_std_group_size: Group size for the minibatch standard
|
| 335 |
+
deviation layer. 0 means disable. (default: 0)
|
| 336 |
+
minibatch_std_channels: Number of new channels after the minibatch
|
| 337 |
+
standard deviation layer. (default: 1)
|
| 338 |
+
|
| 339 |
+
Raises:
|
| 340 |
+
NotImplementedError: If the `activation_type` is not supported.
|
| 341 |
+
"""
|
| 342 |
+
super().__init__()
|
| 343 |
+
|
| 344 |
+
if minibatch_std_group_size > 1:
|
| 345 |
+
in_channels = in_channels + minibatch_std_channels
|
| 346 |
+
self.mbstd = MiniBatchSTDLayer(group_size=minibatch_std_group_size,
|
| 347 |
+
new_channels=minibatch_std_channels)
|
| 348 |
+
else:
|
| 349 |
+
self.mbstd = nn.Identity()
|
| 350 |
+
|
| 351 |
+
if scale_factor > 1:
|
| 352 |
+
extra_padding = kernel_size - scale_factor
|
| 353 |
+
self.filter = DownsamplingLayer(scale_factor=1,
|
| 354 |
+
kernel=filtering_kernel,
|
| 355 |
+
extra_padding=extra_padding)
|
| 356 |
+
self.stride = scale_factor
|
| 357 |
+
self.padding = 0 # Padding is done in `DownsamplingLayer`.
|
| 358 |
+
else:
|
| 359 |
+
self.filter = nn.Identity()
|
| 360 |
+
assert kernel_size % 2 == 1
|
| 361 |
+
self.stride = 1
|
| 362 |
+
self.padding = kernel_size // 2
|
| 363 |
+
|
| 364 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
| 365 |
+
fan_in = kernel_size * kernel_size * in_channels
|
| 366 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
| 367 |
+
if use_wscale:
|
| 368 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
| 369 |
+
self.wscale = wscale * lr_mul
|
| 370 |
+
else:
|
| 371 |
+
self.weight = nn.Parameter(
|
| 372 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
| 373 |
+
self.wscale = lr_mul
|
| 374 |
+
|
| 375 |
+
if add_bias:
|
| 376 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
| 377 |
+
else:
|
| 378 |
+
self.bias = None
|
| 379 |
+
self.bscale = lr_mul
|
| 380 |
+
|
| 381 |
+
if activation_type == 'linear':
|
| 382 |
+
self.activate = nn.Identity()
|
| 383 |
+
self.activate_scale = 1.0
|
| 384 |
+
elif activation_type == 'lrelu':
|
| 385 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 386 |
+
self.activate_scale = np.sqrt(2.0)
|
| 387 |
+
else:
|
| 388 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
| 389 |
+
f'`{activation_type}`!')
|
| 390 |
+
|
| 391 |
+
def forward(self, x):
|
| 392 |
+
x = self.mbstd(x)
|
| 393 |
+
x = self.filter(x)
|
| 394 |
+
weight = self.weight * self.wscale
|
| 395 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
| 396 |
+
x = F.conv2d(x,
|
| 397 |
+
weight=weight,
|
| 398 |
+
bias=bias,
|
| 399 |
+
stride=self.stride,
|
| 400 |
+
padding=self.padding)
|
| 401 |
+
x = self.activate(x) * self.activate_scale
|
| 402 |
+
return x
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class DenseBlock(nn.Module):
|
| 406 |
+
"""Implements the dense block.
|
| 407 |
+
|
| 408 |
+
Basically, this block executes fully-connected layer and activation layer.
|
| 409 |
+
"""
|
| 410 |
+
|
| 411 |
+
def __init__(self,
|
| 412 |
+
in_channels,
|
| 413 |
+
out_channels,
|
| 414 |
+
add_bias=True,
|
| 415 |
+
use_wscale=True,
|
| 416 |
+
wscale_gain=_WSCALE_GAIN,
|
| 417 |
+
lr_mul=1.0,
|
| 418 |
+
activation_type='lrelu'):
|
| 419 |
+
"""Initializes with block settings.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
in_channels: Number of channels of the input tensor.
|
| 423 |
+
out_channels: Number of channels of the output tensor.
|
| 424 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
| 425 |
+
(default: True)
|
| 426 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
| 427 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
| 428 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
| 429 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
| 430 |
+
(default: `lrelu`)
|
| 431 |
+
|
| 432 |
+
Raises:
|
| 433 |
+
NotImplementedError: If the `activation_type` is not supported.
|
| 434 |
+
"""
|
| 435 |
+
super().__init__()
|
| 436 |
+
weight_shape = (out_channels, in_channels)
|
| 437 |
+
wscale = wscale_gain / np.sqrt(in_channels)
|
| 438 |
+
if use_wscale:
|
| 439 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
| 440 |
+
self.wscale = wscale * lr_mul
|
| 441 |
+
else:
|
| 442 |
+
self.weight = nn.Parameter(
|
| 443 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
| 444 |
+
self.wscale = lr_mul
|
| 445 |
+
|
| 446 |
+
if add_bias:
|
| 447 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
| 448 |
+
else:
|
| 449 |
+
self.bias = None
|
| 450 |
+
self.bscale = lr_mul
|
| 451 |
+
|
| 452 |
+
if activation_type == 'linear':
|
| 453 |
+
self.activate = nn.Identity()
|
| 454 |
+
self.activate_scale = 1.0
|
| 455 |
+
elif activation_type == 'lrelu':
|
| 456 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 457 |
+
self.activate_scale = np.sqrt(2.0)
|
| 458 |
+
else:
|
| 459 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
| 460 |
+
f'`{activation_type}`!')
|
| 461 |
+
|
| 462 |
+
def forward(self, x):
|
| 463 |
+
if x.ndim != 2:
|
| 464 |
+
x = x.view(x.shape[0], -1)
|
| 465 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
| 466 |
+
x = F.linear(x, weight=self.weight * self.wscale, bias=bias)
|
| 467 |
+
x = self.activate(x) * self.activate_scale
|
| 468 |
+
return x
|
models/stylegan2_generator.py
ADDED
|
@@ -0,0 +1,996 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3.7
|
| 2 |
+
"""Contains the implementation of generator described in StyleGAN2.
|
| 3 |
+
|
| 4 |
+
Compared to that of StyleGAN, the generator in StyleGAN2 mainly introduces style
|
| 5 |
+
demodulation, adds skip connections, increases model size, and disables
|
| 6 |
+
progressive growth. This script ONLY supports config F in the original paper.
|
| 7 |
+
|
| 8 |
+
Paper: https://arxiv.org/pdf/1912.04958.pdf
|
| 9 |
+
|
| 10 |
+
Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
from .sync_op import all_gather
|
| 20 |
+
|
| 21 |
+
__all__ = ['StyleGAN2Generator']
|
| 22 |
+
|
| 23 |
+
# Resolutions allowed.
|
| 24 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
| 25 |
+
|
| 26 |
+
# Initial resolution.
|
| 27 |
+
_INIT_RES = 4
|
| 28 |
+
|
| 29 |
+
# Architectures allowed.
|
| 30 |
+
_ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin']
|
| 31 |
+
|
| 32 |
+
# Default gain factor for weight scaling.
|
| 33 |
+
_WSCALE_GAIN = 1.0
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class StyleGAN2Generator(nn.Module):
|
| 37 |
+
"""Defines the generator network in StyleGAN2.
|
| 38 |
+
|
| 39 |
+
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
| 40 |
+
[-1, 1].
|
| 41 |
+
|
| 42 |
+
Settings for the mapping network:
|
| 43 |
+
|
| 44 |
+
(1) z_space_dim: Dimension of the input latent space, Z. (default: 512)
|
| 45 |
+
(2) w_space_dim: Dimension of the outout latent space, W. (default: 512)
|
| 46 |
+
(3) label_size: Size of the additional label for conditional generation.
|
| 47 |
+
(default: 0)
|
| 48 |
+
(4)mapping_layers: Number of layers of the mapping network. (default: 8)
|
| 49 |
+
(5) mapping_fmaps: Number of hidden channels of the mapping network.
|
| 50 |
+
(default: 512)
|
| 51 |
+
(6) mapping_lr_mul: Learning rate multiplier for the mapping network.
|
| 52 |
+
(default: 0.01)
|
| 53 |
+
(7) repeat_w: Repeat w-code for different layers.
|
| 54 |
+
|
| 55 |
+
Settings for the synthesis network:
|
| 56 |
+
|
| 57 |
+
(1) resolution: The resolution of the output image.
|
| 58 |
+
(2) image_channels: Number of channels of the output image. (default: 3)
|
| 59 |
+
(3) final_tanh: Whether to use `tanh` to control the final pixel range.
|
| 60 |
+
(default: False)
|
| 61 |
+
(4) const_input: Whether to use a constant in the first convolutional layer.
|
| 62 |
+
(default: True)
|
| 63 |
+
(5) architecture: Type of architecture. Support `origin`, `skip`, and
|
| 64 |
+
`resnet`. (default: `resnet`)
|
| 65 |
+
(6) fused_modulate: Whether to fuse `style_modulate` and `conv2d` together.
|
| 66 |
+
(default: True)
|
| 67 |
+
(7) demodulate: Whether to perform style demodulation. (default: True)
|
| 68 |
+
(8) use_wscale: Whether to use weight scaling. (default: True)
|
| 69 |
+
(9) fmaps_base: Factor to control number of feature maps for each layer.
|
| 70 |
+
(default: 16 << 10)
|
| 71 |
+
(10) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self,
|
| 75 |
+
resolution,
|
| 76 |
+
z_space_dim=512,
|
| 77 |
+
w_space_dim=512,
|
| 78 |
+
label_size=0,
|
| 79 |
+
mapping_layers=8,
|
| 80 |
+
mapping_fmaps=512,
|
| 81 |
+
mapping_lr_mul=0.01,
|
| 82 |
+
repeat_w=True,
|
| 83 |
+
image_channels=3,
|
| 84 |
+
final_tanh=False,
|
| 85 |
+
const_input=True,
|
| 86 |
+
architecture='skip',
|
| 87 |
+
fused_modulate=True,
|
| 88 |
+
demodulate=True,
|
| 89 |
+
use_wscale=True,
|
| 90 |
+
fmaps_base=32 << 10,
|
| 91 |
+
fmaps_max=512):
|
| 92 |
+
"""Initializes with basic settings.
|
| 93 |
+
|
| 94 |
+
Raises:
|
| 95 |
+
ValueError: If the `resolution` is not supported, or `architecture`
|
| 96 |
+
is not supported.
|
| 97 |
+
"""
|
| 98 |
+
super().__init__()
|
| 99 |
+
|
| 100 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
| 101 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
| 102 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
| 103 |
+
if architecture not in _ARCHITECTURES_ALLOWED:
|
| 104 |
+
raise ValueError(f'Invalid architecture: `{architecture}`!\n'
|
| 105 |
+
f'Architectures allowed: '
|
| 106 |
+
f'{_ARCHITECTURES_ALLOWED}.')
|
| 107 |
+
|
| 108 |
+
self.init_res = _INIT_RES
|
| 109 |
+
self.resolution = resolution
|
| 110 |
+
self.z_space_dim = z_space_dim
|
| 111 |
+
self.w_space_dim = w_space_dim
|
| 112 |
+
self.label_size = label_size
|
| 113 |
+
self.mapping_layers = mapping_layers
|
| 114 |
+
self.mapping_fmaps = mapping_fmaps
|
| 115 |
+
self.mapping_lr_mul = mapping_lr_mul
|
| 116 |
+
self.repeat_w = repeat_w
|
| 117 |
+
self.image_channels = image_channels
|
| 118 |
+
self.final_tanh = final_tanh
|
| 119 |
+
self.const_input = const_input
|
| 120 |
+
self.architecture = architecture
|
| 121 |
+
self.fused_modulate = fused_modulate
|
| 122 |
+
self.demodulate = demodulate
|
| 123 |
+
self.use_wscale = use_wscale
|
| 124 |
+
self.fmaps_base = fmaps_base
|
| 125 |
+
self.fmaps_max = fmaps_max
|
| 126 |
+
|
| 127 |
+
self.num_layers = int(np.log2(self.resolution // self.init_res * 2)) * 2
|
| 128 |
+
|
| 129 |
+
if self.repeat_w:
|
| 130 |
+
self.mapping_space_dim = self.w_space_dim
|
| 131 |
+
else:
|
| 132 |
+
self.mapping_space_dim = self.w_space_dim * self.num_layers
|
| 133 |
+
self.mapping = MappingModule(input_space_dim=self.z_space_dim,
|
| 134 |
+
hidden_space_dim=self.mapping_fmaps,
|
| 135 |
+
final_space_dim=self.mapping_space_dim,
|
| 136 |
+
label_size=self.label_size,
|
| 137 |
+
num_layers=self.mapping_layers,
|
| 138 |
+
use_wscale=self.use_wscale,
|
| 139 |
+
lr_mul=self.mapping_lr_mul)
|
| 140 |
+
|
| 141 |
+
self.truncation = TruncationModule(w_space_dim=self.w_space_dim,
|
| 142 |
+
num_layers=self.num_layers,
|
| 143 |
+
repeat_w=self.repeat_w)
|
| 144 |
+
|
| 145 |
+
self.synthesis = SynthesisModule(resolution=self.resolution,
|
| 146 |
+
init_resolution=self.init_res,
|
| 147 |
+
w_space_dim=self.w_space_dim,
|
| 148 |
+
image_channels=self.image_channels,
|
| 149 |
+
final_tanh=self.final_tanh,
|
| 150 |
+
const_input=self.const_input,
|
| 151 |
+
architecture=self.architecture,
|
| 152 |
+
fused_modulate=self.fused_modulate,
|
| 153 |
+
demodulate=self.demodulate,
|
| 154 |
+
use_wscale=self.use_wscale,
|
| 155 |
+
fmaps_base=self.fmaps_base,
|
| 156 |
+
fmaps_max=self.fmaps_max)
|
| 157 |
+
|
| 158 |
+
self.pth_to_tf_var_mapping = {}
|
| 159 |
+
for key, val in self.mapping.pth_to_tf_var_mapping.items():
|
| 160 |
+
self.pth_to_tf_var_mapping[f'mapping.{key}'] = val
|
| 161 |
+
for key, val in self.truncation.pth_to_tf_var_mapping.items():
|
| 162 |
+
self.pth_to_tf_var_mapping[f'truncation.{key}'] = val
|
| 163 |
+
for key, val in self.synthesis.pth_to_tf_var_mapping.items():
|
| 164 |
+
self.pth_to_tf_var_mapping[f'synthesis.{key}'] = val
|
| 165 |
+
|
| 166 |
+
def forward(self,
|
| 167 |
+
z,
|
| 168 |
+
label=None,
|
| 169 |
+
w_moving_decay=0.995,
|
| 170 |
+
style_mixing_prob=0.9,
|
| 171 |
+
trunc_psi=None,
|
| 172 |
+
trunc_layers=None,
|
| 173 |
+
randomize_noise=False,
|
| 174 |
+
**_unused_kwargs):
|
| 175 |
+
mapping_results = self.mapping(z, label)
|
| 176 |
+
w = mapping_results['w']
|
| 177 |
+
|
| 178 |
+
if self.training and w_moving_decay < 1:
|
| 179 |
+
batch_w_avg = all_gather(w).mean(dim=0)
|
| 180 |
+
self.truncation.w_avg.copy_(
|
| 181 |
+
self.truncation.w_avg * w_moving_decay +
|
| 182 |
+
batch_w_avg * (1 - w_moving_decay))
|
| 183 |
+
|
| 184 |
+
if self.training and style_mixing_prob > 0:
|
| 185 |
+
new_z = torch.randn_like(z)
|
| 186 |
+
new_w = self.mapping(new_z, label)['w']
|
| 187 |
+
if np.random.uniform() < style_mixing_prob:
|
| 188 |
+
mixing_cutoff = np.random.randint(1, self.num_layers)
|
| 189 |
+
w = self.truncation(w)
|
| 190 |
+
new_w = self.truncation(new_w)
|
| 191 |
+
w[:, :mixing_cutoff] = new_w[:, :mixing_cutoff]
|
| 192 |
+
|
| 193 |
+
wp = self.truncation(w, trunc_psi, trunc_layers)
|
| 194 |
+
synthesis_results = self.synthesis(wp, randomize_noise)
|
| 195 |
+
|
| 196 |
+
return {**mapping_results, **synthesis_results}
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class MappingModule(nn.Module):
|
| 200 |
+
"""Implements the latent space mapping module.
|
| 201 |
+
|
| 202 |
+
Basically, this module executes several dense layers in sequence.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def __init__(self,
|
| 206 |
+
input_space_dim=512,
|
| 207 |
+
hidden_space_dim=512,
|
| 208 |
+
final_space_dim=512,
|
| 209 |
+
label_size=0,
|
| 210 |
+
num_layers=8,
|
| 211 |
+
normalize_input=True,
|
| 212 |
+
use_wscale=True,
|
| 213 |
+
lr_mul=0.01):
|
| 214 |
+
super().__init__()
|
| 215 |
+
|
| 216 |
+
self.input_space_dim = input_space_dim
|
| 217 |
+
self.hidden_space_dim = hidden_space_dim
|
| 218 |
+
self.final_space_dim = final_space_dim
|
| 219 |
+
self.label_size = label_size
|
| 220 |
+
self.num_layers = num_layers
|
| 221 |
+
self.normalize_input = normalize_input
|
| 222 |
+
self.use_wscale = use_wscale
|
| 223 |
+
self.lr_mul = lr_mul
|
| 224 |
+
|
| 225 |
+
self.norm = PixelNormLayer() if self.normalize_input else nn.Identity()
|
| 226 |
+
|
| 227 |
+
self.pth_to_tf_var_mapping = {}
|
| 228 |
+
for i in range(num_layers):
|
| 229 |
+
dim_mul = 2 if label_size else 1
|
| 230 |
+
in_channels = (input_space_dim * dim_mul if i == 0 else
|
| 231 |
+
hidden_space_dim)
|
| 232 |
+
out_channels = (final_space_dim if i == (num_layers - 1) else
|
| 233 |
+
hidden_space_dim)
|
| 234 |
+
self.add_module(f'dense{i}',
|
| 235 |
+
DenseBlock(in_channels=in_channels,
|
| 236 |
+
out_channels=out_channels,
|
| 237 |
+
use_wscale=self.use_wscale,
|
| 238 |
+
lr_mul=self.lr_mul))
|
| 239 |
+
self.pth_to_tf_var_mapping[f'dense{i}.weight'] = f'Dense{i}/weight'
|
| 240 |
+
self.pth_to_tf_var_mapping[f'dense{i}.bias'] = f'Dense{i}/bias'
|
| 241 |
+
if label_size:
|
| 242 |
+
self.label_weight = nn.Parameter(
|
| 243 |
+
torch.randn(label_size, input_space_dim))
|
| 244 |
+
self.pth_to_tf_var_mapping[f'label_weight'] = f'LabelConcat/weight'
|
| 245 |
+
|
| 246 |
+
def forward(self, z, label=None):
|
| 247 |
+
if z.ndim != 2 or z.shape[1] != self.input_space_dim:
|
| 248 |
+
raise ValueError(f'Input latent code should be with shape '
|
| 249 |
+
f'[batch_size, input_dim], where '
|
| 250 |
+
f'`input_dim` equals to {self.input_space_dim}!\n'
|
| 251 |
+
f'But `{z.shape}` is received!')
|
| 252 |
+
if self.label_size:
|
| 253 |
+
if label is None:
|
| 254 |
+
raise ValueError(f'Model requires an additional label '
|
| 255 |
+
f'(with size {self.label_size}) as input, '
|
| 256 |
+
f'but no label is received!')
|
| 257 |
+
if label.ndim != 2 or label.shape != (z.shape[0], self.label_size):
|
| 258 |
+
raise ValueError(f'Input label should be with shape '
|
| 259 |
+
f'[batch_size, label_size], where '
|
| 260 |
+
f'`batch_size` equals to that of '
|
| 261 |
+
f'latent codes ({z.shape[0]}) and '
|
| 262 |
+
f'`label_size` equals to {self.label_size}!\n'
|
| 263 |
+
f'But `{label.shape}` is received!')
|
| 264 |
+
embedding = torch.matmul(label, self.label_weight)
|
| 265 |
+
z = torch.cat((z, embedding), dim=1)
|
| 266 |
+
|
| 267 |
+
z = self.norm(z)
|
| 268 |
+
w = z
|
| 269 |
+
for i in range(self.num_layers):
|
| 270 |
+
w = self.__getattr__(f'dense{i}')(w)
|
| 271 |
+
results = {
|
| 272 |
+
'z': z,
|
| 273 |
+
'label': label,
|
| 274 |
+
'w': w,
|
| 275 |
+
}
|
| 276 |
+
if self.label_size:
|
| 277 |
+
results['embedding'] = embedding
|
| 278 |
+
return results
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class TruncationModule(nn.Module):
|
| 282 |
+
"""Implements the truncation module.
|
| 283 |
+
|
| 284 |
+
Truncation is executed as follows:
|
| 285 |
+
|
| 286 |
+
For layers in range [0, truncation_layers), the truncated w-code is computed
|
| 287 |
+
as
|
| 288 |
+
|
| 289 |
+
w_new = w_avg + (w - w_avg) * truncation_psi
|
| 290 |
+
|
| 291 |
+
To disable truncation, please set
|
| 292 |
+
(1) truncation_psi = 1.0 (None) OR
|
| 293 |
+
(2) truncation_layers = 0 (None)
|
| 294 |
+
|
| 295 |
+
NOTE: The returned tensor is layer-wise style codes.
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
def __init__(self, w_space_dim, num_layers, repeat_w=True):
|
| 299 |
+
super().__init__()
|
| 300 |
+
|
| 301 |
+
self.num_layers = num_layers
|
| 302 |
+
self.w_space_dim = w_space_dim
|
| 303 |
+
self.repeat_w = repeat_w
|
| 304 |
+
|
| 305 |
+
if self.repeat_w:
|
| 306 |
+
self.register_buffer('w_avg', torch.zeros(w_space_dim))
|
| 307 |
+
else:
|
| 308 |
+
self.register_buffer('w_avg', torch.zeros(num_layers * w_space_dim))
|
| 309 |
+
self.pth_to_tf_var_mapping = {'w_avg': 'dlatent_avg'}
|
| 310 |
+
|
| 311 |
+
def forward(self, w, trunc_psi=None, trunc_layers=None):
|
| 312 |
+
if w.ndim == 2:
|
| 313 |
+
if self.repeat_w and w.shape[1] == self.w_space_dim:
|
| 314 |
+
w = w.view(-1, 1, self.w_space_dim)
|
| 315 |
+
wp = w.repeat(1, self.num_layers, 1)
|
| 316 |
+
else:
|
| 317 |
+
assert w.shape[1] == self.w_space_dim * self.num_layers
|
| 318 |
+
wp = w.view(-1, self.num_layers, self.w_space_dim)
|
| 319 |
+
else:
|
| 320 |
+
wp = w
|
| 321 |
+
assert wp.ndim == 3
|
| 322 |
+
assert wp.shape[1:] == (self.num_layers, self.w_space_dim)
|
| 323 |
+
|
| 324 |
+
trunc_psi = 1.0 if trunc_psi is None else trunc_psi
|
| 325 |
+
trunc_layers = 0 if trunc_layers is None else trunc_layers
|
| 326 |
+
if trunc_psi < 1.0 and trunc_layers > 0:
|
| 327 |
+
layer_idx = np.arange(self.num_layers).reshape(1, -1, 1)
|
| 328 |
+
coefs = np.ones_like(layer_idx, dtype=np.float32)
|
| 329 |
+
coefs[layer_idx < trunc_layers] *= trunc_psi
|
| 330 |
+
coefs = torch.from_numpy(coefs).to(wp)
|
| 331 |
+
w_avg = self.w_avg.view(1, -1, self.w_space_dim)
|
| 332 |
+
wp = w_avg + (wp - w_avg) * coefs
|
| 333 |
+
return wp
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class SynthesisModule(nn.Module):
|
| 337 |
+
"""Implements the image synthesis module.
|
| 338 |
+
|
| 339 |
+
Basically, this module executes several convolutional layers in sequence.
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
def __init__(self,
|
| 343 |
+
resolution=1024,
|
| 344 |
+
init_resolution=4,
|
| 345 |
+
w_space_dim=512,
|
| 346 |
+
image_channels=3,
|
| 347 |
+
final_tanh=False,
|
| 348 |
+
const_input=True,
|
| 349 |
+
architecture='skip',
|
| 350 |
+
fused_modulate=True,
|
| 351 |
+
demodulate=True,
|
| 352 |
+
use_wscale=True,
|
| 353 |
+
fmaps_base=32 << 10,
|
| 354 |
+
fmaps_max=512):
|
| 355 |
+
super().__init__()
|
| 356 |
+
|
| 357 |
+
self.init_res = init_resolution
|
| 358 |
+
self.init_res_log2 = int(np.log2(self.init_res))
|
| 359 |
+
self.resolution = resolution
|
| 360 |
+
self.final_res_log2 = int(np.log2(self.resolution))
|
| 361 |
+
self.w_space_dim = w_space_dim
|
| 362 |
+
self.image_channels = image_channels
|
| 363 |
+
self.final_tanh = final_tanh
|
| 364 |
+
self.const_input = const_input
|
| 365 |
+
self.architecture = architecture
|
| 366 |
+
self.fused_modulate = fused_modulate
|
| 367 |
+
self.demodulate = demodulate
|
| 368 |
+
self.use_wscale = use_wscale
|
| 369 |
+
self.fmaps_base = fmaps_base
|
| 370 |
+
self.fmaps_max = fmaps_max
|
| 371 |
+
|
| 372 |
+
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
|
| 373 |
+
|
| 374 |
+
self.pth_to_tf_var_mapping = {}
|
| 375 |
+
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
|
| 376 |
+
res = 2 ** res_log2
|
| 377 |
+
block_idx = res_log2 - self.init_res_log2
|
| 378 |
+
|
| 379 |
+
# First convolution layer for each resolution.
|
| 380 |
+
if res == self.init_res:
|
| 381 |
+
if self.const_input:
|
| 382 |
+
self.add_module(f'early_layer',
|
| 383 |
+
InputBlock(init_resolution=self.init_res,
|
| 384 |
+
channels=self.get_nf(res)))
|
| 385 |
+
self.pth_to_tf_var_mapping[f'early_layer.const'] = (
|
| 386 |
+
f'{res}x{res}/Const/const')
|
| 387 |
+
else:
|
| 388 |
+
self.add_module(f'early_layer',
|
| 389 |
+
DenseBlock(in_channels=self.w_space_dim,
|
| 390 |
+
out_channels=self.get_nf(res),
|
| 391 |
+
use_wscale=self.use_wscale))
|
| 392 |
+
self.pth_to_tf_var_mapping[f'early_layer.weight'] = (
|
| 393 |
+
f'{res}x{res}/Dense/weight')
|
| 394 |
+
self.pth_to_tf_var_mapping[f'early_layer.bias'] = (
|
| 395 |
+
f'{res}x{res}/Dense/bias')
|
| 396 |
+
else:
|
| 397 |
+
layer_name = f'layer{2 * block_idx - 1}'
|
| 398 |
+
self.add_module(
|
| 399 |
+
layer_name,
|
| 400 |
+
ModulateConvBlock(in_channels=self.get_nf(res // 2),
|
| 401 |
+
out_channels=self.get_nf(res),
|
| 402 |
+
resolution=res,
|
| 403 |
+
w_space_dim=self.w_space_dim,
|
| 404 |
+
scale_factor=2,
|
| 405 |
+
fused_modulate=self.fused_modulate,
|
| 406 |
+
demodulate=self.demodulate,
|
| 407 |
+
use_wscale=self.use_wscale))
|
| 408 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
| 409 |
+
f'{res}x{res}/Conv0_up/weight')
|
| 410 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
| 411 |
+
f'{res}x{res}/Conv0_up/bias')
|
| 412 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
|
| 413 |
+
f'{res}x{res}/Conv0_up/mod_weight')
|
| 414 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
|
| 415 |
+
f'{res}x{res}/Conv0_up/mod_bias')
|
| 416 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = (
|
| 417 |
+
f'{res}x{res}/Conv0_up/noise_strength')
|
| 418 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = (
|
| 419 |
+
f'noise{2 * block_idx - 1}')
|
| 420 |
+
|
| 421 |
+
if self.architecture == 'resnet':
|
| 422 |
+
layer_name = f'layer{2 * block_idx - 1}'
|
| 423 |
+
self.add_module(
|
| 424 |
+
layer_name,
|
| 425 |
+
ConvBlock(in_channels=self.get_nf(res // 2),
|
| 426 |
+
out_channels=self.get_nf(res),
|
| 427 |
+
kernel_size=1,
|
| 428 |
+
add_bias=False,
|
| 429 |
+
scale_factor=2,
|
| 430 |
+
use_wscale=self.use_wscale,
|
| 431 |
+
activation_type='linear'))
|
| 432 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
| 433 |
+
f'{res}x{res}/Skip/weight')
|
| 434 |
+
|
| 435 |
+
# Second convolution layer for each resolution.
|
| 436 |
+
layer_name = f'layer{2 * block_idx}'
|
| 437 |
+
self.add_module(
|
| 438 |
+
layer_name,
|
| 439 |
+
ModulateConvBlock(in_channels=self.get_nf(res),
|
| 440 |
+
out_channels=self.get_nf(res),
|
| 441 |
+
resolution=res,
|
| 442 |
+
w_space_dim=self.w_space_dim,
|
| 443 |
+
fused_modulate=self.fused_modulate,
|
| 444 |
+
demodulate=self.demodulate,
|
| 445 |
+
use_wscale=self.use_wscale))
|
| 446 |
+
tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
|
| 447 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
| 448 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
| 449 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
| 450 |
+
f'{res}x{res}/{tf_layer_name}/bias')
|
| 451 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
|
| 452 |
+
f'{res}x{res}/{tf_layer_name}/mod_weight')
|
| 453 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
|
| 454 |
+
f'{res}x{res}/{tf_layer_name}/mod_bias')
|
| 455 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = (
|
| 456 |
+
f'{res}x{res}/{tf_layer_name}/noise_strength')
|
| 457 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = (
|
| 458 |
+
f'noise{2 * block_idx}')
|
| 459 |
+
|
| 460 |
+
# Output convolution layer for each resolution (if needed).
|
| 461 |
+
if res_log2 == self.final_res_log2 or self.architecture == 'skip':
|
| 462 |
+
layer_name = f'output{block_idx}'
|
| 463 |
+
self.add_module(
|
| 464 |
+
layer_name,
|
| 465 |
+
ModulateConvBlock(in_channels=self.get_nf(res),
|
| 466 |
+
out_channels=image_channels,
|
| 467 |
+
resolution=res,
|
| 468 |
+
w_space_dim=self.w_space_dim,
|
| 469 |
+
kernel_size=1,
|
| 470 |
+
fused_modulate=self.fused_modulate,
|
| 471 |
+
demodulate=False,
|
| 472 |
+
use_wscale=self.use_wscale,
|
| 473 |
+
add_noise=False,
|
| 474 |
+
activation_type='linear'))
|
| 475 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
| 476 |
+
f'{res}x{res}/ToRGB/weight')
|
| 477 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
| 478 |
+
f'{res}x{res}/ToRGB/bias')
|
| 479 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
|
| 480 |
+
f'{res}x{res}/ToRGB/mod_weight')
|
| 481 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
|
| 482 |
+
f'{res}x{res}/ToRGB/mod_bias')
|
| 483 |
+
|
| 484 |
+
if self.architecture == 'skip':
|
| 485 |
+
self.upsample = UpsamplingLayer()
|
| 486 |
+
self.final_activate = nn.Tanh() if final_tanh else nn.Identity()
|
| 487 |
+
|
| 488 |
+
def get_nf(self, res):
|
| 489 |
+
"""Gets number of feature maps according to current resolution."""
|
| 490 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
| 491 |
+
|
| 492 |
+
def forward(self, wp, randomize_noise=False):
|
| 493 |
+
if wp.ndim != 3 or wp.shape[1:] != (self.num_layers, self.w_space_dim):
|
| 494 |
+
raise ValueError(f'Input tensor should be with shape '
|
| 495 |
+
f'[batch_size, num_layers, w_space_dim], where '
|
| 496 |
+
f'`num_layers` equals to {self.num_layers}, and '
|
| 497 |
+
f'`w_space_dim` equals to {self.w_space_dim}!\n'
|
| 498 |
+
f'But `{wp.shape}` is received!')
|
| 499 |
+
|
| 500 |
+
results = {'wp': wp}
|
| 501 |
+
x = self.early_layer(wp[:, 0])
|
| 502 |
+
if self.architecture == 'origin':
|
| 503 |
+
for layer_idx in range(self.num_layers - 1):
|
| 504 |
+
x, style = self.__getattr__(f'layer{layer_idx}')(
|
| 505 |
+
x, wp[:, layer_idx], randomize_noise)
|
| 506 |
+
results[f'style{layer_idx:02d}'] = style
|
| 507 |
+
image, style = self.__getattr__(f'output{layer_idx // 2}')(
|
| 508 |
+
x, wp[:, layer_idx + 1])
|
| 509 |
+
results[f'output_style{layer_idx // 2}'] = style
|
| 510 |
+
elif self.architecture == 'skip':
|
| 511 |
+
for layer_idx in range(self.num_layers - 1):
|
| 512 |
+
x, style = self.__getattr__(f'layer{layer_idx}')(
|
| 513 |
+
x, wp[:, layer_idx], randomize_noise)
|
| 514 |
+
results[f'style{layer_idx:02d}'] = style
|
| 515 |
+
if layer_idx % 2 == 0:
|
| 516 |
+
temp, style = self.__getattr__(f'output{layer_idx // 2}')(
|
| 517 |
+
x, wp[:, layer_idx + 1])
|
| 518 |
+
results[f'output_style{layer_idx // 2}'] = style
|
| 519 |
+
if layer_idx == 0:
|
| 520 |
+
image = temp
|
| 521 |
+
else:
|
| 522 |
+
image = temp + self.upsample(image)
|
| 523 |
+
elif self.architecture == 'resnet':
|
| 524 |
+
x, style = self.layer0(x)
|
| 525 |
+
results[f'style00'] = style
|
| 526 |
+
for layer_idx in range(1, self.num_layers - 1, 2):
|
| 527 |
+
residual = self.__getattr__(f'skip_layer{layer_idx // 2}')(x)
|
| 528 |
+
x, style = self.__getattr__(f'layer{layer_idx}')(
|
| 529 |
+
x, wp[:, layer_idx], randomize_noise)
|
| 530 |
+
results[f'style{layer_idx:02d}'] = style
|
| 531 |
+
x, style = self.__getattr__(f'layer{layer_idx + 1}')(
|
| 532 |
+
x, wp[:, layer_idx + 1], randomize_noise)
|
| 533 |
+
results[f'style{layer_idx + 1:02d}'] = style
|
| 534 |
+
x = (x + residual) / np.sqrt(2.0)
|
| 535 |
+
image, style = self.__getattr__(f'output{layer_idx // 2 + 1}')(
|
| 536 |
+
x, wp[:, layer_idx + 2])
|
| 537 |
+
results[f'output_style{layer_idx // 2}'] = style
|
| 538 |
+
results['image'] = self.final_activate(image)
|
| 539 |
+
return results
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class PixelNormLayer(nn.Module):
|
| 543 |
+
"""Implements pixel-wise feature vector normalization layer."""
|
| 544 |
+
|
| 545 |
+
def __init__(self, dim=1, epsilon=1e-8):
|
| 546 |
+
super().__init__()
|
| 547 |
+
self.dim = dim
|
| 548 |
+
self.eps = epsilon
|
| 549 |
+
|
| 550 |
+
def forward(self, x):
|
| 551 |
+
norm = torch.sqrt(
|
| 552 |
+
torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps)
|
| 553 |
+
return x / norm
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
class UpsamplingLayer(nn.Module):
|
| 557 |
+
"""Implements the upsampling layer.
|
| 558 |
+
|
| 559 |
+
This layer can also be used as filtering by setting `scale_factor` as 1.
|
| 560 |
+
"""
|
| 561 |
+
|
| 562 |
+
def __init__(self,
|
| 563 |
+
scale_factor=2,
|
| 564 |
+
kernel=(1, 3, 3, 1),
|
| 565 |
+
extra_padding=0,
|
| 566 |
+
kernel_gain=None):
|
| 567 |
+
super().__init__()
|
| 568 |
+
assert scale_factor >= 1
|
| 569 |
+
self.scale_factor = scale_factor
|
| 570 |
+
|
| 571 |
+
if extra_padding != 0:
|
| 572 |
+
assert scale_factor == 1
|
| 573 |
+
|
| 574 |
+
if kernel is None:
|
| 575 |
+
kernel = np.ones((scale_factor), dtype=np.float32)
|
| 576 |
+
else:
|
| 577 |
+
kernel = np.array(kernel, dtype=np.float32)
|
| 578 |
+
assert kernel.ndim == 1
|
| 579 |
+
kernel = np.outer(kernel, kernel)
|
| 580 |
+
kernel = kernel / np.sum(kernel)
|
| 581 |
+
if kernel_gain is None:
|
| 582 |
+
kernel = kernel * (scale_factor ** 2)
|
| 583 |
+
else:
|
| 584 |
+
assert kernel_gain > 0
|
| 585 |
+
kernel = kernel * (kernel_gain ** 2)
|
| 586 |
+
assert kernel.ndim == 2
|
| 587 |
+
assert kernel.shape[0] == kernel.shape[1]
|
| 588 |
+
kernel = kernel[np.newaxis, np.newaxis]
|
| 589 |
+
self.register_buffer('kernel', torch.from_numpy(kernel))
|
| 590 |
+
self.kernel = self.kernel.flip(0, 1)
|
| 591 |
+
|
| 592 |
+
self.upsample_padding = (0, scale_factor - 1, # Width padding.
|
| 593 |
+
0, 0, # Width.
|
| 594 |
+
0, scale_factor - 1, # Height padding.
|
| 595 |
+
0, 0, # Height.
|
| 596 |
+
0, 0, # Channel.
|
| 597 |
+
0, 0) # Batch size.
|
| 598 |
+
|
| 599 |
+
padding = kernel.shape[2] - scale_factor + extra_padding
|
| 600 |
+
self.padding = ((padding + 1) // 2 + scale_factor - 1, padding // 2,
|
| 601 |
+
(padding + 1) // 2 + scale_factor - 1, padding // 2)
|
| 602 |
+
|
| 603 |
+
def forward(self, x):
|
| 604 |
+
assert x.ndim == 4
|
| 605 |
+
channels = x.shape[1]
|
| 606 |
+
if self.scale_factor > 1:
|
| 607 |
+
x = x.view(-1, channels, x.shape[2], 1, x.shape[3], 1)
|
| 608 |
+
x = F.pad(x, self.upsample_padding, mode='constant', value=0)
|
| 609 |
+
x = x.view(-1, channels, x.shape[2] * self.scale_factor,
|
| 610 |
+
x.shape[4] * self.scale_factor)
|
| 611 |
+
x = x.view(-1, 1, x.shape[2], x.shape[3])
|
| 612 |
+
x = F.pad(x, self.padding, mode='constant', value=0)
|
| 613 |
+
x = F.conv2d(x, self.kernel, stride=1)
|
| 614 |
+
x = x.view(-1, channels, x.shape[2], x.shape[3])
|
| 615 |
+
return x
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
class InputBlock(nn.Module):
|
| 619 |
+
"""Implements the input block.
|
| 620 |
+
|
| 621 |
+
Basically, this block starts from a const input, which is with shape
|
| 622 |
+
`(channels, init_resolution, init_resolution)`.
|
| 623 |
+
"""
|
| 624 |
+
|
| 625 |
+
def __init__(self, init_resolution, channels):
|
| 626 |
+
super().__init__()
|
| 627 |
+
self.const = nn.Parameter(
|
| 628 |
+
torch.randn(1, channels, init_resolution, init_resolution))
|
| 629 |
+
|
| 630 |
+
def forward(self, w):
|
| 631 |
+
x = self.const.repeat(w.shape[0], 1, 1, 1)
|
| 632 |
+
return x
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class ConvBlock(nn.Module):
|
| 636 |
+
"""Implements the convolutional block (no style modulation).
|
| 637 |
+
|
| 638 |
+
Basically, this block executes, convolutional layer, filtering layer (if
|
| 639 |
+
needed), and activation layer in sequence.
|
| 640 |
+
|
| 641 |
+
NOTE: This block is particularly used for skip-connection branch in the
|
| 642 |
+
`resnet` structure.
|
| 643 |
+
"""
|
| 644 |
+
|
| 645 |
+
def __init__(self,
|
| 646 |
+
in_channels,
|
| 647 |
+
out_channels,
|
| 648 |
+
kernel_size=3,
|
| 649 |
+
add_bias=True,
|
| 650 |
+
scale_factor=1,
|
| 651 |
+
filtering_kernel=(1, 3, 3, 1),
|
| 652 |
+
use_wscale=True,
|
| 653 |
+
wscale_gain=_WSCALE_GAIN,
|
| 654 |
+
lr_mul=1.0,
|
| 655 |
+
activation_type='lrelu'):
|
| 656 |
+
"""Initializes with block settings.
|
| 657 |
+
|
| 658 |
+
Args:
|
| 659 |
+
in_channels: Number of channels of the input tensor.
|
| 660 |
+
out_channels: Number of channels of the output tensor.
|
| 661 |
+
kernel_size: Size of the convolutional kernels. (default: 3)
|
| 662 |
+
add_bias: Whether to add bias onto the convolutional result.
|
| 663 |
+
(default: True)
|
| 664 |
+
scale_factor: Scale factor for upsampling. `1` means skip
|
| 665 |
+
upsampling. (default: 1)
|
| 666 |
+
filtering_kernel: Kernel used for filtering after upsampling.
|
| 667 |
+
(default: (1, 3, 3, 1))
|
| 668 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
| 669 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
| 670 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
| 671 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
| 672 |
+
(default: `lrelu`)
|
| 673 |
+
|
| 674 |
+
Raises:
|
| 675 |
+
NotImplementedError: If the `activation_type` is not supported.
|
| 676 |
+
"""
|
| 677 |
+
super().__init__()
|
| 678 |
+
|
| 679 |
+
if scale_factor > 1:
|
| 680 |
+
self.use_conv2d_transpose = True
|
| 681 |
+
extra_padding = scale_factor - kernel_size
|
| 682 |
+
self.filter = UpsamplingLayer(scale_factor=1,
|
| 683 |
+
kernel=filtering_kernel,
|
| 684 |
+
extra_padding=extra_padding,
|
| 685 |
+
kernel_gain=scale_factor)
|
| 686 |
+
self.stride = scale_factor
|
| 687 |
+
self.padding = 0 # Padding is done in `UpsamplingLayer`.
|
| 688 |
+
else:
|
| 689 |
+
self.use_conv2d_transpose = False
|
| 690 |
+
assert kernel_size % 2 == 1
|
| 691 |
+
self.stride = 1
|
| 692 |
+
self.padding = kernel_size // 2
|
| 693 |
+
|
| 694 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
| 695 |
+
fan_in = kernel_size * kernel_size * in_channels
|
| 696 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
| 697 |
+
if use_wscale:
|
| 698 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
| 699 |
+
self.wscale = wscale * lr_mul
|
| 700 |
+
else:
|
| 701 |
+
self.weight = nn.Parameter(
|
| 702 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
| 703 |
+
self.wscale = lr_mul
|
| 704 |
+
|
| 705 |
+
if add_bias:
|
| 706 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
| 707 |
+
else:
|
| 708 |
+
self.bias = None
|
| 709 |
+
self.bscale = lr_mul
|
| 710 |
+
|
| 711 |
+
if activation_type == 'linear':
|
| 712 |
+
self.activate = nn.Identity()
|
| 713 |
+
self.activate_scale = 1.0
|
| 714 |
+
elif activation_type == 'lrelu':
|
| 715 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 716 |
+
self.activate_scale = np.sqrt(2.0)
|
| 717 |
+
else:
|
| 718 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
| 719 |
+
f'`{activation_type}`!')
|
| 720 |
+
|
| 721 |
+
def forward(self, x):
|
| 722 |
+
weight = self.weight * self.wscale
|
| 723 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
| 724 |
+
if self.use_conv2d_transpose:
|
| 725 |
+
weight = weight.permute(1, 0, 2, 3).flip(2, 3)
|
| 726 |
+
x = F.conv_transpose2d(x,
|
| 727 |
+
weight=weight,
|
| 728 |
+
bias=bias,
|
| 729 |
+
stride=self.scale_factor,
|
| 730 |
+
padding=self.padding)
|
| 731 |
+
x = self.filter(x)
|
| 732 |
+
else:
|
| 733 |
+
x = F.conv2d(x,
|
| 734 |
+
weight=weight,
|
| 735 |
+
bias=bias,
|
| 736 |
+
stride=self.stride,
|
| 737 |
+
padding=self.padding)
|
| 738 |
+
x = self.activate(x) * self.activate_scale
|
| 739 |
+
return x
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
class ModulateConvBlock(nn.Module):
|
| 743 |
+
"""Implements the convolutional block with style modulation."""
|
| 744 |
+
|
| 745 |
+
def __init__(self,
|
| 746 |
+
in_channels,
|
| 747 |
+
out_channels,
|
| 748 |
+
resolution,
|
| 749 |
+
w_space_dim,
|
| 750 |
+
kernel_size=3,
|
| 751 |
+
add_bias=True,
|
| 752 |
+
scale_factor=1,
|
| 753 |
+
filtering_kernel=(1, 3, 3, 1),
|
| 754 |
+
fused_modulate=True,
|
| 755 |
+
demodulate=True,
|
| 756 |
+
use_wscale=True,
|
| 757 |
+
wscale_gain=_WSCALE_GAIN,
|
| 758 |
+
lr_mul=1.0,
|
| 759 |
+
add_noise=True,
|
| 760 |
+
activation_type='lrelu',
|
| 761 |
+
epsilon=1e-8):
|
| 762 |
+
"""Initializes with block settings.
|
| 763 |
+
|
| 764 |
+
Args:
|
| 765 |
+
in_channels: Number of channels of the input tensor.
|
| 766 |
+
out_channels: Number of channels of the output tensor.
|
| 767 |
+
resolution: Resolution of the output tensor.
|
| 768 |
+
w_space_dim: Dimension of W space for style modulation.
|
| 769 |
+
kernel_size: Size of the convolutional kernels. (default: 3)
|
| 770 |
+
add_bias: Whether to add bias onto the convolutional result.
|
| 771 |
+
(default: True)
|
| 772 |
+
scale_factor: Scale factor for upsampling. `1` means skip
|
| 773 |
+
upsampling. (default: 1)
|
| 774 |
+
filtering_kernel: Kernel used for filtering after upsampling.
|
| 775 |
+
(default: (1, 3, 3, 1))
|
| 776 |
+
fused_modulate: Whether to fuse `style_modulate` and `conv2d`
|
| 777 |
+
together. (default: True)
|
| 778 |
+
demodulate: Whether to perform style demodulation. (default: True)
|
| 779 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
| 780 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
| 781 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
| 782 |
+
add_noise: Whether to add noise onto the output tensor. (default:
|
| 783 |
+
True)
|
| 784 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
| 785 |
+
(default: `lrelu`)
|
| 786 |
+
epsilon: Small number to avoid `divide by zero`. (default: 1e-8)
|
| 787 |
+
|
| 788 |
+
Raises:
|
| 789 |
+
NotImplementedError: If the `activation_type` is not supported.
|
| 790 |
+
"""
|
| 791 |
+
super().__init__()
|
| 792 |
+
|
| 793 |
+
self.res = resolution
|
| 794 |
+
self.in_c = in_channels
|
| 795 |
+
self.out_c = out_channels
|
| 796 |
+
self.ksize = kernel_size
|
| 797 |
+
self.eps = epsilon
|
| 798 |
+
|
| 799 |
+
if scale_factor > 1:
|
| 800 |
+
self.use_conv2d_transpose = True
|
| 801 |
+
extra_padding = scale_factor - kernel_size
|
| 802 |
+
self.filter = UpsamplingLayer(scale_factor=1,
|
| 803 |
+
kernel=filtering_kernel,
|
| 804 |
+
extra_padding=extra_padding,
|
| 805 |
+
kernel_gain=scale_factor)
|
| 806 |
+
self.stride = scale_factor
|
| 807 |
+
self.padding = 0 # Padding is done in `UpsamplingLayer`.
|
| 808 |
+
else:
|
| 809 |
+
self.use_conv2d_transpose = False
|
| 810 |
+
assert kernel_size % 2 == 1
|
| 811 |
+
self.stride = 1
|
| 812 |
+
self.padding = kernel_size // 2
|
| 813 |
+
|
| 814 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
| 815 |
+
fan_in = kernel_size * kernel_size * in_channels
|
| 816 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
| 817 |
+
if use_wscale:
|
| 818 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
| 819 |
+
self.wscale = wscale * lr_mul
|
| 820 |
+
else:
|
| 821 |
+
self.weight = nn.Parameter(
|
| 822 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
| 823 |
+
self.wscale = lr_mul
|
| 824 |
+
|
| 825 |
+
self.style = DenseBlock(in_channels=w_space_dim,
|
| 826 |
+
out_channels=in_channels,
|
| 827 |
+
additional_bias=1.0,
|
| 828 |
+
use_wscale=use_wscale,
|
| 829 |
+
activation_type='linear')
|
| 830 |
+
|
| 831 |
+
self.fused_modulate = fused_modulate
|
| 832 |
+
self.demodulate = demodulate
|
| 833 |
+
|
| 834 |
+
if add_bias:
|
| 835 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
| 836 |
+
else:
|
| 837 |
+
self.bias = None
|
| 838 |
+
self.bscale = lr_mul
|
| 839 |
+
|
| 840 |
+
if activation_type == 'linear':
|
| 841 |
+
self.activate = nn.Identity()
|
| 842 |
+
self.activate_scale = 1.0
|
| 843 |
+
elif activation_type == 'lrelu':
|
| 844 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 845 |
+
self.activate_scale = np.sqrt(2.0)
|
| 846 |
+
else:
|
| 847 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
| 848 |
+
f'`{activation_type}`!')
|
| 849 |
+
|
| 850 |
+
self.add_noise = add_noise
|
| 851 |
+
if self.add_noise:
|
| 852 |
+
self.register_buffer('noise', torch.randn(1, 1, self.res, self.res))
|
| 853 |
+
self.noise_strength = nn.Parameter(torch.zeros(()))
|
| 854 |
+
|
| 855 |
+
def forward(self, x, w, randomize_noise=False):
|
| 856 |
+
batch = x.shape[0]
|
| 857 |
+
|
| 858 |
+
weight = self.weight * self.wscale
|
| 859 |
+
weight = weight.permute(2, 3, 1, 0)
|
| 860 |
+
|
| 861 |
+
# Style modulation.
|
| 862 |
+
style = self.style(w)
|
| 863 |
+
_weight = weight.view(1, self.ksize, self.ksize, self.in_c, self.out_c)
|
| 864 |
+
_weight = _weight * style.view(batch, 1, 1, self.in_c, 1)
|
| 865 |
+
|
| 866 |
+
# Style demodulation.
|
| 867 |
+
if self.demodulate:
|
| 868 |
+
_weight_norm = torch.sqrt(
|
| 869 |
+
torch.sum(_weight ** 2, dim=[1, 2, 3]) + self.eps)
|
| 870 |
+
_weight = _weight / _weight_norm.view(batch, 1, 1, 1, self.out_c)
|
| 871 |
+
|
| 872 |
+
if self.fused_modulate:
|
| 873 |
+
x = x.view(1, batch * self.in_c, x.shape[2], x.shape[3])
|
| 874 |
+
weight = _weight.permute(1, 2, 3, 0, 4).reshape(
|
| 875 |
+
self.ksize, self.ksize, self.in_c, batch * self.out_c)
|
| 876 |
+
else:
|
| 877 |
+
x = x * style.view(batch, self.in_c, 1, 1)
|
| 878 |
+
|
| 879 |
+
if self.use_conv2d_transpose:
|
| 880 |
+
weight = weight.flip(0, 1)
|
| 881 |
+
if self.fused_modulate:
|
| 882 |
+
weight = weight.view(
|
| 883 |
+
self.ksize, self.ksize, self.in_c, batch, self.out_c)
|
| 884 |
+
weight = weight.permute(0, 1, 4, 3, 2)
|
| 885 |
+
weight = weight.reshape(
|
| 886 |
+
self.ksize, self.ksize, self.out_c, batch * self.in_c)
|
| 887 |
+
weight = weight.permute(3, 2, 0, 1)
|
| 888 |
+
else:
|
| 889 |
+
weight = weight.permute(2, 3, 0, 1)
|
| 890 |
+
x = F.conv_transpose2d(x,
|
| 891 |
+
weight=weight,
|
| 892 |
+
bias=None,
|
| 893 |
+
stride=self.stride,
|
| 894 |
+
padding=self.padding,
|
| 895 |
+
groups=(batch if self.fused_modulate else 1))
|
| 896 |
+
x = self.filter(x)
|
| 897 |
+
else:
|
| 898 |
+
weight = weight.permute(3, 2, 0, 1)
|
| 899 |
+
x = F.conv2d(x,
|
| 900 |
+
weight=weight,
|
| 901 |
+
bias=None,
|
| 902 |
+
stride=self.stride,
|
| 903 |
+
padding=self.padding,
|
| 904 |
+
groups=(batch if self.fused_modulate else 1))
|
| 905 |
+
|
| 906 |
+
if self.fused_modulate:
|
| 907 |
+
x = x.view(batch, self.out_c, self.res, self.res)
|
| 908 |
+
elif self.demodulate:
|
| 909 |
+
x = x / _weight_norm.view(batch, self.out_c, 1, 1)
|
| 910 |
+
|
| 911 |
+
if self.add_noise:
|
| 912 |
+
if randomize_noise:
|
| 913 |
+
noise = torch.randn(x.shape[0], 1, self.res, self.res).to(x)
|
| 914 |
+
else:
|
| 915 |
+
noise = self.noise
|
| 916 |
+
x = x + noise * self.noise_strength.view(1, 1, 1, 1)
|
| 917 |
+
|
| 918 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
| 919 |
+
if bias is not None:
|
| 920 |
+
x = x + bias.view(1, -1, 1, 1)
|
| 921 |
+
x = self.activate(x) * self.activate_scale
|
| 922 |
+
return x, style
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
class DenseBlock(nn.Module):
|
| 926 |
+
"""Implements the dense block.
|
| 927 |
+
|
| 928 |
+
Basically, this block executes fully-connected layer and activation layer.
|
| 929 |
+
|
| 930 |
+
NOTE: This layer supports adding an additional bias beyond the trainable
|
| 931 |
+
bias parameter. This is specially used for the mapping from the w code to
|
| 932 |
+
the style code.
|
| 933 |
+
"""
|
| 934 |
+
|
| 935 |
+
def __init__(self,
|
| 936 |
+
in_channels,
|
| 937 |
+
out_channels,
|
| 938 |
+
add_bias=True,
|
| 939 |
+
additional_bias=0,
|
| 940 |
+
use_wscale=True,
|
| 941 |
+
wscale_gain=_WSCALE_GAIN,
|
| 942 |
+
lr_mul=1.0,
|
| 943 |
+
activation_type='lrelu'):
|
| 944 |
+
"""Initializes with block settings.
|
| 945 |
+
|
| 946 |
+
Args:
|
| 947 |
+
in_channels: Number of channels of the input tensor.
|
| 948 |
+
out_channels: Number of channels of the output tensor.
|
| 949 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
| 950 |
+
(default: True)
|
| 951 |
+
additional_bias: The additional bias, which is independent from the
|
| 952 |
+
bias parameter. (default: 0.0)
|
| 953 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
| 954 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
| 955 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
| 956 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
| 957 |
+
(default: `lrelu`)
|
| 958 |
+
|
| 959 |
+
Raises:
|
| 960 |
+
NotImplementedError: If the `activation_type` is not supported.
|
| 961 |
+
"""
|
| 962 |
+
super().__init__()
|
| 963 |
+
weight_shape = (out_channels, in_channels)
|
| 964 |
+
wscale = wscale_gain / np.sqrt(in_channels)
|
| 965 |
+
if use_wscale:
|
| 966 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
| 967 |
+
self.wscale = wscale * lr_mul
|
| 968 |
+
else:
|
| 969 |
+
self.weight = nn.Parameter(
|
| 970 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
| 971 |
+
self.wscale = lr_mul
|
| 972 |
+
|
| 973 |
+
if add_bias:
|
| 974 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
| 975 |
+
else:
|
| 976 |
+
self.bias = None
|
| 977 |
+
self.bscale = lr_mul
|
| 978 |
+
self.additional_bias = additional_bias
|
| 979 |
+
|
| 980 |
+
if activation_type == 'linear':
|
| 981 |
+
self.activate = nn.Identity()
|
| 982 |
+
self.activate_scale = 1.0
|
| 983 |
+
elif activation_type == 'lrelu':
|
| 984 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 985 |
+
self.activate_scale = np.sqrt(2.0)
|
| 986 |
+
else:
|
| 987 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
| 988 |
+
f'`{activation_type}`!')
|
| 989 |
+
|
| 990 |
+
def forward(self, x):
|
| 991 |
+
if x.ndim != 2:
|
| 992 |
+
x = x.view(x.shape[0], -1)
|
| 993 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
| 994 |
+
x = F.linear(x, weight=self.weight * self.wscale, bias=bias)
|
| 995 |
+
x = self.activate(x + self.additional_bias) * self.activate_scale
|
| 996 |
+
return x
|
models/stylegan_discriminator.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3.7
|
| 2 |
+
"""Contains the implementation of discriminator described in StyleGAN.
|
| 3 |
+
|
| 4 |
+
Paper: https://arxiv.org/pdf/1812.04948.pdf
|
| 5 |
+
|
| 6 |
+
Official TensorFlow implementation: https://github.com/NVlabs/stylegan
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
__all__ = ['StyleGANDiscriminator']
|
| 16 |
+
|
| 17 |
+
# Resolutions allowed.
|
| 18 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
| 19 |
+
|
| 20 |
+
# Initial resolution.
|
| 21 |
+
_INIT_RES = 4
|
| 22 |
+
|
| 23 |
+
# Fused-scale options allowed.
|
| 24 |
+
_FUSED_SCALE_ALLOWED = [True, False, 'auto']
|
| 25 |
+
|
| 26 |
+
# Minimal resolution for `auto` fused-scale strategy.
|
| 27 |
+
_AUTO_FUSED_SCALE_MIN_RES = 128
|
| 28 |
+
|
| 29 |
+
# Default gain factor for weight scaling.
|
| 30 |
+
_WSCALE_GAIN = np.sqrt(2.0)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class StyleGANDiscriminator(nn.Module):
|
| 34 |
+
"""Defines the discriminator network in StyleGAN.
|
| 35 |
+
|
| 36 |
+
NOTE: The discriminator takes images with `RGB` channel order and pixel
|
| 37 |
+
range [-1, 1] as inputs.
|
| 38 |
+
|
| 39 |
+
Settings for the network:
|
| 40 |
+
|
| 41 |
+
(1) resolution: The resolution of the input image.
|
| 42 |
+
(2) image_channels: Number of channels of the input image. (default: 3)
|
| 43 |
+
(3) label_size: Size of the additional label for conditional generation.
|
| 44 |
+
(default: 0)
|
| 45 |
+
(4) fused_scale: Whether to fused `conv2d` and `downsample` together,
|
| 46 |
+
resulting in `conv2d` with strides. (default: `auto`)
|
| 47 |
+
(5) use_wscale: Whether to use weight scaling. (default: True)
|
| 48 |
+
(6) minibatch_std_group_size: Group size for the minibatch standard
|
| 49 |
+
deviation layer. 0 means disable. (default: 4)
|
| 50 |
+
(7) minibatch_std_channels: Number of new channels after the minibatch
|
| 51 |
+
standard deviation layer. (default: 1)
|
| 52 |
+
(8) fmaps_base: Factor to control number of feature maps for each layer.
|
| 53 |
+
(default: 16 << 10)
|
| 54 |
+
(9) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self,
|
| 58 |
+
resolution,
|
| 59 |
+
image_channels=3,
|
| 60 |
+
label_size=0,
|
| 61 |
+
fused_scale='auto',
|
| 62 |
+
use_wscale=True,
|
| 63 |
+
minibatch_std_group_size=4,
|
| 64 |
+
minibatch_std_channels=1,
|
| 65 |
+
fmaps_base=16 << 10,
|
| 66 |
+
fmaps_max=512):
|
| 67 |
+
"""Initializes with basic settings.
|
| 68 |
+
|
| 69 |
+
Raises:
|
| 70 |
+
ValueError: If the `resolution` is not supported, or `fused_scale`
|
| 71 |
+
is not supported.
|
| 72 |
+
"""
|
| 73 |
+
super().__init__()
|
| 74 |
+
|
| 75 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
| 76 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
| 77 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
| 78 |
+
if fused_scale not in _FUSED_SCALE_ALLOWED:
|
| 79 |
+
raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n'
|
| 80 |
+
f'Options allowed: {_FUSED_SCALE_ALLOWED}.')
|
| 81 |
+
|
| 82 |
+
self.init_res = _INIT_RES
|
| 83 |
+
self.init_res_log2 = int(np.log2(self.init_res))
|
| 84 |
+
self.resolution = resolution
|
| 85 |
+
self.final_res_log2 = int(np.log2(self.resolution))
|
| 86 |
+
self.image_channels = image_channels
|
| 87 |
+
self.label_size = label_size
|
| 88 |
+
self.fused_scale = fused_scale
|
| 89 |
+
self.use_wscale = use_wscale
|
| 90 |
+
self.minibatch_std_group_size = minibatch_std_group_size
|
| 91 |
+
self.minibatch_std_channels = minibatch_std_channels
|
| 92 |
+
self.fmaps_base = fmaps_base
|
| 93 |
+
self.fmaps_max = fmaps_max
|
| 94 |
+
|
| 95 |
+
# Level of detail (used for progressive training).
|
| 96 |
+
self.register_buffer('lod', torch.zeros(()))
|
| 97 |
+
self.pth_to_tf_var_mapping = {'lod': 'lod'}
|
| 98 |
+
|
| 99 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
| 100 |
+
res = 2 ** res_log2
|
| 101 |
+
block_idx = self.final_res_log2 - res_log2
|
| 102 |
+
|
| 103 |
+
# Input convolution layer for each resolution.
|
| 104 |
+
self.add_module(
|
| 105 |
+
f'input{block_idx}',
|
| 106 |
+
ConvBlock(in_channels=self.image_channels,
|
| 107 |
+
out_channels=self.get_nf(res),
|
| 108 |
+
kernel_size=1,
|
| 109 |
+
padding=0,
|
| 110 |
+
use_wscale=self.use_wscale))
|
| 111 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = (
|
| 112 |
+
f'FromRGB_lod{block_idx}/weight')
|
| 113 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = (
|
| 114 |
+
f'FromRGB_lod{block_idx}/bias')
|
| 115 |
+
|
| 116 |
+
# Convolution block for each resolution (except the last one).
|
| 117 |
+
if res != self.init_res:
|
| 118 |
+
if self.fused_scale == 'auto':
|
| 119 |
+
fused_scale = (res >= _AUTO_FUSED_SCALE_MIN_RES)
|
| 120 |
+
else:
|
| 121 |
+
fused_scale = self.fused_scale
|
| 122 |
+
self.add_module(
|
| 123 |
+
f'layer{2 * block_idx}',
|
| 124 |
+
ConvBlock(in_channels=self.get_nf(res),
|
| 125 |
+
out_channels=self.get_nf(res),
|
| 126 |
+
use_wscale=self.use_wscale))
|
| 127 |
+
tf_layer0_name = 'Conv0'
|
| 128 |
+
self.add_module(
|
| 129 |
+
f'layer{2 * block_idx + 1}',
|
| 130 |
+
ConvBlock(in_channels=self.get_nf(res),
|
| 131 |
+
out_channels=self.get_nf(res // 2),
|
| 132 |
+
downsample=True,
|
| 133 |
+
fused_scale=fused_scale,
|
| 134 |
+
use_wscale=self.use_wscale))
|
| 135 |
+
tf_layer1_name = 'Conv1_down'
|
| 136 |
+
|
| 137 |
+
# Convolution block for last resolution.
|
| 138 |
+
else:
|
| 139 |
+
self.add_module(
|
| 140 |
+
f'layer{2 * block_idx}',
|
| 141 |
+
ConvBlock(in_channels=self.get_nf(res),
|
| 142 |
+
out_channels=self.get_nf(res),
|
| 143 |
+
use_wscale=self.use_wscale,
|
| 144 |
+
minibatch_std_group_size=minibatch_std_group_size,
|
| 145 |
+
minibatch_std_channels=minibatch_std_channels))
|
| 146 |
+
tf_layer0_name = 'Conv'
|
| 147 |
+
self.add_module(
|
| 148 |
+
f'layer{2 * block_idx + 1}',
|
| 149 |
+
DenseBlock(in_channels=self.get_nf(res) * res * res,
|
| 150 |
+
out_channels=self.get_nf(res // 2),
|
| 151 |
+
use_wscale=self.use_wscale))
|
| 152 |
+
tf_layer1_name = 'Dense0'
|
| 153 |
+
|
| 154 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
|
| 155 |
+
f'{res}x{res}/{tf_layer0_name}/weight')
|
| 156 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
|
| 157 |
+
f'{res}x{res}/{tf_layer0_name}/bias')
|
| 158 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
|
| 159 |
+
f'{res}x{res}/{tf_layer1_name}/weight')
|
| 160 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
|
| 161 |
+
f'{res}x{res}/{tf_layer1_name}/bias')
|
| 162 |
+
|
| 163 |
+
# Final dense block.
|
| 164 |
+
self.add_module(
|
| 165 |
+
f'layer{2 * block_idx + 2}',
|
| 166 |
+
DenseBlock(in_channels=self.get_nf(res // 2),
|
| 167 |
+
out_channels=max(self.label_size, 1),
|
| 168 |
+
use_wscale=self.use_wscale,
|
| 169 |
+
wscale_gain=1.0,
|
| 170 |
+
activation_type='linear'))
|
| 171 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.weight'] = (
|
| 172 |
+
f'{res}x{res}/Dense1/weight')
|
| 173 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.bias'] = (
|
| 174 |
+
f'{res}x{res}/Dense1/bias')
|
| 175 |
+
|
| 176 |
+
self.downsample = DownsamplingLayer()
|
| 177 |
+
|
| 178 |
+
def get_nf(self, res):
|
| 179 |
+
"""Gets number of feature maps according to current resolution."""
|
| 180 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
| 181 |
+
|
| 182 |
+
def forward(self, image, label=None, lod=None, **_unused_kwargs):
|
| 183 |
+
expected_shape = (self.image_channels, self.resolution, self.resolution)
|
| 184 |
+
if image.ndim != 4 or image.shape[1:] != expected_shape:
|
| 185 |
+
raise ValueError(f'The input tensor should be with shape '
|
| 186 |
+
f'[batch_size, channel, height, width], where '
|
| 187 |
+
f'`channel` equals to {self.image_channels}, '
|
| 188 |
+
f'`height`, `width` equal to {self.resolution}!\n'
|
| 189 |
+
f'But `{image.shape}` is received!')
|
| 190 |
+
|
| 191 |
+
lod = self.lod.cpu().tolist() if lod is None else lod
|
| 192 |
+
if lod + self.init_res_log2 > self.final_res_log2:
|
| 193 |
+
raise ValueError(f'Maximum level-of-detail (lod) is '
|
| 194 |
+
f'{self.final_res_log2 - self.init_res_log2}, '
|
| 195 |
+
f'but `{lod}` is received!')
|
| 196 |
+
|
| 197 |
+
if self.label_size:
|
| 198 |
+
if label is None:
|
| 199 |
+
raise ValueError(f'Model requires an additional label '
|
| 200 |
+
f'(with size {self.label_size}) as input, '
|
| 201 |
+
f'but no label is received!')
|
| 202 |
+
batch_size = image.shape[0]
|
| 203 |
+
if label.ndim != 2 or label.shape != (batch_size, self.label_size):
|
| 204 |
+
raise ValueError(f'Input label should be with shape '
|
| 205 |
+
f'[batch_size, label_size], where '
|
| 206 |
+
f'`batch_size` equals to that of '
|
| 207 |
+
f'images ({image.shape[0]}) and '
|
| 208 |
+
f'`label_size` equals to {self.label_size}!\n'
|
| 209 |
+
f'But `{label.shape}` is received!')
|
| 210 |
+
|
| 211 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
| 212 |
+
block_idx = current_lod = self.final_res_log2 - res_log2
|
| 213 |
+
if current_lod <= lod < current_lod + 1:
|
| 214 |
+
x = self.__getattr__(f'input{block_idx}')(image)
|
| 215 |
+
elif current_lod - 1 < lod < current_lod:
|
| 216 |
+
alpha = lod - np.floor(lod)
|
| 217 |
+
x = (self.__getattr__(f'input{block_idx}')(image) * alpha +
|
| 218 |
+
x * (1 - alpha))
|
| 219 |
+
if lod < current_lod + 1:
|
| 220 |
+
x = self.__getattr__(f'layer{2 * block_idx}')(x)
|
| 221 |
+
x = self.__getattr__(f'layer{2 * block_idx + 1}')(x)
|
| 222 |
+
if lod > current_lod:
|
| 223 |
+
image = self.downsample(image)
|
| 224 |
+
x = self.__getattr__(f'layer{2 * block_idx + 2}')(x)
|
| 225 |
+
|
| 226 |
+
if self.label_size:
|
| 227 |
+
x = torch.sum(x * label, dim=1, keepdim=True)
|
| 228 |
+
|
| 229 |
+
return x
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class MiniBatchSTDLayer(nn.Module):
|
| 233 |
+
"""Implements the minibatch standard deviation layer."""
|
| 234 |
+
|
| 235 |
+
def __init__(self, group_size=4, new_channels=1, epsilon=1e-8):
|
| 236 |
+
super().__init__()
|
| 237 |
+
self.group_size = group_size
|
| 238 |
+
self.new_channels = new_channels
|
| 239 |
+
self.epsilon = epsilon
|
| 240 |
+
|
| 241 |
+
def forward(self, x):
|
| 242 |
+
if self.group_size <= 1:
|
| 243 |
+
return x
|
| 244 |
+
ng = min(self.group_size, x.shape[0])
|
| 245 |
+
nc = self.new_channels
|
| 246 |
+
temp_c = x.shape[1] // nc # [NCHW]
|
| 247 |
+
y = x.view(ng, -1, nc, temp_c, x.shape[2], x.shape[3]) # [GMncHW]
|
| 248 |
+
y = y - torch.mean(y, dim=0, keepdim=True) # [GMncHW]
|
| 249 |
+
y = torch.mean(y ** 2, dim=0) # [MncHW]
|
| 250 |
+
y = torch.sqrt(y + self.epsilon) # [MncHW]
|
| 251 |
+
y = torch.mean(y, dim=[2, 3, 4], keepdim=True) # [Mn111]
|
| 252 |
+
y = torch.mean(y, dim=2) # [Mn11]
|
| 253 |
+
y = y.repeat(ng, 1, x.shape[2], x.shape[3]) # [NnHW]
|
| 254 |
+
return torch.cat([x, y], dim=1)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class DownsamplingLayer(nn.Module):
|
| 258 |
+
"""Implements the downsampling layer.
|
| 259 |
+
|
| 260 |
+
Basically, this layer can be used to downsample feature maps with average
|
| 261 |
+
pooling.
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
def __init__(self, scale_factor=2):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.scale_factor = scale_factor
|
| 267 |
+
|
| 268 |
+
def forward(self, x):
|
| 269 |
+
if self.scale_factor <= 1:
|
| 270 |
+
return x
|
| 271 |
+
return F.avg_pool2d(x,
|
| 272 |
+
kernel_size=self.scale_factor,
|
| 273 |
+
stride=self.scale_factor,
|
| 274 |
+
padding=0)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class Blur(torch.autograd.Function):
|
| 278 |
+
"""Defines blur operation with customized gradient computation."""
|
| 279 |
+
|
| 280 |
+
@staticmethod
|
| 281 |
+
def forward(ctx, x, kernel):
|
| 282 |
+
ctx.save_for_backward(kernel)
|
| 283 |
+
y = F.conv2d(input=x,
|
| 284 |
+
weight=kernel,
|
| 285 |
+
bias=None,
|
| 286 |
+
stride=1,
|
| 287 |
+
padding=1,
|
| 288 |
+
groups=x.shape[1])
|
| 289 |
+
return y
|
| 290 |
+
|
| 291 |
+
@staticmethod
|
| 292 |
+
def backward(ctx, dy):
|
| 293 |
+
kernel, = ctx.saved_tensors
|
| 294 |
+
dx = BlurBackPropagation.apply(dy, kernel)
|
| 295 |
+
return dx, None, None
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class BlurBackPropagation(torch.autograd.Function):
|
| 299 |
+
"""Defines the back propagation of blur operation.
|
| 300 |
+
|
| 301 |
+
NOTE: This is used to speed up the backward of gradient penalty.
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
@staticmethod
|
| 305 |
+
def forward(ctx, dy, kernel):
|
| 306 |
+
ctx.save_for_backward(kernel)
|
| 307 |
+
dx = F.conv2d(input=dy,
|
| 308 |
+
weight=kernel.flip((2, 3)),
|
| 309 |
+
bias=None,
|
| 310 |
+
stride=1,
|
| 311 |
+
padding=1,
|
| 312 |
+
groups=dy.shape[1])
|
| 313 |
+
return dx
|
| 314 |
+
|
| 315 |
+
@staticmethod
|
| 316 |
+
def backward(ctx, ddx):
|
| 317 |
+
kernel, = ctx.saved_tensors
|
| 318 |
+
ddy = F.conv2d(input=ddx,
|
| 319 |
+
weight=kernel,
|
| 320 |
+
bias=None,
|
| 321 |
+
stride=1,
|
| 322 |
+
padding=1,
|
| 323 |
+
groups=ddx.shape[1])
|
| 324 |
+
return ddy, None, None
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class BlurLayer(nn.Module):
|
| 328 |
+
"""Implements the blur layer."""
|
| 329 |
+
|
| 330 |
+
def __init__(self,
|
| 331 |
+
channels,
|
| 332 |
+
kernel=(1, 2, 1),
|
| 333 |
+
normalize=True):
|
| 334 |
+
super().__init__()
|
| 335 |
+
kernel = np.array(kernel, dtype=np.float32).reshape(1, -1)
|
| 336 |
+
kernel = kernel.T.dot(kernel)
|
| 337 |
+
if normalize:
|
| 338 |
+
kernel = kernel / np.sum(kernel)
|
| 339 |
+
kernel = kernel[np.newaxis, np.newaxis]
|
| 340 |
+
kernel = np.tile(kernel, [channels, 1, 1, 1])
|
| 341 |
+
self.register_buffer('kernel', torch.from_numpy(kernel))
|
| 342 |
+
|
| 343 |
+
def forward(self, x):
|
| 344 |
+
return Blur.apply(x, self.kernel)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class ConvBlock(nn.Module):
|
| 348 |
+
"""Implements the convolutional block.
|
| 349 |
+
|
| 350 |
+
Basically, this block executes minibatch standard deviation layer (if
|
| 351 |
+
needed), convolutional layer, activation layer, and downsampling layer (
|
| 352 |
+
if needed) in sequence.
|
| 353 |
+
"""
|
| 354 |
+
|
| 355 |
+
def __init__(self,
|
| 356 |
+
in_channels,
|
| 357 |
+
out_channels,
|
| 358 |
+
kernel_size=3,
|
| 359 |
+
stride=1,
|
| 360 |
+
padding=1,
|
| 361 |
+
add_bias=True,
|
| 362 |
+
downsample=False,
|
| 363 |
+
fused_scale=False,
|
| 364 |
+
use_wscale=True,
|
| 365 |
+
wscale_gain=_WSCALE_GAIN,
|
| 366 |
+
lr_mul=1.0,
|
| 367 |
+
activation_type='lrelu',
|
| 368 |
+
minibatch_std_group_size=0,
|
| 369 |
+
minibatch_std_channels=1):
|
| 370 |
+
"""Initializes with block settings.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
in_channels: Number of channels of the input tensor.
|
| 374 |
+
out_channels: Number of channels of the output tensor.
|
| 375 |
+
kernel_size: Size of the convolutional kernels. (default: 3)
|
| 376 |
+
stride: Stride parameter for convolution operation. (default: 1)
|
| 377 |
+
padding: Padding parameter for convolution operation. (default: 1)
|
| 378 |
+
add_bias: Whether to add bias onto the convolutional result.
|
| 379 |
+
(default: True)
|
| 380 |
+
downsample: Whether to downsample the result after convolution.
|
| 381 |
+
(default: False)
|
| 382 |
+
fused_scale: Whether to fused `conv2d` and `downsample` together,
|
| 383 |
+
resulting in `conv2d` with strides. (default: False)
|
| 384 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
| 385 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
| 386 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
| 387 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
| 388 |
+
(default: `lrelu`)
|
| 389 |
+
minibatch_std_group_size: Group size for the minibatch standard
|
| 390 |
+
deviation layer. 0 means disable. (default: 0)
|
| 391 |
+
minibatch_std_channels: Number of new channels after the minibatch
|
| 392 |
+
standard deviation layer. (default: 1)
|
| 393 |
+
|
| 394 |
+
Raises:
|
| 395 |
+
NotImplementedError: If the `activation_type` is not supported.
|
| 396 |
+
"""
|
| 397 |
+
super().__init__()
|
| 398 |
+
|
| 399 |
+
if minibatch_std_group_size > 1:
|
| 400 |
+
in_channels = in_channels + minibatch_std_channels
|
| 401 |
+
self.mbstd = MiniBatchSTDLayer(group_size=minibatch_std_group_size,
|
| 402 |
+
new_channels=minibatch_std_channels)
|
| 403 |
+
else:
|
| 404 |
+
self.mbstd = nn.Identity()
|
| 405 |
+
|
| 406 |
+
if downsample:
|
| 407 |
+
self.blur = BlurLayer(channels=in_channels)
|
| 408 |
+
else:
|
| 409 |
+
self.blur = nn.Identity()
|
| 410 |
+
|
| 411 |
+
if downsample and not fused_scale:
|
| 412 |
+
self.downsample = DownsamplingLayer()
|
| 413 |
+
else:
|
| 414 |
+
self.downsample = nn.Identity()
|
| 415 |
+
|
| 416 |
+
if downsample and fused_scale:
|
| 417 |
+
self.use_stride = True
|
| 418 |
+
self.stride = 2
|
| 419 |
+
self.padding = 1
|
| 420 |
+
else:
|
| 421 |
+
self.use_stride = False
|
| 422 |
+
self.stride = stride
|
| 423 |
+
self.padding = padding
|
| 424 |
+
|
| 425 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
| 426 |
+
fan_in = kernel_size * kernel_size * in_channels
|
| 427 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
| 428 |
+
if use_wscale:
|
| 429 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
| 430 |
+
self.wscale = wscale * lr_mul
|
| 431 |
+
else:
|
| 432 |
+
self.weight = nn.Parameter(
|
| 433 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
| 434 |
+
self.wscale = lr_mul
|
| 435 |
+
|
| 436 |
+
if add_bias:
|
| 437 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
| 438 |
+
self.bscale = lr_mul
|
| 439 |
+
else:
|
| 440 |
+
self.bias = None
|
| 441 |
+
|
| 442 |
+
if activation_type == 'linear':
|
| 443 |
+
self.activate = nn.Identity()
|
| 444 |
+
elif activation_type == 'lrelu':
|
| 445 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 446 |
+
else:
|
| 447 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
| 448 |
+
f'`{activation_type}`!')
|
| 449 |
+
|
| 450 |
+
def forward(self, x):
|
| 451 |
+
x = self.mbstd(x)
|
| 452 |
+
x = self.blur(x)
|
| 453 |
+
weight = self.weight * self.wscale
|
| 454 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
| 455 |
+
if self.use_stride:
|
| 456 |
+
weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
|
| 457 |
+
weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
|
| 458 |
+
weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25
|
| 459 |
+
x = F.conv2d(x,
|
| 460 |
+
weight=weight,
|
| 461 |
+
bias=bias,
|
| 462 |
+
stride=self.stride,
|
| 463 |
+
padding=self.padding)
|
| 464 |
+
x = self.downsample(x)
|
| 465 |
+
x = self.activate(x)
|
| 466 |
+
return x
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
class DenseBlock(nn.Module):
|
| 470 |
+
"""Implements the dense block.
|
| 471 |
+
|
| 472 |
+
Basically, this block executes fully-connected layer and activation layer.
|
| 473 |
+
"""
|
| 474 |
+
|
| 475 |
+
def __init__(self,
|
| 476 |
+
in_channels,
|
| 477 |
+
out_channels,
|
| 478 |
+
add_bias=True,
|
| 479 |
+
use_wscale=True,
|
| 480 |
+
wscale_gain=_WSCALE_GAIN,
|
| 481 |
+
lr_mul=1.0,
|
| 482 |
+
activation_type='lrelu'):
|
| 483 |
+
"""Initializes with block settings.
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
in_channels: Number of channels of the input tensor.
|
| 487 |
+
out_channels: Number of channels of the output tensor.
|
| 488 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
| 489 |
+
(default: True)
|
| 490 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
| 491 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
| 492 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
| 493 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
| 494 |
+
(default: `lrelu`)
|
| 495 |
+
|
| 496 |
+
Raises:
|
| 497 |
+
NotImplementedError: If the `activation_type` is not supported.
|
| 498 |
+
"""
|
| 499 |
+
super().__init__()
|
| 500 |
+
weight_shape = (out_channels, in_channels)
|
| 501 |
+
wscale = wscale_gain / np.sqrt(in_channels)
|
| 502 |
+
if use_wscale:
|
| 503 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
| 504 |
+
self.wscale = wscale * lr_mul
|
| 505 |
+
else:
|
| 506 |
+
self.weight = nn.Parameter(
|
| 507 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
| 508 |
+
self.wscale = lr_mul
|
| 509 |
+
|
| 510 |
+
if add_bias:
|
| 511 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
| 512 |
+
self.bscale = lr_mul
|
| 513 |
+
else:
|
| 514 |
+
self.bias = None
|
| 515 |
+
|
| 516 |
+
if activation_type == 'linear':
|
| 517 |
+
self.activate = nn.Identity()
|
| 518 |
+
elif activation_type == 'lrelu':
|
| 519 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 520 |
+
else:
|
| 521 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
| 522 |
+
f'`{activation_type}`!')
|
| 523 |
+
|
| 524 |
+
def forward(self, x):
|
| 525 |
+
if x.ndim != 2:
|
| 526 |
+
x = x.view(x.shape[0], -1)
|
| 527 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
| 528 |
+
x = F.linear(x, weight=self.weight * self.wscale, bias=bias)
|
| 529 |
+
x = self.activate(x)
|
| 530 |
+
return x
|
models/stylegan_generator.py
ADDED
|
@@ -0,0 +1,869 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3.7
|
| 2 |
+
"""Contains the implementation of generator described in StyleGAN.
|
| 3 |
+
|
| 4 |
+
Paper: https://arxiv.org/pdf/1812.04948.pdf
|
| 5 |
+
|
| 6 |
+
Official TensorFlow implementation: https://github.com/NVlabs/stylegan
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from .sync_op import all_gather
|
| 16 |
+
|
| 17 |
+
__all__ = ['StyleGANGenerator']
|
| 18 |
+
|
| 19 |
+
# Resolutions allowed.
|
| 20 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
| 21 |
+
|
| 22 |
+
# Initial resolution.
|
| 23 |
+
_INIT_RES = 4
|
| 24 |
+
|
| 25 |
+
# Fused-scale options allowed.
|
| 26 |
+
_FUSED_SCALE_ALLOWED = [True, False, 'auto']
|
| 27 |
+
|
| 28 |
+
# Minimal resolution for `auto` fused-scale strategy.
|
| 29 |
+
_AUTO_FUSED_SCALE_MIN_RES = 128
|
| 30 |
+
|
| 31 |
+
# Default gain factor for weight scaling.
|
| 32 |
+
_WSCALE_GAIN = np.sqrt(2.0)
|
| 33 |
+
_STYLEMOD_WSCALE_GAIN = 1.0
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class StyleGANGenerator(nn.Module):
|
| 37 |
+
"""Defines the generator network in StyleGAN.
|
| 38 |
+
|
| 39 |
+
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
| 40 |
+
[-1, 1].
|
| 41 |
+
|
| 42 |
+
Settings for the mapping network:
|
| 43 |
+
|
| 44 |
+
(1) z_space_dim: Dimension of the input latent space, Z. (default: 512)
|
| 45 |
+
(2) w_space_dim: Dimension of the outout latent space, W. (default: 512)
|
| 46 |
+
(3) label_size: Size of the additional label for conditional generation.
|
| 47 |
+
(default: 0)
|
| 48 |
+
(4)mapping_layers: Number of layers of the mapping network. (default: 8)
|
| 49 |
+
(5) mapping_fmaps: Number of hidden channels of the mapping network.
|
| 50 |
+
(default: 512)
|
| 51 |
+
(6) mapping_lr_mul: Learning rate multiplier for the mapping network.
|
| 52 |
+
(default: 0.01)
|
| 53 |
+
(7) repeat_w: Repeat w-code for different layers.
|
| 54 |
+
|
| 55 |
+
Settings for the synthesis network:
|
| 56 |
+
|
| 57 |
+
(1) resolution: The resolution of the output image.
|
| 58 |
+
(2) image_channels: Number of channels of the output image. (default: 3)
|
| 59 |
+
(3) final_tanh: Whether to use `tanh` to control the final pixel range.
|
| 60 |
+
(default: False)
|
| 61 |
+
(4) const_input: Whether to use a constant in the first convolutional layer.
|
| 62 |
+
(default: True)
|
| 63 |
+
(5) fused_scale: Whether to fused `upsample` and `conv2d` together,
|
| 64 |
+
resulting in `conv2d_transpose`. (default: `auto`)
|
| 65 |
+
(6) use_wscale: Whether to use weight scaling. (default: True)
|
| 66 |
+
(7) fmaps_base: Factor to control number of feature maps for each layer.
|
| 67 |
+
(default: 16 << 10)
|
| 68 |
+
(8) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(self,
|
| 72 |
+
resolution,
|
| 73 |
+
z_space_dim=512,
|
| 74 |
+
w_space_dim=512,
|
| 75 |
+
label_size=0,
|
| 76 |
+
mapping_layers=8,
|
| 77 |
+
mapping_fmaps=512,
|
| 78 |
+
mapping_lr_mul=0.01,
|
| 79 |
+
repeat_w=True,
|
| 80 |
+
image_channels=3,
|
| 81 |
+
final_tanh=False,
|
| 82 |
+
const_input=True,
|
| 83 |
+
fused_scale='auto',
|
| 84 |
+
use_wscale=True,
|
| 85 |
+
fmaps_base=16 << 10,
|
| 86 |
+
fmaps_max=512):
|
| 87 |
+
"""Initializes with basic settings.
|
| 88 |
+
|
| 89 |
+
Raises:
|
| 90 |
+
ValueError: If the `resolution` is not supported, or `fused_scale`
|
| 91 |
+
is not supported.
|
| 92 |
+
"""
|
| 93 |
+
super().__init__()
|
| 94 |
+
|
| 95 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
| 96 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
| 97 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
| 98 |
+
if fused_scale not in _FUSED_SCALE_ALLOWED:
|
| 99 |
+
raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n'
|
| 100 |
+
f'Options allowed: {_FUSED_SCALE_ALLOWED}.')
|
| 101 |
+
|
| 102 |
+
self.init_res = _INIT_RES
|
| 103 |
+
self.resolution = resolution
|
| 104 |
+
self.z_space_dim = z_space_dim
|
| 105 |
+
self.w_space_dim = w_space_dim
|
| 106 |
+
self.label_size = label_size
|
| 107 |
+
self.mapping_layers = mapping_layers
|
| 108 |
+
self.mapping_fmaps = mapping_fmaps
|
| 109 |
+
self.mapping_lr_mul = mapping_lr_mul
|
| 110 |
+
self.repeat_w = repeat_w
|
| 111 |
+
self.image_channels = image_channels
|
| 112 |
+
self.final_tanh = final_tanh
|
| 113 |
+
self.const_input = const_input
|
| 114 |
+
self.fused_scale = fused_scale
|
| 115 |
+
self.use_wscale = use_wscale
|
| 116 |
+
self.fmaps_base = fmaps_base
|
| 117 |
+
self.fmaps_max = fmaps_max
|
| 118 |
+
|
| 119 |
+
self.num_layers = int(np.log2(self.resolution // self.init_res * 2)) * 2
|
| 120 |
+
|
| 121 |
+
if self.repeat_w:
|
| 122 |
+
self.mapping_space_dim = self.w_space_dim
|
| 123 |
+
else:
|
| 124 |
+
self.mapping_space_dim = self.w_space_dim * self.num_layers
|
| 125 |
+
self.mapping = MappingModule(input_space_dim=self.z_space_dim,
|
| 126 |
+
hidden_space_dim=self.mapping_fmaps,
|
| 127 |
+
final_space_dim=self.mapping_space_dim,
|
| 128 |
+
label_size=self.label_size,
|
| 129 |
+
num_layers=self.mapping_layers,
|
| 130 |
+
use_wscale=self.use_wscale,
|
| 131 |
+
lr_mul=self.mapping_lr_mul)
|
| 132 |
+
|
| 133 |
+
self.truncation = TruncationModule(w_space_dim=self.w_space_dim,
|
| 134 |
+
num_layers=self.num_layers,
|
| 135 |
+
repeat_w=self.repeat_w)
|
| 136 |
+
|
| 137 |
+
self.synthesis = SynthesisModule(resolution=self.resolution,
|
| 138 |
+
init_resolution=self.init_res,
|
| 139 |
+
w_space_dim=self.w_space_dim,
|
| 140 |
+
image_channels=self.image_channels,
|
| 141 |
+
final_tanh=self.final_tanh,
|
| 142 |
+
const_input=self.const_input,
|
| 143 |
+
fused_scale=self.fused_scale,
|
| 144 |
+
use_wscale=self.use_wscale,
|
| 145 |
+
fmaps_base=self.fmaps_base,
|
| 146 |
+
fmaps_max=self.fmaps_max)
|
| 147 |
+
|
| 148 |
+
self.pth_to_tf_var_mapping = {}
|
| 149 |
+
for key, val in self.mapping.pth_to_tf_var_mapping.items():
|
| 150 |
+
self.pth_to_tf_var_mapping[f'mapping.{key}'] = val
|
| 151 |
+
for key, val in self.truncation.pth_to_tf_var_mapping.items():
|
| 152 |
+
self.pth_to_tf_var_mapping[f'truncation.{key}'] = val
|
| 153 |
+
for key, val in self.synthesis.pth_to_tf_var_mapping.items():
|
| 154 |
+
self.pth_to_tf_var_mapping[f'synthesis.{key}'] = val
|
| 155 |
+
|
| 156 |
+
def forward(self,
|
| 157 |
+
z,
|
| 158 |
+
label=None,
|
| 159 |
+
lod=None,
|
| 160 |
+
w_moving_decay=0.995,
|
| 161 |
+
style_mixing_prob=0.9,
|
| 162 |
+
trunc_psi=None,
|
| 163 |
+
trunc_layers=None,
|
| 164 |
+
randomize_noise=False,
|
| 165 |
+
**_unused_kwargs):
|
| 166 |
+
mapping_results = self.mapping(z, label)
|
| 167 |
+
w = mapping_results['w']
|
| 168 |
+
|
| 169 |
+
if self.training and w_moving_decay < 1:
|
| 170 |
+
batch_w_avg = all_gather(w).mean(dim=0)
|
| 171 |
+
self.truncation.w_avg.copy_(
|
| 172 |
+
self.truncation.w_avg * w_moving_decay +
|
| 173 |
+
batch_w_avg * (1 - w_moving_decay))
|
| 174 |
+
|
| 175 |
+
if self.training and style_mixing_prob > 0:
|
| 176 |
+
new_z = torch.randn_like(z)
|
| 177 |
+
new_w = self.mapping(new_z, label)['w']
|
| 178 |
+
lod = self.synthesis.lod.cpu().tolist() if lod is None else lod
|
| 179 |
+
current_layers = self.num_layers - int(lod) * 2
|
| 180 |
+
if np.random.uniform() < style_mixing_prob:
|
| 181 |
+
mixing_cutoff = np.random.randint(1, current_layers)
|
| 182 |
+
w = self.truncation(w)
|
| 183 |
+
new_w = self.truncation(new_w)
|
| 184 |
+
w[:, mixing_cutoff:] = new_w[:, mixing_cutoff:]
|
| 185 |
+
|
| 186 |
+
wp = self.truncation(w, trunc_psi, trunc_layers)
|
| 187 |
+
synthesis_results = self.synthesis(wp, lod, randomize_noise)
|
| 188 |
+
|
| 189 |
+
return {**mapping_results, **synthesis_results}
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class MappingModule(nn.Module):
|
| 193 |
+
"""Implements the latent space mapping module.
|
| 194 |
+
|
| 195 |
+
Basically, this module executes several dense layers in sequence.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def __init__(self,
|
| 199 |
+
input_space_dim=512,
|
| 200 |
+
hidden_space_dim=512,
|
| 201 |
+
final_space_dim=512,
|
| 202 |
+
label_size=0,
|
| 203 |
+
num_layers=8,
|
| 204 |
+
normalize_input=True,
|
| 205 |
+
use_wscale=True,
|
| 206 |
+
lr_mul=0.01):
|
| 207 |
+
super().__init__()
|
| 208 |
+
|
| 209 |
+
self.input_space_dim = input_space_dim
|
| 210 |
+
self.hidden_space_dim = hidden_space_dim
|
| 211 |
+
self.final_space_dim = final_space_dim
|
| 212 |
+
self.label_size = label_size
|
| 213 |
+
self.num_layers = num_layers
|
| 214 |
+
self.normalize_input = normalize_input
|
| 215 |
+
self.use_wscale = use_wscale
|
| 216 |
+
self.lr_mul = lr_mul
|
| 217 |
+
|
| 218 |
+
self.norm = PixelNormLayer() if self.normalize_input else nn.Identity()
|
| 219 |
+
|
| 220 |
+
self.pth_to_tf_var_mapping = {}
|
| 221 |
+
for i in range(num_layers):
|
| 222 |
+
dim_mul = 2 if label_size else 1
|
| 223 |
+
in_channels = (input_space_dim * dim_mul if i == 0 else
|
| 224 |
+
hidden_space_dim)
|
| 225 |
+
out_channels = (final_space_dim if i == (num_layers - 1) else
|
| 226 |
+
hidden_space_dim)
|
| 227 |
+
self.add_module(f'dense{i}',
|
| 228 |
+
DenseBlock(in_channels=in_channels,
|
| 229 |
+
out_channels=out_channels,
|
| 230 |
+
use_wscale=self.use_wscale,
|
| 231 |
+
lr_mul=self.lr_mul))
|
| 232 |
+
self.pth_to_tf_var_mapping[f'dense{i}.weight'] = f'Dense{i}/weight'
|
| 233 |
+
self.pth_to_tf_var_mapping[f'dense{i}.bias'] = f'Dense{i}/bias'
|
| 234 |
+
if label_size:
|
| 235 |
+
self.label_weight = nn.Parameter(
|
| 236 |
+
torch.randn(label_size, input_space_dim))
|
| 237 |
+
self.pth_to_tf_var_mapping[f'label_weight'] = f'LabelConcat/weight'
|
| 238 |
+
|
| 239 |
+
def forward(self, z, label=None):
|
| 240 |
+
if z.ndim != 2 or z.shape[1] != self.input_space_dim:
|
| 241 |
+
raise ValueError(f'Input latent code should be with shape '
|
| 242 |
+
f'[batch_size, input_dim], where '
|
| 243 |
+
f'`input_dim` equals to {self.input_space_dim}!\n'
|
| 244 |
+
f'But `{z.shape}` is received!')
|
| 245 |
+
if self.label_size:
|
| 246 |
+
if label is None:
|
| 247 |
+
raise ValueError(f'Model requires an additional label '
|
| 248 |
+
f'(with size {self.label_size}) as input, '
|
| 249 |
+
f'but no label is received!')
|
| 250 |
+
if label.ndim != 2 or label.shape != (z.shape[0], self.label_size):
|
| 251 |
+
raise ValueError(f'Input label should be with shape '
|
| 252 |
+
f'[batch_size, label_size], where '
|
| 253 |
+
f'`batch_size` equals to that of '
|
| 254 |
+
f'latent codes ({z.shape[0]}) and '
|
| 255 |
+
f'`label_size` equals to {self.label_size}!\n'
|
| 256 |
+
f'But `{label.shape}` is received!')
|
| 257 |
+
embedding = torch.matmul(label, self.label_weight)
|
| 258 |
+
z = torch.cat((z, embedding), dim=1)
|
| 259 |
+
|
| 260 |
+
z = self.norm(z)
|
| 261 |
+
w = z
|
| 262 |
+
for i in range(self.num_layers):
|
| 263 |
+
w = self.__getattr__(f'dense{i}')(w)
|
| 264 |
+
results = {
|
| 265 |
+
'z': z,
|
| 266 |
+
'label': label,
|
| 267 |
+
'w': w,
|
| 268 |
+
}
|
| 269 |
+
if self.label_size:
|
| 270 |
+
results['embedding'] = embedding
|
| 271 |
+
return results
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class TruncationModule(nn.Module):
|
| 275 |
+
"""Implements the truncation module.
|
| 276 |
+
|
| 277 |
+
Truncation is executed as follows:
|
| 278 |
+
|
| 279 |
+
For layers in range [0, truncation_layers), the truncated w-code is computed
|
| 280 |
+
as
|
| 281 |
+
|
| 282 |
+
w_new = w_avg + (w - w_avg) * truncation_psi
|
| 283 |
+
|
| 284 |
+
To disable truncation, please set
|
| 285 |
+
(1) truncation_psi = 1.0 (None) OR
|
| 286 |
+
(2) truncation_layers = 0 (None)
|
| 287 |
+
|
| 288 |
+
NOTE: The returned tensor is layer-wise style codes.
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
def __init__(self, w_space_dim, num_layers, repeat_w=True):
|
| 292 |
+
super().__init__()
|
| 293 |
+
|
| 294 |
+
self.num_layers = num_layers
|
| 295 |
+
self.w_space_dim = w_space_dim
|
| 296 |
+
self.repeat_w = repeat_w
|
| 297 |
+
|
| 298 |
+
if self.repeat_w:
|
| 299 |
+
self.register_buffer('w_avg', torch.zeros(w_space_dim))
|
| 300 |
+
else:
|
| 301 |
+
self.register_buffer('w_avg', torch.zeros(num_layers * w_space_dim))
|
| 302 |
+
self.pth_to_tf_var_mapping = {'w_avg': 'dlatent_avg'}
|
| 303 |
+
|
| 304 |
+
def forward(self, w, trunc_psi=None, trunc_layers=None):
|
| 305 |
+
if w.ndim == 2:
|
| 306 |
+
if self.repeat_w and w.shape[1] == self.w_space_dim:
|
| 307 |
+
w = w.view(-1, 1, self.w_space_dim)
|
| 308 |
+
wp = w.repeat(1, self.num_layers, 1)
|
| 309 |
+
else:
|
| 310 |
+
assert w.shape[1] == self.w_space_dim * self.num_layers
|
| 311 |
+
wp = w.view(-1, self.num_layers, self.w_space_dim)
|
| 312 |
+
else:
|
| 313 |
+
wp = w
|
| 314 |
+
assert wp.ndim == 3
|
| 315 |
+
assert wp.shape[1:] == (self.num_layers, self.w_space_dim)
|
| 316 |
+
|
| 317 |
+
trunc_psi = 1.0 if trunc_psi is None else trunc_psi
|
| 318 |
+
trunc_layers = 0 if trunc_layers is None else trunc_layers
|
| 319 |
+
if trunc_psi < 1.0 and trunc_layers > 0:
|
| 320 |
+
layer_idx = np.arange(self.num_layers).reshape(1, -1, 1)
|
| 321 |
+
coefs = np.ones_like(layer_idx, dtype=np.float32)
|
| 322 |
+
coefs[layer_idx < trunc_layers] *= trunc_psi
|
| 323 |
+
coefs = torch.from_numpy(coefs).to(wp)
|
| 324 |
+
w_avg = self.w_avg.view(1, -1, self.w_space_dim)
|
| 325 |
+
wp = w_avg + (wp - w_avg) * coefs
|
| 326 |
+
return wp
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class SynthesisModule(nn.Module):
|
| 330 |
+
"""Implements the image synthesis module.
|
| 331 |
+
|
| 332 |
+
Basically, this module executes several convolutional layers in sequence.
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
def __init__(self,
|
| 336 |
+
resolution=1024,
|
| 337 |
+
init_resolution=4,
|
| 338 |
+
w_space_dim=512,
|
| 339 |
+
image_channels=3,
|
| 340 |
+
final_tanh=False,
|
| 341 |
+
const_input=True,
|
| 342 |
+
fused_scale='auto',
|
| 343 |
+
use_wscale=True,
|
| 344 |
+
fmaps_base=16 << 10,
|
| 345 |
+
fmaps_max=512):
|
| 346 |
+
super().__init__()
|
| 347 |
+
|
| 348 |
+
self.init_res = init_resolution
|
| 349 |
+
self.init_res_log2 = int(np.log2(self.init_res))
|
| 350 |
+
self.resolution = resolution
|
| 351 |
+
self.final_res_log2 = int(np.log2(self.resolution))
|
| 352 |
+
self.w_space_dim = w_space_dim
|
| 353 |
+
self.image_channels = image_channels
|
| 354 |
+
self.final_tanh = final_tanh
|
| 355 |
+
self.const_input = const_input
|
| 356 |
+
self.fused_scale = fused_scale
|
| 357 |
+
self.use_wscale = use_wscale
|
| 358 |
+
self.fmaps_base = fmaps_base
|
| 359 |
+
self.fmaps_max = fmaps_max
|
| 360 |
+
|
| 361 |
+
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
|
| 362 |
+
|
| 363 |
+
# Level of detail (used for progressive training).
|
| 364 |
+
self.register_buffer('lod', torch.zeros(()))
|
| 365 |
+
self.pth_to_tf_var_mapping = {'lod': 'lod'}
|
| 366 |
+
|
| 367 |
+
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
|
| 368 |
+
res = 2 ** res_log2
|
| 369 |
+
block_idx = res_log2 - self.init_res_log2
|
| 370 |
+
|
| 371 |
+
# First convolution layer for each resolution.
|
| 372 |
+
layer_name = f'layer{2 * block_idx}'
|
| 373 |
+
if res == self.init_res:
|
| 374 |
+
if self.const_input:
|
| 375 |
+
self.add_module(layer_name,
|
| 376 |
+
ConvBlock(in_channels=self.get_nf(res),
|
| 377 |
+
out_channels=self.get_nf(res),
|
| 378 |
+
resolution=self.init_res,
|
| 379 |
+
w_space_dim=self.w_space_dim,
|
| 380 |
+
position='const_init',
|
| 381 |
+
use_wscale=self.use_wscale))
|
| 382 |
+
tf_layer_name = 'Const'
|
| 383 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.const'] = (
|
| 384 |
+
f'{res}x{res}/{tf_layer_name}/const')
|
| 385 |
+
else:
|
| 386 |
+
self.add_module(layer_name,
|
| 387 |
+
ConvBlock(in_channels=self.w_space_dim,
|
| 388 |
+
out_channels=self.get_nf(res),
|
| 389 |
+
resolution=self.init_res,
|
| 390 |
+
w_space_dim=self.w_space_dim,
|
| 391 |
+
kernel_size=self.init_res,
|
| 392 |
+
padding=self.init_res - 1,
|
| 393 |
+
use_wscale=self.use_wscale))
|
| 394 |
+
tf_layer_name = 'Dense'
|
| 395 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
| 396 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
| 397 |
+
else:
|
| 398 |
+
if self.fused_scale == 'auto':
|
| 399 |
+
fused_scale = (res >= _AUTO_FUSED_SCALE_MIN_RES)
|
| 400 |
+
else:
|
| 401 |
+
fused_scale = self.fused_scale
|
| 402 |
+
self.add_module(layer_name,
|
| 403 |
+
ConvBlock(in_channels=self.get_nf(res // 2),
|
| 404 |
+
out_channels=self.get_nf(res),
|
| 405 |
+
resolution=res,
|
| 406 |
+
w_space_dim=self.w_space_dim,
|
| 407 |
+
upsample=True,
|
| 408 |
+
fused_scale=fused_scale,
|
| 409 |
+
use_wscale=self.use_wscale))
|
| 410 |
+
tf_layer_name = 'Conv0_up'
|
| 411 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
| 412 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
| 413 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
| 414 |
+
f'{res}x{res}/{tf_layer_name}/bias')
|
| 415 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
|
| 416 |
+
f'{res}x{res}/{tf_layer_name}/StyleMod/weight')
|
| 417 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
|
| 418 |
+
f'{res}x{res}/{tf_layer_name}/StyleMod/bias')
|
| 419 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.apply_noise.weight'] = (
|
| 420 |
+
f'{res}x{res}/{tf_layer_name}/Noise/weight')
|
| 421 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.apply_noise.noise'] = (
|
| 422 |
+
f'noise{2 * block_idx}')
|
| 423 |
+
|
| 424 |
+
# Second convolution layer for each resolution.
|
| 425 |
+
layer_name = f'layer{2 * block_idx + 1}'
|
| 426 |
+
self.add_module(layer_name,
|
| 427 |
+
ConvBlock(in_channels=self.get_nf(res),
|
| 428 |
+
out_channels=self.get_nf(res),
|
| 429 |
+
resolution=res,
|
| 430 |
+
w_space_dim=self.w_space_dim,
|
| 431 |
+
use_wscale=self.use_wscale))
|
| 432 |
+
tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
|
| 433 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
| 434 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
| 435 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
| 436 |
+
f'{res}x{res}/{tf_layer_name}/bias')
|
| 437 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
|
| 438 |
+
f'{res}x{res}/{tf_layer_name}/StyleMod/weight')
|
| 439 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
|
| 440 |
+
f'{res}x{res}/{tf_layer_name}/StyleMod/bias')
|
| 441 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.apply_noise.weight'] = (
|
| 442 |
+
f'{res}x{res}/{tf_layer_name}/Noise/weight')
|
| 443 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.apply_noise.noise'] = (
|
| 444 |
+
f'noise{2 * block_idx + 1}')
|
| 445 |
+
|
| 446 |
+
# Output convolution layer for each resolution.
|
| 447 |
+
self.add_module(f'output{block_idx}',
|
| 448 |
+
ConvBlock(in_channels=self.get_nf(res),
|
| 449 |
+
out_channels=self.image_channels,
|
| 450 |
+
resolution=res,
|
| 451 |
+
w_space_dim=self.w_space_dim,
|
| 452 |
+
position='last',
|
| 453 |
+
kernel_size=1,
|
| 454 |
+
padding=0,
|
| 455 |
+
use_wscale=self.use_wscale,
|
| 456 |
+
wscale_gain=1.0,
|
| 457 |
+
activation_type='linear'))
|
| 458 |
+
self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = (
|
| 459 |
+
f'ToRGB_lod{self.final_res_log2 - res_log2}/weight')
|
| 460 |
+
self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = (
|
| 461 |
+
f'ToRGB_lod{self.final_res_log2 - res_log2}/bias')
|
| 462 |
+
|
| 463 |
+
self.upsample = UpsamplingLayer()
|
| 464 |
+
self.final_activate = nn.Tanh() if final_tanh else nn.Identity()
|
| 465 |
+
|
| 466 |
+
def get_nf(self, res):
|
| 467 |
+
"""Gets number of feature maps according to current resolution."""
|
| 468 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
| 469 |
+
|
| 470 |
+
def forward(self, wp, lod=None, randomize_noise=False):
|
| 471 |
+
if wp.ndim != 3 or wp.shape[1:] != (self.num_layers, self.w_space_dim):
|
| 472 |
+
raise ValueError(f'Input tensor should be with shape '
|
| 473 |
+
f'[batch_size, num_layers, w_space_dim], where '
|
| 474 |
+
f'`num_layers` equals to {self.num_layers}, and '
|
| 475 |
+
f'`w_space_dim` equals to {self.w_space_dim}!\n'
|
| 476 |
+
f'But `{wp.shape}` is received!')
|
| 477 |
+
|
| 478 |
+
lod = self.lod.cpu().tolist() if lod is None else lod
|
| 479 |
+
if lod + self.init_res_log2 > self.final_res_log2:
|
| 480 |
+
raise ValueError(f'Maximum level-of-detail (lod) is '
|
| 481 |
+
f'{self.final_res_log2 - self.init_res_log2}, '
|
| 482 |
+
f'but `{lod}` is received!')
|
| 483 |
+
|
| 484 |
+
results = {'wp': wp}
|
| 485 |
+
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
|
| 486 |
+
current_lod = self.final_res_log2 - res_log2
|
| 487 |
+
if lod < current_lod + 1:
|
| 488 |
+
block_idx = res_log2 - self.init_res_log2
|
| 489 |
+
if block_idx == 0:
|
| 490 |
+
if self.const_input:
|
| 491 |
+
x, style = self.layer0(None, wp[:, 0], randomize_noise)
|
| 492 |
+
else:
|
| 493 |
+
x = wp[:, 0].view(-1, self.w_space_dim, 1, 1)
|
| 494 |
+
x, style = self.layer0(x, wp[:, 0], randomize_noise)
|
| 495 |
+
else:
|
| 496 |
+
x, style = self.__getattr__(f'layer{2 * block_idx}')(
|
| 497 |
+
x, wp[:, 2 * block_idx])
|
| 498 |
+
results[f'style{2 * block_idx:02d}'] = style
|
| 499 |
+
x, style = self.__getattr__(f'layer{2 * block_idx + 1}')(
|
| 500 |
+
x, wp[:, 2 * block_idx + 1])
|
| 501 |
+
results[f'style{2 * block_idx + 1:02d}'] = style
|
| 502 |
+
if current_lod - 1 < lod <= current_lod:
|
| 503 |
+
image = self.__getattr__(f'output{block_idx}')(x, None)
|
| 504 |
+
elif current_lod < lod < current_lod + 1:
|
| 505 |
+
alpha = np.ceil(lod) - lod
|
| 506 |
+
image = (self.__getattr__(f'output{block_idx}')(x, None) * alpha
|
| 507 |
+
+ self.upsample(image) * (1 - alpha))
|
| 508 |
+
elif lod >= current_lod + 1:
|
| 509 |
+
image = self.upsample(image)
|
| 510 |
+
results['image'] = self.final_activate(image)
|
| 511 |
+
return results
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
class PixelNormLayer(nn.Module):
|
| 515 |
+
"""Implements pixel-wise feature vector normalization layer."""
|
| 516 |
+
|
| 517 |
+
def __init__(self, epsilon=1e-8):
|
| 518 |
+
super().__init__()
|
| 519 |
+
self.eps = epsilon
|
| 520 |
+
|
| 521 |
+
def forward(self, x):
|
| 522 |
+
norm = torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.eps)
|
| 523 |
+
return x / norm
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
class InstanceNormLayer(nn.Module):
|
| 527 |
+
"""Implements instance normalization layer."""
|
| 528 |
+
|
| 529 |
+
def __init__(self, epsilon=1e-8):
|
| 530 |
+
super().__init__()
|
| 531 |
+
self.eps = epsilon
|
| 532 |
+
|
| 533 |
+
def forward(self, x):
|
| 534 |
+
if x.ndim != 4:
|
| 535 |
+
raise ValueError(f'The input tensor should be with shape '
|
| 536 |
+
f'[batch_size, channel, height, width], '
|
| 537 |
+
f'but `{x.shape}` is received!')
|
| 538 |
+
x = x - torch.mean(x, dim=[2, 3], keepdim=True)
|
| 539 |
+
norm = torch.sqrt(
|
| 540 |
+
torch.mean(x ** 2, dim=[2, 3], keepdim=True) + self.eps)
|
| 541 |
+
return x / norm
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
class UpsamplingLayer(nn.Module):
|
| 545 |
+
"""Implements the upsampling layer.
|
| 546 |
+
|
| 547 |
+
Basically, this layer can be used to upsample feature maps with nearest
|
| 548 |
+
neighbor interpolation.
|
| 549 |
+
"""
|
| 550 |
+
|
| 551 |
+
def __init__(self, scale_factor=2):
|
| 552 |
+
super().__init__()
|
| 553 |
+
self.scale_factor = scale_factor
|
| 554 |
+
|
| 555 |
+
def forward(self, x):
|
| 556 |
+
if self.scale_factor <= 1:
|
| 557 |
+
return x
|
| 558 |
+
return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
class Blur(torch.autograd.Function):
|
| 562 |
+
"""Defines blur operation with customized gradient computation."""
|
| 563 |
+
|
| 564 |
+
@staticmethod
|
| 565 |
+
def forward(ctx, x, kernel):
|
| 566 |
+
ctx.save_for_backward(kernel)
|
| 567 |
+
y = F.conv2d(input=x,
|
| 568 |
+
weight=kernel,
|
| 569 |
+
bias=None,
|
| 570 |
+
stride=1,
|
| 571 |
+
padding=1,
|
| 572 |
+
groups=x.shape[1])
|
| 573 |
+
return y
|
| 574 |
+
|
| 575 |
+
@staticmethod
|
| 576 |
+
def backward(ctx, dy):
|
| 577 |
+
kernel, = ctx.saved_tensors
|
| 578 |
+
dx = F.conv2d(input=dy,
|
| 579 |
+
weight=kernel.flip((2, 3)),
|
| 580 |
+
bias=None,
|
| 581 |
+
stride=1,
|
| 582 |
+
padding=1,
|
| 583 |
+
groups=dy.shape[1])
|
| 584 |
+
return dx, None, None
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
class BlurLayer(nn.Module):
|
| 588 |
+
"""Implements the blur layer."""
|
| 589 |
+
|
| 590 |
+
def __init__(self,
|
| 591 |
+
channels,
|
| 592 |
+
kernel=(1, 2, 1),
|
| 593 |
+
normalize=True):
|
| 594 |
+
super().__init__()
|
| 595 |
+
kernel = np.array(kernel, dtype=np.float32).reshape(1, -1)
|
| 596 |
+
kernel = kernel.T.dot(kernel)
|
| 597 |
+
if normalize:
|
| 598 |
+
kernel /= np.sum(kernel)
|
| 599 |
+
kernel = kernel[np.newaxis, np.newaxis]
|
| 600 |
+
kernel = np.tile(kernel, [channels, 1, 1, 1])
|
| 601 |
+
self.register_buffer('kernel', torch.from_numpy(kernel))
|
| 602 |
+
|
| 603 |
+
def forward(self, x):
|
| 604 |
+
return Blur.apply(x, self.kernel)
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
class NoiseApplyingLayer(nn.Module):
|
| 608 |
+
"""Implements the noise applying layer."""
|
| 609 |
+
|
| 610 |
+
def __init__(self, resolution, channels):
|
| 611 |
+
super().__init__()
|
| 612 |
+
self.res = resolution
|
| 613 |
+
self.register_buffer('noise', torch.randn(1, 1, self.res, self.res))
|
| 614 |
+
self.weight = nn.Parameter(torch.zeros(channels))
|
| 615 |
+
|
| 616 |
+
def forward(self, x, randomize_noise=False):
|
| 617 |
+
if x.ndim != 4:
|
| 618 |
+
raise ValueError(f'The input tensor should be with shape '
|
| 619 |
+
f'[batch_size, channel, height, width], '
|
| 620 |
+
f'but `{x.shape}` is received!')
|
| 621 |
+
if randomize_noise:
|
| 622 |
+
noise = torch.randn(x.shape[0], 1, self.res, self.res).to(x)
|
| 623 |
+
else:
|
| 624 |
+
noise = self.noise
|
| 625 |
+
return x + noise * self.weight.view(1, -1, 1, 1)
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
class StyleModLayer(nn.Module):
|
| 629 |
+
"""Implements the style modulation layer."""
|
| 630 |
+
|
| 631 |
+
def __init__(self,
|
| 632 |
+
w_space_dim,
|
| 633 |
+
out_channels,
|
| 634 |
+
use_wscale=True):
|
| 635 |
+
super().__init__()
|
| 636 |
+
self.w_space_dim = w_space_dim
|
| 637 |
+
self.out_channels = out_channels
|
| 638 |
+
|
| 639 |
+
weight_shape = (self.out_channels * 2, self.w_space_dim)
|
| 640 |
+
wscale = _STYLEMOD_WSCALE_GAIN / np.sqrt(self.w_space_dim)
|
| 641 |
+
if use_wscale:
|
| 642 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape))
|
| 643 |
+
self.wscale = wscale
|
| 644 |
+
else:
|
| 645 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
|
| 646 |
+
self.wscale = 1.0
|
| 647 |
+
|
| 648 |
+
self.bias = nn.Parameter(torch.zeros(self.out_channels * 2))
|
| 649 |
+
|
| 650 |
+
def forward(self, x, w):
|
| 651 |
+
if w.ndim != 2 or w.shape[1] != self.w_space_dim:
|
| 652 |
+
raise ValueError(f'The input tensor should be with shape '
|
| 653 |
+
f'[batch_size, w_space_dim], where '
|
| 654 |
+
f'`w_space_dim` equals to {self.w_space_dim}!\n'
|
| 655 |
+
f'But `{w.shape}` is received!')
|
| 656 |
+
style = F.linear(w, weight=self.weight * self.wscale, bias=self.bias)
|
| 657 |
+
style_split = style.view(-1, 2, self.out_channels, 1, 1)
|
| 658 |
+
x = x * (style_split[:, 0] + 1) + style_split[:, 1]
|
| 659 |
+
return x, style
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
class ConvBlock(nn.Module):
|
| 663 |
+
"""Implements the normal convolutional block.
|
| 664 |
+
|
| 665 |
+
Basically, this block executes upsampling layer (if needed), convolutional
|
| 666 |
+
layer, blurring layer, noise applying layer, activation layer, instance
|
| 667 |
+
normalization layer, and style modulation layer in sequence.
|
| 668 |
+
"""
|
| 669 |
+
|
| 670 |
+
def __init__(self,
|
| 671 |
+
in_channels,
|
| 672 |
+
out_channels,
|
| 673 |
+
resolution,
|
| 674 |
+
w_space_dim,
|
| 675 |
+
position=None,
|
| 676 |
+
kernel_size=3,
|
| 677 |
+
stride=1,
|
| 678 |
+
padding=1,
|
| 679 |
+
add_bias=True,
|
| 680 |
+
upsample=False,
|
| 681 |
+
fused_scale=False,
|
| 682 |
+
use_wscale=True,
|
| 683 |
+
wscale_gain=_WSCALE_GAIN,
|
| 684 |
+
lr_mul=1.0,
|
| 685 |
+
activation_type='lrelu'):
|
| 686 |
+
"""Initializes with block settings.
|
| 687 |
+
|
| 688 |
+
Args:
|
| 689 |
+
in_channels: Number of channels of the input tensor.
|
| 690 |
+
out_channels: Number of channels of the output tensor.
|
| 691 |
+
resolution: Resolution of the output tensor.
|
| 692 |
+
w_space_dim: Dimension of W space for style modulation.
|
| 693 |
+
position: Position of the layer. `const_init`, `last` would lead to
|
| 694 |
+
different behavior. (default: None)
|
| 695 |
+
kernel_size: Size of the convolutional kernels. (default: 3)
|
| 696 |
+
stride: Stride parameter for convolution operation. (default: 1)
|
| 697 |
+
padding: Padding parameter for convolution operation. (default: 1)
|
| 698 |
+
add_bias: Whether to add bias onto the convolutional result.
|
| 699 |
+
(default: True)
|
| 700 |
+
upsample: Whether to upsample the input tensor before convolution.
|
| 701 |
+
(default: False)
|
| 702 |
+
fused_scale: Whether to fused `upsample` and `conv2d` together,
|
| 703 |
+
resulting in `conv2d_transpose`. (default: False)
|
| 704 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
| 705 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
| 706 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
| 707 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
| 708 |
+
(default: `lrelu`)
|
| 709 |
+
|
| 710 |
+
Raises:
|
| 711 |
+
NotImplementedError: If the `activation_type` is not supported.
|
| 712 |
+
"""
|
| 713 |
+
super().__init__()
|
| 714 |
+
|
| 715 |
+
self.position = position
|
| 716 |
+
|
| 717 |
+
if add_bias:
|
| 718 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
| 719 |
+
self.bscale = lr_mul
|
| 720 |
+
else:
|
| 721 |
+
self.bias = None
|
| 722 |
+
|
| 723 |
+
if activation_type == 'linear':
|
| 724 |
+
self.activate = nn.Identity()
|
| 725 |
+
elif activation_type == 'lrelu':
|
| 726 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 727 |
+
else:
|
| 728 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
| 729 |
+
f'`{activation_type}`!')
|
| 730 |
+
|
| 731 |
+
if self.position != 'last':
|
| 732 |
+
self.apply_noise = NoiseApplyingLayer(resolution, out_channels)
|
| 733 |
+
self.normalize = InstanceNormLayer()
|
| 734 |
+
self.style = StyleModLayer(w_space_dim, out_channels, use_wscale)
|
| 735 |
+
|
| 736 |
+
if self.position == 'const_init':
|
| 737 |
+
self.const = nn.Parameter(
|
| 738 |
+
torch.ones(1, in_channels, resolution, resolution))
|
| 739 |
+
return
|
| 740 |
+
|
| 741 |
+
self.blur = BlurLayer(out_channels) if upsample else nn.Identity()
|
| 742 |
+
|
| 743 |
+
if upsample and not fused_scale:
|
| 744 |
+
self.upsample = UpsamplingLayer()
|
| 745 |
+
else:
|
| 746 |
+
self.upsample = nn.Identity()
|
| 747 |
+
|
| 748 |
+
if upsample and fused_scale:
|
| 749 |
+
self.use_conv2d_transpose = True
|
| 750 |
+
self.stride = 2
|
| 751 |
+
self.padding = 1
|
| 752 |
+
else:
|
| 753 |
+
self.use_conv2d_transpose = False
|
| 754 |
+
self.stride = stride
|
| 755 |
+
self.padding = padding
|
| 756 |
+
|
| 757 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
| 758 |
+
fan_in = kernel_size * kernel_size * in_channels
|
| 759 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
| 760 |
+
if use_wscale:
|
| 761 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
| 762 |
+
self.wscale = wscale * lr_mul
|
| 763 |
+
else:
|
| 764 |
+
self.weight = nn.Parameter(
|
| 765 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
| 766 |
+
self.wscale = lr_mul
|
| 767 |
+
|
| 768 |
+
def forward(self, x, w, randomize_noise=False):
|
| 769 |
+
if self.position != 'const_init':
|
| 770 |
+
x = self.upsample(x)
|
| 771 |
+
weight = self.weight * self.wscale
|
| 772 |
+
if self.use_conv2d_transpose:
|
| 773 |
+
weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0)
|
| 774 |
+
weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
|
| 775 |
+
weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1])
|
| 776 |
+
weight = weight.permute(1, 0, 2, 3)
|
| 777 |
+
x = F.conv_transpose2d(x,
|
| 778 |
+
weight=weight,
|
| 779 |
+
bias=None,
|
| 780 |
+
stride=self.stride,
|
| 781 |
+
padding=self.padding)
|
| 782 |
+
else:
|
| 783 |
+
x = F.conv2d(x,
|
| 784 |
+
weight=weight,
|
| 785 |
+
bias=None,
|
| 786 |
+
stride=self.stride,
|
| 787 |
+
padding=self.padding)
|
| 788 |
+
x = self.blur(x)
|
| 789 |
+
else:
|
| 790 |
+
x = self.const.repeat(w.shape[0], 1, 1, 1)
|
| 791 |
+
|
| 792 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
| 793 |
+
|
| 794 |
+
if self.position == 'last':
|
| 795 |
+
if bias is not None:
|
| 796 |
+
x = x + bias.view(1, -1, 1, 1)
|
| 797 |
+
return x
|
| 798 |
+
|
| 799 |
+
x = self.apply_noise(x, randomize_noise)
|
| 800 |
+
if bias is not None:
|
| 801 |
+
x = x + bias.view(1, -1, 1, 1)
|
| 802 |
+
x = self.activate(x)
|
| 803 |
+
x = self.normalize(x)
|
| 804 |
+
x, style = self.style(x, w)
|
| 805 |
+
return x, style
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
class DenseBlock(nn.Module):
|
| 809 |
+
"""Implements the dense block.
|
| 810 |
+
|
| 811 |
+
Basically, this block executes fully-connected layer and activation layer.
|
| 812 |
+
"""
|
| 813 |
+
|
| 814 |
+
def __init__(self,
|
| 815 |
+
in_channels,
|
| 816 |
+
out_channels,
|
| 817 |
+
add_bias=True,
|
| 818 |
+
use_wscale=True,
|
| 819 |
+
wscale_gain=_WSCALE_GAIN,
|
| 820 |
+
lr_mul=1.0,
|
| 821 |
+
activation_type='lrelu'):
|
| 822 |
+
"""Initializes with block settings.
|
| 823 |
+
|
| 824 |
+
Args:
|
| 825 |
+
in_channels: Number of channels of the input tensor.
|
| 826 |
+
out_channels: Number of channels of the output tensor.
|
| 827 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
| 828 |
+
(default: True)
|
| 829 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
| 830 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
| 831 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
| 832 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
| 833 |
+
(default: `lrelu`)
|
| 834 |
+
|
| 835 |
+
Raises:
|
| 836 |
+
NotImplementedError: If the `activation_type` is not supported.
|
| 837 |
+
"""
|
| 838 |
+
super().__init__()
|
| 839 |
+
weight_shape = (out_channels, in_channels)
|
| 840 |
+
wscale = wscale_gain / np.sqrt(in_channels)
|
| 841 |
+
if use_wscale:
|
| 842 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
| 843 |
+
self.wscale = wscale * lr_mul
|
| 844 |
+
else:
|
| 845 |
+
self.weight = nn.Parameter(
|
| 846 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
| 847 |
+
self.wscale = lr_mul
|
| 848 |
+
|
| 849 |
+
if add_bias:
|
| 850 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
| 851 |
+
self.bscale = lr_mul
|
| 852 |
+
else:
|
| 853 |
+
self.bias = None
|
| 854 |
+
|
| 855 |
+
if activation_type == 'linear':
|
| 856 |
+
self.activate = nn.Identity()
|
| 857 |
+
elif activation_type == 'lrelu':
|
| 858 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 859 |
+
else:
|
| 860 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
| 861 |
+
f'`{activation_type}`!')
|
| 862 |
+
|
| 863 |
+
def forward(self, x):
|
| 864 |
+
if x.ndim != 2:
|
| 865 |
+
x = x.view(x.shape[0], -1)
|
| 866 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
| 867 |
+
x = F.linear(x, weight=self.weight * self.wscale, bias=bias)
|
| 868 |
+
x = self.activate(x)
|
| 869 |
+
return x
|
models/sync_op.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3.7
|
| 2 |
+
"""Contains the synchronizing operator."""
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
|
| 7 |
+
__all__ = ['all_gather']
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def all_gather(tensor):
|
| 11 |
+
"""Gathers tensor from all devices and does averaging."""
|
| 12 |
+
if not dist.is_initialized():
|
| 13 |
+
return tensor
|
| 14 |
+
|
| 15 |
+
world_size = dist.get_world_size()
|
| 16 |
+
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
|
| 17 |
+
dist.all_gather(tensor_list, tensor, async_op=False)
|
| 18 |
+
return torch.mean(torch.stack(tensor_list, dim=0), dim=0)
|
sefa.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SeFa."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from models import parse_gan_type
|
| 11 |
+
from utils import to_tensor
|
| 12 |
+
from utils import postprocess
|
| 13 |
+
from utils import load_generator
|
| 14 |
+
from utils import factorize_weight
|
| 15 |
+
from utils import HtmlPageVisualizer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def parse_args():
|
| 19 |
+
"""Parses arguments."""
|
| 20 |
+
parser = argparse.ArgumentParser(
|
| 21 |
+
description='Discover semantics from the pre-trained weight.')
|
| 22 |
+
parser.add_argument('model_name', type=str,
|
| 23 |
+
help='Name to the pre-trained model.')
|
| 24 |
+
parser.add_argument('--save_dir', type=str, default='results',
|
| 25 |
+
help='Directory to save the visualization pages. '
|
| 26 |
+
'(default: %(default)s)')
|
| 27 |
+
parser.add_argument('-L', '--layer_idx', type=str, default='all',
|
| 28 |
+
help='Indices of layers to interpret. '
|
| 29 |
+
'(default: %(default)s)')
|
| 30 |
+
parser.add_argument('-N', '--num_samples', type=int, default=5,
|
| 31 |
+
help='Number of samples used for visualization. '
|
| 32 |
+
'(default: %(default)s)')
|
| 33 |
+
parser.add_argument('-K', '--num_semantics', type=int, default=5,
|
| 34 |
+
help='Number of semantic boundaries corresponding to '
|
| 35 |
+
'the top-k eigen values. (default: %(default)s)')
|
| 36 |
+
parser.add_argument('--start_distance', type=float, default=-3.0,
|
| 37 |
+
help='Start point for manipulation on each semantic. '
|
| 38 |
+
'(default: %(default)s)')
|
| 39 |
+
parser.add_argument('--end_distance', type=float, default=3.0,
|
| 40 |
+
help='Ending point for manipulation on each semantic. '
|
| 41 |
+
'(default: %(default)s)')
|
| 42 |
+
parser.add_argument('--step', type=int, default=11,
|
| 43 |
+
help='Manipulation step on each semantic. '
|
| 44 |
+
'(default: %(default)s)')
|
| 45 |
+
parser.add_argument('--viz_size', type=int, default=256,
|
| 46 |
+
help='Size of images to visualize on the HTML page. '
|
| 47 |
+
'(default: %(default)s)')
|
| 48 |
+
parser.add_argument('--trunc_psi', type=float, default=0.7,
|
| 49 |
+
help='Psi factor used for truncation. This is '
|
| 50 |
+
'particularly applicable to StyleGAN (v1/v2). '
|
| 51 |
+
'(default: %(default)s)')
|
| 52 |
+
parser.add_argument('--trunc_layers', type=int, default=8,
|
| 53 |
+
help='Number of layers to perform truncation. This is '
|
| 54 |
+
'particularly applicable to StyleGAN (v1/v2). '
|
| 55 |
+
'(default: %(default)s)')
|
| 56 |
+
parser.add_argument('--seed', type=int, default=0,
|
| 57 |
+
help='Seed for sampling. (default: %(default)s)')
|
| 58 |
+
parser.add_argument('--gpu_id', type=str, default='0',
|
| 59 |
+
help='GPU(s) to use. (default: %(default)s)')
|
| 60 |
+
return parser.parse_args()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def main():
|
| 64 |
+
"""Main function."""
|
| 65 |
+
args = parse_args()
|
| 66 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
|
| 67 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 68 |
+
|
| 69 |
+
# Factorize weights.
|
| 70 |
+
generator = load_generator(args.model_name)
|
| 71 |
+
gan_type = parse_gan_type(generator)
|
| 72 |
+
layers, boundaries, values = factorize_weight(generator, args.layer_idx)
|
| 73 |
+
|
| 74 |
+
# Set random seed.
|
| 75 |
+
np.random.seed(args.seed)
|
| 76 |
+
torch.manual_seed(args.seed)
|
| 77 |
+
|
| 78 |
+
# Prepare codes.
|
| 79 |
+
codes = torch.randn(args.num_samples, generator.z_space_dim).cuda()
|
| 80 |
+
if gan_type == 'pggan':
|
| 81 |
+
codes = generator.layer0.pixel_norm(codes)
|
| 82 |
+
elif gan_type in ['stylegan', 'stylegan2']:
|
| 83 |
+
codes = generator.mapping(codes)['w']
|
| 84 |
+
codes = generator.truncation(codes,
|
| 85 |
+
trunc_psi=args.trunc_psi,
|
| 86 |
+
trunc_layers=args.trunc_layers)
|
| 87 |
+
codes = codes.detach().cpu().numpy()
|
| 88 |
+
|
| 89 |
+
# Generate visualization pages.
|
| 90 |
+
distances = np.linspace(args.start_distance,args.end_distance, args.step)
|
| 91 |
+
num_sam = args.num_samples
|
| 92 |
+
num_sem = args.num_semantics
|
| 93 |
+
vizer_1 = HtmlPageVisualizer(num_rows=num_sem * (num_sam + 1),
|
| 94 |
+
num_cols=args.step + 1,
|
| 95 |
+
viz_size=args.viz_size)
|
| 96 |
+
vizer_2 = HtmlPageVisualizer(num_rows=num_sam * (num_sem + 1),
|
| 97 |
+
num_cols=args.step + 1,
|
| 98 |
+
viz_size=args.viz_size)
|
| 99 |
+
|
| 100 |
+
headers = [''] + [f'Distance {d:.2f}' for d in distances]
|
| 101 |
+
vizer_1.set_headers(headers)
|
| 102 |
+
vizer_2.set_headers(headers)
|
| 103 |
+
for sem_id in range(num_sem):
|
| 104 |
+
value = values[sem_id]
|
| 105 |
+
vizer_1.set_cell(sem_id * (num_sam + 1), 0,
|
| 106 |
+
text=f'Semantic {sem_id:03d}<br>({value:.3f})',
|
| 107 |
+
highlight=True)
|
| 108 |
+
for sam_id in range(num_sam):
|
| 109 |
+
vizer_1.set_cell(sem_id * (num_sam + 1) + sam_id + 1, 0,
|
| 110 |
+
text=f'Sample {sam_id:03d}')
|
| 111 |
+
for sam_id in range(num_sam):
|
| 112 |
+
vizer_2.set_cell(sam_id * (num_sem + 1), 0,
|
| 113 |
+
text=f'Sample {sam_id:03d}',
|
| 114 |
+
highlight=True)
|
| 115 |
+
for sem_id in range(num_sem):
|
| 116 |
+
value = values[sem_id]
|
| 117 |
+
vizer_2.set_cell(sam_id * (num_sem + 1) + sem_id + 1, 0,
|
| 118 |
+
text=f'Semantic {sem_id:03d}<br>({value:.3f})')
|
| 119 |
+
|
| 120 |
+
for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False):
|
| 121 |
+
code = codes[sam_id:sam_id + 1]
|
| 122 |
+
for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False):
|
| 123 |
+
boundary = boundaries[sem_id:sem_id + 1]
|
| 124 |
+
for col_id, d in enumerate(distances, start=1):
|
| 125 |
+
temp_code = code.copy()
|
| 126 |
+
if gan_type == 'pggan':
|
| 127 |
+
temp_code += boundary * d
|
| 128 |
+
image = generator(to_tensor(temp_code))['image']
|
| 129 |
+
elif gan_type in ['stylegan', 'stylegan2']:
|
| 130 |
+
temp_code[:, layers, :] += boundary * d
|
| 131 |
+
image = generator.synthesis(to_tensor(temp_code))['image']
|
| 132 |
+
image = postprocess(image)[0]
|
| 133 |
+
vizer_1.set_cell(sem_id * (num_sam + 1) + sam_id + 1, col_id,
|
| 134 |
+
image=image)
|
| 135 |
+
vizer_2.set_cell(sam_id * (num_sem + 1) + sem_id + 1, col_id,
|
| 136 |
+
image=image)
|
| 137 |
+
|
| 138 |
+
prefix = (f'{args.model_name}_'
|
| 139 |
+
f'N{num_sam}_K{num_sem}_L{args.layer_idx}_seed{args.seed}')
|
| 140 |
+
vizer_1.save(os.path.join(args.save_dir, f'{prefix}_sample_first.html'))
|
| 141 |
+
vizer_2.save(os.path.join(args.save_dir, f'{prefix}_semantic_first.html'))
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
if __name__ == '__main__':
|
| 145 |
+
main()
|
utils.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions."""
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
import os
|
| 5 |
+
import subprocess
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from models import MODEL_ZOO
|
| 12 |
+
from models import build_generator
|
| 13 |
+
from models import parse_gan_type
|
| 14 |
+
|
| 15 |
+
__all__ = ['postprocess', 'load_generator', 'factorize_weight',
|
| 16 |
+
'HtmlPageVisualizer']
|
| 17 |
+
|
| 18 |
+
CHECKPOINT_DIR = 'checkpoints'
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def to_tensor(array):
|
| 22 |
+
"""Converts a `numpy.ndarray` to `torch.Tensor`.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
array: The input array to convert.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
A `torch.Tensor` with dtype `torch.FloatTensor` on cuda device.
|
| 29 |
+
"""
|
| 30 |
+
assert isinstance(array, np.ndarray)
|
| 31 |
+
return torch.from_numpy(array).type(torch.FloatTensor).cuda()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def postprocess(images, min_val=-1.0, max_val=1.0):
|
| 35 |
+
"""Post-processes images from `torch.Tensor` to `numpy.ndarray`.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
images: A `torch.Tensor` with shape `NCHW` to process.
|
| 39 |
+
min_val: The minimum value of the input tensor. (default: -1.0)
|
| 40 |
+
max_val: The maximum value of the input tensor. (default: 1.0)
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
A `numpy.ndarray` with shape `NHWC` and pixel range [0, 255].
|
| 44 |
+
"""
|
| 45 |
+
assert isinstance(images, torch.Tensor)
|
| 46 |
+
images = images.detach().cpu().numpy()
|
| 47 |
+
images = (images - min_val) * 255 / (max_val - min_val)
|
| 48 |
+
images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
|
| 49 |
+
images = images.transpose(0, 2, 3, 1)
|
| 50 |
+
return images
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_generator(model_name):
|
| 54 |
+
"""Loads pre-trained generator.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
model_name: Name of the model. Should be a key in `models.MODEL_ZOO`.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
A generator, which is a `torch.nn.Module`, with pre-trained weights
|
| 61 |
+
loaded.
|
| 62 |
+
|
| 63 |
+
Raises:
|
| 64 |
+
KeyError: If the input `model_name` is not in `models.MODEL_ZOO`.
|
| 65 |
+
"""
|
| 66 |
+
if model_name not in MODEL_ZOO:
|
| 67 |
+
raise KeyError(f'Unknown model name `{model_name}`!')
|
| 68 |
+
|
| 69 |
+
model_config = MODEL_ZOO[model_name].copy()
|
| 70 |
+
url = model_config.pop('url') # URL to download model if needed.
|
| 71 |
+
|
| 72 |
+
# Build generator.
|
| 73 |
+
print(f'Building generator for model `{model_name}` ...')
|
| 74 |
+
generator = build_generator(**model_config)
|
| 75 |
+
print(f'Finish building generator.')
|
| 76 |
+
|
| 77 |
+
# Load pre-trained weights.
|
| 78 |
+
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
| 79 |
+
checkpoint_path = os.path.join(CHECKPOINT_DIR, model_name + '.pth')
|
| 80 |
+
print(f'Loading checkpoint from `{checkpoint_path}` ...')
|
| 81 |
+
if not os.path.exists(checkpoint_path):
|
| 82 |
+
print(f' Downloading checkpoint from `{url}` ...')
|
| 83 |
+
subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
|
| 84 |
+
print(f' Finish downloading checkpoint.')
|
| 85 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 86 |
+
if 'generator_smooth' in checkpoint:
|
| 87 |
+
generator.load_state_dict(checkpoint['generator_smooth'])
|
| 88 |
+
else:
|
| 89 |
+
generator.load_state_dict(checkpoint['generator'])
|
| 90 |
+
generator = generator.cuda()
|
| 91 |
+
generator.eval()
|
| 92 |
+
print(f'Finish loading checkpoint.')
|
| 93 |
+
return generator
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def parse_indices(obj, min_val=None, max_val=None):
|
| 97 |
+
"""Parses indices.
|
| 98 |
+
|
| 99 |
+
The input can be a list or a tuple or a string, which is either a comma
|
| 100 |
+
separated list of numbers 'a, b, c', or a dash separated range 'a - c'.
|
| 101 |
+
Space in the string will be ignored.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
obj: The input object to parse indices from.
|
| 105 |
+
min_val: If not `None`, this function will check that all indices are
|
| 106 |
+
equal to or larger than this value. (default: None)
|
| 107 |
+
max_val: If not `None`, this function will check that all indices are
|
| 108 |
+
equal to or smaller than this value. (default: None)
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
A list of integers.
|
| 112 |
+
|
| 113 |
+
Raises:
|
| 114 |
+
If the input is invalid, i.e., neither a list or tuple, nor a string.
|
| 115 |
+
"""
|
| 116 |
+
if obj is None or obj == '':
|
| 117 |
+
indices = []
|
| 118 |
+
elif isinstance(obj, int):
|
| 119 |
+
indices = [obj]
|
| 120 |
+
elif isinstance(obj, (list, tuple, np.ndarray)):
|
| 121 |
+
indices = list(obj)
|
| 122 |
+
elif isinstance(obj, str):
|
| 123 |
+
indices = []
|
| 124 |
+
splits = obj.replace(' ', '').split(',')
|
| 125 |
+
for split in splits:
|
| 126 |
+
numbers = list(map(int, split.split('-')))
|
| 127 |
+
if len(numbers) == 1:
|
| 128 |
+
indices.append(numbers[0])
|
| 129 |
+
elif len(numbers) == 2:
|
| 130 |
+
indices.extend(list(range(numbers[0], numbers[1] + 1)))
|
| 131 |
+
else:
|
| 132 |
+
raise ValueError(f'Unable to parse the input!')
|
| 133 |
+
|
| 134 |
+
else:
|
| 135 |
+
raise ValueError(f'Invalid type of input: `{type(obj)}`!')
|
| 136 |
+
|
| 137 |
+
assert isinstance(indices, list)
|
| 138 |
+
indices = sorted(list(set(indices)))
|
| 139 |
+
for idx in indices:
|
| 140 |
+
assert isinstance(idx, int)
|
| 141 |
+
if min_val is not None:
|
| 142 |
+
assert idx >= min_val, f'{idx} is smaller than min val `{min_val}`!'
|
| 143 |
+
if max_val is not None:
|
| 144 |
+
assert idx <= max_val, f'{idx} is larger than max val `{max_val}`!'
|
| 145 |
+
|
| 146 |
+
return indices
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def factorize_weight(generator, layer_idx='all'):
|
| 150 |
+
"""Factorizes the generator weight to get semantics boundaries.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
generator: Generator to factorize.
|
| 154 |
+
layer_idx: Indices of layers to interpret, especially for StyleGAN and
|
| 155 |
+
StyleGAN2. (default: `all`)
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
A tuple of (layers_to_interpret, semantic_boundaries, eigen_values).
|
| 159 |
+
|
| 160 |
+
Raises:
|
| 161 |
+
ValueError: If the generator type is not supported.
|
| 162 |
+
"""
|
| 163 |
+
# Get GAN type.
|
| 164 |
+
gan_type = parse_gan_type(generator)
|
| 165 |
+
|
| 166 |
+
# Get layers.
|
| 167 |
+
if gan_type == 'pggan':
|
| 168 |
+
layers = [0]
|
| 169 |
+
elif gan_type in ['stylegan', 'stylegan2']:
|
| 170 |
+
if layer_idx == 'all':
|
| 171 |
+
layers = list(range(generator.num_layers))
|
| 172 |
+
else:
|
| 173 |
+
layers = parse_indices(layer_idx,
|
| 174 |
+
min_val=0,
|
| 175 |
+
max_val=generator.num_layers - 1)
|
| 176 |
+
|
| 177 |
+
# Factorize semantics from weight.
|
| 178 |
+
weights = []
|
| 179 |
+
for idx in layers:
|
| 180 |
+
layer_name = f'layer{idx}'
|
| 181 |
+
if gan_type == 'stylegan2' and idx == generator.num_layers - 1:
|
| 182 |
+
layer_name = f'output{idx // 2}'
|
| 183 |
+
if gan_type == 'pggan':
|
| 184 |
+
weight = generator.__getattr__(layer_name).weight
|
| 185 |
+
weight = weight.flip(2, 3).permute(1, 0, 2, 3).flatten(1)
|
| 186 |
+
elif gan_type in ['stylegan', 'stylegan2']:
|
| 187 |
+
weight = generator.synthesis.__getattr__(layer_name).style.weight.T
|
| 188 |
+
weights.append(weight.cpu().detach().numpy())
|
| 189 |
+
weight = np.concatenate(weights, axis=1).astype(np.float32)
|
| 190 |
+
weight = weight / np.linalg.norm(weight, axis=0, keepdims=True)
|
| 191 |
+
eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T))
|
| 192 |
+
|
| 193 |
+
return layers, eigen_vectors.T, eigen_values
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def get_sortable_html_header(column_name_list, sort_by_ascending=False):
|
| 197 |
+
"""Gets header for sortable html page.
|
| 198 |
+
|
| 199 |
+
Basically, the html page contains a sortable table, where user can sort the
|
| 200 |
+
rows by a particular column by clicking the column head.
|
| 201 |
+
|
| 202 |
+
Example:
|
| 203 |
+
|
| 204 |
+
column_name_list = [name_1, name_2, name_3]
|
| 205 |
+
header = get_sortable_html_header(column_name_list)
|
| 206 |
+
footer = get_sortable_html_footer()
|
| 207 |
+
sortable_table = ...
|
| 208 |
+
html_page = header + sortable_table + footer
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
column_name_list: List of column header names.
|
| 212 |
+
sort_by_ascending: Default sorting order. If set as `True`, the html
|
| 213 |
+
page will be sorted by ascending order when the header is clicked
|
| 214 |
+
for the first time.
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
A string, which represents for the header for a sortable html page.
|
| 218 |
+
"""
|
| 219 |
+
header = '\n'.join([
|
| 220 |
+
'<script type="text/javascript">',
|
| 221 |
+
'var column_idx;',
|
| 222 |
+
'var sort_by_ascending = ' + str(sort_by_ascending).lower() + ';',
|
| 223 |
+
'',
|
| 224 |
+
'function sorting(tbody, column_idx){',
|
| 225 |
+
' this.column_idx = column_idx;',
|
| 226 |
+
' Array.from(tbody.rows)',
|
| 227 |
+
' .sort(compareCells)',
|
| 228 |
+
' .forEach(function(row) { tbody.appendChild(row); })',
|
| 229 |
+
' sort_by_ascending = !sort_by_ascending;',
|
| 230 |
+
'}',
|
| 231 |
+
'',
|
| 232 |
+
'function compareCells(row_a, row_b) {',
|
| 233 |
+
' var val_a = row_a.cells[column_idx].innerText;',
|
| 234 |
+
' var val_b = row_b.cells[column_idx].innerText;',
|
| 235 |
+
' var flag = sort_by_ascending ? 1 : -1;',
|
| 236 |
+
' return flag * (val_a > val_b ? 1 : -1);',
|
| 237 |
+
'}',
|
| 238 |
+
'</script>',
|
| 239 |
+
'',
|
| 240 |
+
'<html>',
|
| 241 |
+
'',
|
| 242 |
+
'<head>',
|
| 243 |
+
'<style>',
|
| 244 |
+
' table {',
|
| 245 |
+
' border-spacing: 0;',
|
| 246 |
+
' border: 1px solid black;',
|
| 247 |
+
' }',
|
| 248 |
+
' th {',
|
| 249 |
+
' cursor: pointer;',
|
| 250 |
+
' }',
|
| 251 |
+
' th, td {',
|
| 252 |
+
' text-align: left;',
|
| 253 |
+
' vertical-align: middle;',
|
| 254 |
+
' border-collapse: collapse;',
|
| 255 |
+
' border: 0.5px solid black;',
|
| 256 |
+
' padding: 8px;',
|
| 257 |
+
' }',
|
| 258 |
+
' tr:nth-child(even) {',
|
| 259 |
+
' background-color: #d2d2d2;',
|
| 260 |
+
' }',
|
| 261 |
+
'</style>',
|
| 262 |
+
'</head>',
|
| 263 |
+
'',
|
| 264 |
+
'<body>',
|
| 265 |
+
'',
|
| 266 |
+
'<table>',
|
| 267 |
+
'<thead>',
|
| 268 |
+
'<tr>',
|
| 269 |
+
''])
|
| 270 |
+
for idx, name in enumerate(column_name_list):
|
| 271 |
+
header += f' <th onclick="sorting(tbody, {idx})">{name}</th>\n'
|
| 272 |
+
header += '</tr>\n'
|
| 273 |
+
header += '</thead>\n'
|
| 274 |
+
header += '<tbody id="tbody">\n'
|
| 275 |
+
|
| 276 |
+
return header
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def get_sortable_html_footer():
|
| 280 |
+
"""Gets footer for sortable html page.
|
| 281 |
+
|
| 282 |
+
Check function `get_sortable_html_header()` for more details.
|
| 283 |
+
"""
|
| 284 |
+
return '</tbody>\n</table>\n\n</body>\n</html>\n'
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def parse_image_size(obj):
|
| 288 |
+
"""Parses object to a pair of image size, i.e., (width, height).
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
obj: The input object to parse image size from.
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
A two-element tuple, indicating image width and height respectively.
|
| 295 |
+
|
| 296 |
+
Raises:
|
| 297 |
+
If the input is invalid, i.e., neither a list or tuple, nor a string.
|
| 298 |
+
"""
|
| 299 |
+
if obj is None or obj == '':
|
| 300 |
+
width = height = 0
|
| 301 |
+
elif isinstance(obj, int):
|
| 302 |
+
width = height = obj
|
| 303 |
+
elif isinstance(obj, (list, tuple, np.ndarray)):
|
| 304 |
+
numbers = tuple(obj)
|
| 305 |
+
if len(numbers) == 0:
|
| 306 |
+
width = height = 0
|
| 307 |
+
elif len(numbers) == 1:
|
| 308 |
+
width = height = numbers[0]
|
| 309 |
+
elif len(numbers) == 2:
|
| 310 |
+
width = numbers[0]
|
| 311 |
+
height = numbers[1]
|
| 312 |
+
else:
|
| 313 |
+
raise ValueError(f'At most two elements for image size.')
|
| 314 |
+
elif isinstance(obj, str):
|
| 315 |
+
splits = obj.replace(' ', '').split(',')
|
| 316 |
+
numbers = tuple(map(int, splits))
|
| 317 |
+
if len(numbers) == 0:
|
| 318 |
+
width = height = 0
|
| 319 |
+
elif len(numbers) == 1:
|
| 320 |
+
width = height = numbers[0]
|
| 321 |
+
elif len(numbers) == 2:
|
| 322 |
+
width = numbers[0]
|
| 323 |
+
height = numbers[1]
|
| 324 |
+
else:
|
| 325 |
+
raise ValueError(f'At most two elements for image size.')
|
| 326 |
+
else:
|
| 327 |
+
raise ValueError(f'Invalid type of input: {type(obj)}!')
|
| 328 |
+
|
| 329 |
+
return (max(0, width), max(0, height))
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def encode_image_to_html_str(image, image_size=None):
|
| 333 |
+
"""Encodes an image to html language.
|
| 334 |
+
NOTE: Input image is always assumed to be with `RGB` channel order.
|
| 335 |
+
Args:
|
| 336 |
+
image: The input image to encode. Should be with `RGB` channel order.
|
| 337 |
+
image_size: This field is used to resize the image before encoding. `0`
|
| 338 |
+
disables resizing. (default: None)
|
| 339 |
+
Returns:
|
| 340 |
+
A string which represents the encoded image.
|
| 341 |
+
"""
|
| 342 |
+
if image is None:
|
| 343 |
+
return ''
|
| 344 |
+
|
| 345 |
+
assert image.ndim == 3 and image.shape[2] in [1, 3]
|
| 346 |
+
|
| 347 |
+
# Change channel order to `BGR`, which is opencv-friendly.
|
| 348 |
+
image = image[:, :, ::-1]
|
| 349 |
+
|
| 350 |
+
# Resize the image if needed.
|
| 351 |
+
width, height = parse_image_size(image_size)
|
| 352 |
+
if height or width:
|
| 353 |
+
height = height or image.shape[0]
|
| 354 |
+
width = width or image.shape[1]
|
| 355 |
+
image = cv2.resize(image, (width, height))
|
| 356 |
+
|
| 357 |
+
# Encode the image to html-format string.
|
| 358 |
+
encoded_image = cv2.imencode('.jpg', image)[1].tostring()
|
| 359 |
+
encoded_image_base64 = base64.b64encode(encoded_image).decode('utf-8')
|
| 360 |
+
html_str = f'<img src="data:image/jpeg;base64, {encoded_image_base64}"/>'
|
| 361 |
+
|
| 362 |
+
return html_str
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def get_grid_shape(size, row=0, col=0, is_portrait=False):
|
| 366 |
+
"""Gets the shape of a grid based on the size.
|
| 367 |
+
|
| 368 |
+
This function makes greatest effort on making the output grid square if
|
| 369 |
+
neither `row` nor `col` is set. If `is_portrait` is set as `False`, the
|
| 370 |
+
height will always be equal to or smaller than the width. For example, if
|
| 371 |
+
input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`,
|
| 372 |
+
output shape will be (3, 5). Otherwise, the height will always be equal to
|
| 373 |
+
or larger than the width.
|
| 374 |
+
|
| 375 |
+
Args:
|
| 376 |
+
size: Size (height * width) of the target grid.
|
| 377 |
+
is_portrait: Whether to return a portrait size of a landscape size.
|
| 378 |
+
(default: False)
|
| 379 |
+
|
| 380 |
+
Returns:
|
| 381 |
+
A two-element tuple, representing height and width respectively.
|
| 382 |
+
"""
|
| 383 |
+
assert isinstance(size, int)
|
| 384 |
+
assert isinstance(row, int)
|
| 385 |
+
assert isinstance(col, int)
|
| 386 |
+
if size == 0:
|
| 387 |
+
return (0, 0)
|
| 388 |
+
|
| 389 |
+
if row > 0 and col > 0 and row * col != size:
|
| 390 |
+
row = 0
|
| 391 |
+
col = 0
|
| 392 |
+
|
| 393 |
+
if row > 0 and size % row == 0:
|
| 394 |
+
return (row, size // row)
|
| 395 |
+
if col > 0 and size % col == 0:
|
| 396 |
+
return (size // col, col)
|
| 397 |
+
|
| 398 |
+
row = int(np.sqrt(size))
|
| 399 |
+
while row > 0:
|
| 400 |
+
if size % row == 0:
|
| 401 |
+
col = size // row
|
| 402 |
+
break
|
| 403 |
+
row = row - 1
|
| 404 |
+
|
| 405 |
+
return (col, row) if is_portrait else (row, col)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class HtmlPageVisualizer(object):
|
| 409 |
+
"""Defines the html page visualizer.
|
| 410 |
+
|
| 411 |
+
This class can be used to visualize image results as html page. Basically,
|
| 412 |
+
it is based on an html-format sorted table with helper functions
|
| 413 |
+
`get_sortable_html_header()`, `get_sortable_html_footer()`, and
|
| 414 |
+
`encode_image_to_html_str()`. To simplify the usage, specifying the
|
| 415 |
+
following fields are enough to create a visualization page:
|
| 416 |
+
|
| 417 |
+
(1) num_rows: Number of rows of the table (header-row exclusive).
|
| 418 |
+
(2) num_cols: Number of columns of the table.
|
| 419 |
+
(3) header contents (optional): Title of each column.
|
| 420 |
+
|
| 421 |
+
NOTE: `grid_size` can be used to assign `num_rows` and `num_cols`
|
| 422 |
+
automatically.
|
| 423 |
+
|
| 424 |
+
Example:
|
| 425 |
+
|
| 426 |
+
html = HtmlPageVisualizer(num_rows, num_cols)
|
| 427 |
+
html.set_headers([...])
|
| 428 |
+
for i in range(num_rows):
|
| 429 |
+
for j in range(num_cols):
|
| 430 |
+
html.set_cell(i, j, text=..., image=..., highlight=False)
|
| 431 |
+
html.save('visualize.html')
|
| 432 |
+
"""
|
| 433 |
+
|
| 434 |
+
def __init__(self,
|
| 435 |
+
num_rows=0,
|
| 436 |
+
num_cols=0,
|
| 437 |
+
grid_size=0,
|
| 438 |
+
is_portrait=True,
|
| 439 |
+
viz_size=None):
|
| 440 |
+
if grid_size > 0:
|
| 441 |
+
num_rows, num_cols = get_grid_shape(
|
| 442 |
+
grid_size, row=num_rows, col=num_cols, is_portrait=is_portrait)
|
| 443 |
+
assert num_rows > 0 and num_cols > 0
|
| 444 |
+
|
| 445 |
+
self.num_rows = num_rows
|
| 446 |
+
self.num_cols = num_cols
|
| 447 |
+
self.viz_size = parse_image_size(viz_size)
|
| 448 |
+
self.headers = ['' for _ in range(self.num_cols)]
|
| 449 |
+
self.cells = [[{
|
| 450 |
+
'text': '',
|
| 451 |
+
'image': '',
|
| 452 |
+
'highlight': False,
|
| 453 |
+
} for _ in range(self.num_cols)] for _ in range(self.num_rows)]
|
| 454 |
+
|
| 455 |
+
def set_header(self, col_idx, content):
|
| 456 |
+
"""Sets the content of a particular header by column index."""
|
| 457 |
+
self.headers[col_idx] = content
|
| 458 |
+
|
| 459 |
+
def set_headers(self, contents):
|
| 460 |
+
"""Sets the contents of all headers."""
|
| 461 |
+
if isinstance(contents, str):
|
| 462 |
+
contents = [contents]
|
| 463 |
+
assert isinstance(contents, (list, tuple))
|
| 464 |
+
assert len(contents) == self.num_cols
|
| 465 |
+
for col_idx, content in enumerate(contents):
|
| 466 |
+
self.set_header(col_idx, content)
|
| 467 |
+
|
| 468 |
+
def set_cell(self, row_idx, col_idx, text='', image=None, highlight=False):
|
| 469 |
+
"""Sets the content of a particular cell.
|
| 470 |
+
|
| 471 |
+
Basically, a cell contains some text as well as an image. Both text and
|
| 472 |
+
image can be empty.
|
| 473 |
+
|
| 474 |
+
Args:
|
| 475 |
+
row_idx: Row index of the cell to edit.
|
| 476 |
+
col_idx: Column index of the cell to edit.
|
| 477 |
+
text: Text to add into the target cell. (default: None)
|
| 478 |
+
image: Image to show in the target cell. Should be with `RGB`
|
| 479 |
+
channel order. (default: None)
|
| 480 |
+
highlight: Whether to highlight this cell. (default: False)
|
| 481 |
+
"""
|
| 482 |
+
self.cells[row_idx][col_idx]['text'] = text
|
| 483 |
+
self.cells[row_idx][col_idx]['image'] = encode_image_to_html_str(
|
| 484 |
+
image, self.viz_size)
|
| 485 |
+
self.cells[row_idx][col_idx]['highlight'] = bool(highlight)
|
| 486 |
+
|
| 487 |
+
def save(self, save_path):
|
| 488 |
+
"""Saves the html page."""
|
| 489 |
+
html = ''
|
| 490 |
+
for i in range(self.num_rows):
|
| 491 |
+
html += f'<tr>\n'
|
| 492 |
+
for j in range(self.num_cols):
|
| 493 |
+
text = self.cells[i][j]['text']
|
| 494 |
+
image = self.cells[i][j]['image']
|
| 495 |
+
if self.cells[i][j]['highlight']:
|
| 496 |
+
color = ' bgcolor="#FF8888"'
|
| 497 |
+
else:
|
| 498 |
+
color = ''
|
| 499 |
+
if text:
|
| 500 |
+
html += f' <td{color}>{text}<br><br>{image}</td>\n'
|
| 501 |
+
else:
|
| 502 |
+
html += f' <td{color}>{image}</td>\n'
|
| 503 |
+
html += f'</tr>\n'
|
| 504 |
+
|
| 505 |
+
header = get_sortable_html_header(self.headers)
|
| 506 |
+
footer = get_sortable_html_footer()
|
| 507 |
+
|
| 508 |
+
with open(save_path, 'w') as f:
|
| 509 |
+
f.write(header + html + footer)
|