728x90
stylegan2 Loss
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
"""Loss functions."""
import numpy as np
import tensorflow as tf
import dnnlib.tflib as tflib
from dnnlib.tflib.autosummary import autosummary
#----------------------------------------------------------------------------
# Logistic loss from the paper
# "Generative Adversarial Nets", Goodfellow et al. 2014
def G_logistic(G, D, opt, training_set, minibatch_size):
_ = opt
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
labels = training_set.get_random_labels_tf(minibatch_size)
fake_images_out = G.get_output_for(latents, labels, is_training=True)
fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
loss = -tf.nn.softplus(fake_scores_out) # log(1-sigmoid(fake_scores_out)) # pylint: disable=invalid-unary-operand-type
return loss, None
def G_logistic_ns(G, D, opt, training_set, minibatch_size):
_ = opt
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
labels = training_set.get_random_labels_tf(minibatch_size)
fake_images_out = G.get_output_for(latents, labels, is_training=True)
fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
loss = tf.nn.softplus(-fake_scores_out) # -log(sigmoid(fake_scores_out))
return loss, None
def D_logistic(G, D, opt, training_set, minibatch_size, reals, labels):
_ = opt, training_set
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
fake_images_out = G.get_output_for(latents, labels, is_training=True)
real_scores_out = D.get_output_for(reals, labels, is_training=True)
fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
real_scores_out = autosummary('Loss/scores/real', real_scores_out)
fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
loss = tf.nn.softplus(fake_scores_out) # -log(1-sigmoid(fake_scores_out))
loss += tf.nn.softplus(-real_scores_out) # -log(sigmoid(real_scores_out)) # pylint: disable=invalid-unary-operand-type
return loss, None
#----------------------------------------------------------------------------
# R1 and R2 regularizers from the paper
# "Which Training Methods for GANs do actually Converge?", Mescheder et al. 2018
def D_logistic_r1(G, D, opt, training_set, minibatch_size, reals, labels, gamma=10.0):
_ = opt, training_set
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
fake_images_out = G.get_output_for(latents, labels, is_training=True)
real_scores_out = D.get_output_for(reals, labels, is_training=True)
fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
real_scores_out = autosummary('Loss/scores/real', real_scores_out)
fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
loss = tf.nn.softplus(fake_scores_out) # -log(1-sigmoid(fake_scores_out))
loss += tf.nn.softplus(-real_scores_out) # -log(sigmoid(real_scores_out)) # pylint: disable=invalid-unary-operand-type
with tf.name_scope('GradientPenalty'):
real_grads = tf.gradients(tf.reduce_sum(real_scores_out), [reals])[0]
gradient_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1,2,3])
gradient_penalty = autosummary('Loss/gradient_penalty', gradient_penalty)
reg = gradient_penalty * (gamma * 0.5)
return loss, reg
def D_logistic_r2(G, D, opt, training_set, minibatch_size, reals, labels, gamma=10.0):
_ = opt, training_set
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
fake_images_out = G.get_output_for(latents, labels, is_training=True)
real_scores_out = D.get_output_for(reals, labels, is_training=True)
fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
real_scores_out = autosummary('Loss/scores/real', real_scores_out)
fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
loss = tf.nn.softplus(fake_scores_out) # -log(1-sigmoid(fake_scores_out))
loss += tf.nn.softplus(-real_scores_out) # -log(sigmoid(real_scores_out)) # pylint: disable=invalid-unary-operand-type
with tf.name_scope('GradientPenalty'):
fake_grads = tf.gradients(tf.reduce_sum(fake_scores_out), [fake_images_out])[0]
gradient_penalty = tf.reduce_sum(tf.square(fake_grads), axis=[1,2,3])
gradient_penalty = autosummary('Loss/gradient_penalty', gradient_penalty)
reg = gradient_penalty * (gamma * 0.5)
return loss, reg
#----------------------------------------------------------------------------
# WGAN loss from the paper
# "Wasserstein Generative Adversarial Networks", Arjovsky et al. 2017
def G_wgan(G, D, opt, training_set, minibatch_size):
_ = opt
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
labels = training_set.get_random_labels_tf(minibatch_size)
fake_images_out = G.get_output_for(latents, labels, is_training=True)
fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
loss = -fake_scores_out
return loss, None
def D_wgan(G, D, opt, training_set, minibatch_size, reals, labels, wgan_epsilon=0.001):
_ = opt, training_set
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
fake_images_out = G.get_output_for(latents, labels, is_training=True)
real_scores_out = D.get_output_for(reals, labels, is_training=True)
fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
real_scores_out = autosummary('Loss/scores/real', real_scores_out)
fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
loss = fake_scores_out - real_scores_out
with tf.name_scope('EpsilonPenalty'):
epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out))
loss += epsilon_penalty * wgan_epsilon
return loss, None
#----------------------------------------------------------------------------
# WGAN-GP loss from the paper
# "Improved Training of Wasserstein GANs", Gulrajani et al. 2017
def D_wgan_gp(G, D, opt, training_set, minibatch_size, reals, labels, wgan_lambda=10.0, wgan_epsilon=0.001, wgan_target=1.0):
_ = opt, training_set
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
fake_images_out = G.get_output_for(latents, labels, is_training=True)
real_scores_out = D.get_output_for(reals, labels, is_training=True)
fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
real_scores_out = autosummary('Loss/scores/real', real_scores_out)
fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
loss = fake_scores_out - real_scores_out
with tf.name_scope('EpsilonPenalty'):
epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out))
loss += epsilon_penalty * wgan_epsilon
with tf.name_scope('GradientPenalty'):
mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype)
mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors)
mixed_scores_out = D.get_output_for(mixed_images_out, labels, is_training=True)
mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out)
mixed_grads = tf.gradients(tf.reduce_sum(mixed_scores_out), [mixed_images_out])[0]
mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3]))
mixed_norms = autosummary('Loss/mixed_norms', mixed_norms)
gradient_penalty = tf.square(mixed_norms - wgan_target)
reg = gradient_penalty * (wgan_lambda / (wgan_target**2))
return loss, reg
#----------------------------------------------------------------------------
# Non-saturating logistic loss with path length regularizer from the paper
# "Analyzing and Improving the Image Quality of StyleGAN", Karras et al. 2019
def G_logistic_ns_pathreg(G, D, opt, training_set, minibatch_size, pl_minibatch_shrink=2, pl_decay=0.01, pl_weight=2.0):
_ = opt
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
labels = training_set.get_random_labels_tf(minibatch_size)
fake_images_out, fake_dlatents_out = G.get_output_for(latents, labels, is_training=True, return_dlatents=True)
fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
loss = tf.nn.softplus(-fake_scores_out) # -log(sigmoid(fake_scores_out))
# Path length regularization.
with tf.name_scope('PathReg'):
# Evaluate the regularization term using a smaller minibatch to conserve memory.
if pl_minibatch_shrink > 1:
pl_minibatch = minibatch_size // pl_minibatch_shrink
pl_latents = tf.random_normal([pl_minibatch] + G.input_shapes[0][1:])
pl_labels = training_set.get_random_labels_tf(pl_minibatch)
fake_images_out, fake_dlatents_out = G.get_output_for(pl_latents, pl_labels, is_training=True, return_dlatents=True)
# Compute |J*y|.
pl_noise = tf.random_normal(tf.shape(fake_images_out)) / np.sqrt(np.prod(G.output_shape[2:]))
pl_grads = tf.gradients(tf.reduce_sum(fake_images_out * pl_noise), [fake_dlatents_out])[0]
pl_lengths = tf.sqrt(tf.reduce_mean(tf.reduce_sum(tf.square(pl_grads), axis=2), axis=1))
pl_lengths = autosummary('Loss/pl_lengths', pl_lengths)
# Track exponential moving average of |J*y|.
with tf.control_dependencies(None):
pl_mean_var = tf.Variable(name='pl_mean', trainable=False, initial_value=0.0, dtype=tf.float32)
pl_mean = pl_mean_var + pl_decay * (tf.reduce_mean(pl_lengths) - pl_mean_var)
pl_update = tf.assign(pl_mean_var, pl_mean)
# Calculate (|J*y|-a)^2.
with tf.control_dependencies([pl_update]):
pl_penalty = tf.square(pl_lengths - pl_mean)
pl_penalty = autosummary('Loss/pl_penalty', pl_penalty)
# Apply weight.
#
# Note: The division in pl_noise decreases the weight by num_pixels, and the reduce_mean
# in pl_lengths decreases it by num_affine_layers. The effective weight then becomes:
#
# gamma_pl = pl_weight / num_pixels / num_affine_layers
# = 2 / (r^2) / (log2(r) * 2 - 2)
# = 1 / (r^2 * (log2(r) - 1))
# = ln(2) / (r^2 * (ln(r) - ln(2))
#
reg = pl_penalty * pl_weight
return loss, reg
#----------------------------------------------------------------------------
StyleGan2-ada loss
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Loss functions."""
import numpy as np
import tensorflow as tf
import dnnlib
import dnnlib.tflib as tflib
from dnnlib.tflib.autosummary import autosummary
#----------------------------------------------------------------------------
# Report statistic for all interested parties (AdaptiveAugment and tfevents).
def report_stat(aug, name, value):
if aug is not None:
value = aug.report_stat(name, value)
value = autosummary(name, value)
return value
#----------------------------------------------------------------------------
# Report loss terms and collect them into EasyDict.
def report_loss(aug, G_loss, D_loss, G_reg=None, D_reg=None):
assert G_loss is not None and D_loss is not None
terms = dnnlib.EasyDict(G_reg=None, D_reg=None)
terms.G_loss = report_stat(aug, 'Loss/G/loss', G_loss)
terms.D_loss = report_stat(aug, 'Loss/D/loss', D_loss)
if G_reg is not None: terms.G_reg = report_stat(aug, 'Loss/G/reg', G_reg)
if D_reg is not None: terms.D_reg = report_stat(aug, 'Loss/D/reg', D_reg)
return terms
#----------------------------------------------------------------------------
# Evaluate G and return results as EasyDict.
def eval_G(G, latents, labels, return_dlatents=False):
r = dnnlib.EasyDict()
r.args = dnnlib.EasyDict()
r.args.is_training = True
if return_dlatents:
r.args.return_dlatents = True
r.images = G.get_output_for(latents, labels, **r.args)
r.dlatents = None
if return_dlatents:
r.images, r.dlatents = r.images
return r
#----------------------------------------------------------------------------
# Evaluate D and return results as EasyDict.
def eval_D(D, aug, images, labels, report=None, augment_inputs=True, return_aux=0):
r = dnnlib.EasyDict()
r.images_aug = images
r.labels_aug = labels
if augment_inputs and aug is not None:
r.images_aug, r.labels_aug = aug.apply(r.images_aug, r.labels_aug)
r.args = dnnlib.EasyDict()
r.args.is_training = True
if aug is not None:
r.args.augment_strength = aug.get_strength_var()
if return_aux > 0:
r.args.score_size = return_aux + 1
r.scores = D.get_output_for(r.images_aug, r.labels_aug, **r.args)
r.aux = None
if return_aux:
r.aux = r.scores[:, 1:]
r.scores = r.scores[:, :1]
if report is not None:
report_ops = [
report_stat(aug, 'Loss/scores/' + report, r.scores),
report_stat(aug, 'Loss/signs/' + report, tf.sign(r.scores)),
report_stat(aug, 'Loss/squares/' + report, tf.square(r.scores)),
]
with tf.control_dependencies(report_ops):
r.scores = tf.identity(r.scores)
return r
#----------------------------------------------------------------------------
# Non-saturating logistic loss with R1 and path length regularizers, used
# in the paper "Analyzing and Improving the Image Quality of StyleGAN".
def stylegan2(G, D, aug, fake_labels, real_images, real_labels, r1_gamma=10, pl_minibatch_shrink=2, pl_decay=0.01, pl_weight=2, **_kwargs):
# Evaluate networks for the main loss.
minibatch_size = tf.shape(fake_labels)[0]
fake_latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
G_fake = eval_G(G, fake_latents, fake_labels, return_dlatents=True)
D_fake = eval_D(D, aug, G_fake.images, fake_labels, report='fake')
D_real = eval_D(D, aug, real_images, real_labels, report='real')
# Non-saturating logistic loss from "Generative Adversarial Nets".
with tf.name_scope('Loss_main'):
G_loss = tf.nn.softplus(-D_fake.scores) # -log(sigmoid(D_fake.scores)), pylint: disable=invalid-unary-operand-type
D_loss = tf.nn.softplus(D_fake.scores) # -log(1 - sigmoid(D_fake.scores))
D_loss += tf.nn.softplus(-D_real.scores) # -log(sigmoid(D_real.scores)), pylint: disable=invalid-unary-operand-type
G_reg = 0
D_reg = 0
# R1 regularizer from "Which Training Methods for GANs do actually Converge?".
if r1_gamma != 0:
with tf.name_scope('Loss_R1'):
r1_grads = tf.gradients(tf.reduce_sum(D_real.scores), [real_images])[0]
r1_penalty = tf.reduce_sum(tf.square(r1_grads), axis=[1,2,3])
r1_penalty = report_stat(aug, 'Loss/r1_penalty', r1_penalty)
D_reg += r1_penalty * (r1_gamma * 0.5)
# Path length regularizer from "Analyzing and Improving the Image Quality of StyleGAN".
if pl_weight != 0:
with tf.name_scope('Loss_PL'):
# Evaluate the regularization term using a smaller minibatch to conserve memory.
G_pl = G_fake
if pl_minibatch_shrink > 1:
pl_minibatch_size = minibatch_size // pl_minibatch_shrink
pl_latents = fake_latents[:pl_minibatch_size]
pl_labels = fake_labels[:pl_minibatch_size]
G_pl = eval_G(G, pl_latents, pl_labels, return_dlatents=True)
# Compute |J*y|.
pl_noise = tf.random_normal(tf.shape(G_pl.images)) / np.sqrt(np.prod(G.output_shape[2:]))
pl_grads = tf.gradients(tf.reduce_sum(G_pl.images * pl_noise), [G_pl.dlatents])[0]
pl_lengths = tf.sqrt(tf.reduce_mean(tf.reduce_sum(tf.square(pl_grads), axis=2), axis=1))
# Track exponential moving average of |J*y|.
with tf.control_dependencies(None):
pl_mean_var = tf.Variable(name='pl_mean', trainable=False, initial_value=0, dtype=tf.float32)
pl_mean = pl_mean_var + pl_decay * (tf.reduce_mean(pl_lengths) - pl_mean_var)
pl_update = tf.assign(pl_mean_var, pl_mean)
# Calculate (|J*y|-a)^2.
with tf.control_dependencies([pl_update]):
pl_penalty = tf.square(pl_lengths - pl_mean)
pl_penalty = report_stat(aug, 'Loss/pl_penalty', pl_penalty)
# Apply weight.
#
# Note: The division in pl_noise decreases the weight by num_pixels, and the reduce_mean
# in pl_lengths decreases it by num_affine_layers. The effective weight then becomes:
#
# gamma_pl = pl_weight / num_pixels / num_affine_layers
# = 2 / (r^2) / (log2(r) * 2 - 2)
# = 1 / (r^2 * (log2(r) - 1))
# = ln(2) / (r^2 * (ln(r) - ln(2))
#
G_reg += tf.tile(pl_penalty, [pl_minibatch_shrink]) * pl_weight
return report_loss(aug, G_loss, D_loss, G_reg, D_reg)
#----------------------------------------------------------------------------
# Hybrid loss used for comparison methods used in the paper
# "Training Generative Adversarial Networks with Limited Data".
def cmethods(G, D, aug, fake_labels, real_images, real_labels,
r1_gamma=10, r2_gamma=0,
pl_minibatch_shrink=2, pl_decay=0.01, pl_weight=2,
bcr_real_weight=0, bcr_fake_weight=0, bcr_augment=None,
zcr_gen_weight=0, zcr_dis_weight=0, zcr_noise_std=0.1,
auxrot_alpha=0, auxrot_beta=0,
**_kwargs,
):
# Evaluate networks for the main loss.
minibatch_size = tf.shape(fake_labels)[0]
fake_latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
G_fake = eval_G(G, fake_latents, fake_labels)
D_fake = eval_D(D, aug, G_fake.images, fake_labels, report='fake')
D_real = eval_D(D, aug, real_images, real_labels, report='real')
# Non-saturating logistic loss from "Generative Adversarial Nets".
with tf.name_scope('Loss_main'):
G_loss = tf.nn.softplus(-D_fake.scores) # -log(sigmoid(D_fake.scores)), pylint: disable=invalid-unary-operand-type
D_loss = tf.nn.softplus(D_fake.scores) # -log(1 - sigmoid(D_fake.scores))
D_loss += tf.nn.softplus(-D_real.scores) # -log(sigmoid(D_real.scores)), pylint: disable=invalid-unary-operand-type
G_reg = 0
D_reg = 0
# R1 and R2 regularizers from "Which Training Methods for GANs do actually Converge?".
if r1_gamma != 0 or r2_gamma != 0:
with tf.name_scope('Loss_R1R2'):
if r1_gamma != 0:
r1_grads = tf.gradients(tf.reduce_sum(D_real.scores), [real_images])[0]
r1_penalty = tf.reduce_sum(tf.square(r1_grads), axis=[1,2,3])
r1_penalty = report_stat(aug, 'Loss/r1_penalty', r1_penalty)
D_reg += r1_penalty * (r1_gamma * 0.5)
if r2_gamma != 0:
r2_grads = tf.gradients(tf.reduce_sum(D_fake.scores), [G_fake.images])[0]
r2_penalty = tf.reduce_sum(tf.square(r2_grads), axis=[1,2,3])
r2_penalty = report_stat(aug, 'Loss/r2_penalty', r2_penalty)
D_reg += r2_penalty * (r2_gamma * 0.5)
# Path length regularizer from "Analyzing and Improving the Image Quality of StyleGAN".
if pl_weight != 0:
with tf.name_scope('Loss_PL'):
pl_minibatch_size = minibatch_size // pl_minibatch_shrink
pl_latents = fake_latents[:pl_minibatch_size]
pl_labels = fake_labels[:pl_minibatch_size]
G_pl = eval_G(G, pl_latents, pl_labels, return_dlatents=True)
pl_noise = tf.random_normal(tf.shape(G_pl.images)) / np.sqrt(np.prod(G.output_shape[2:]))
pl_grads = tf.gradients(tf.reduce_sum(G_pl.images * pl_noise), [G_pl.dlatents])[0]
pl_lengths = tf.sqrt(tf.reduce_mean(tf.reduce_sum(tf.square(pl_grads), axis=2), axis=1))
with tf.control_dependencies(None):
pl_mean_var = tf.Variable(name='pl_mean', trainable=False, initial_value=0, dtype=tf.float32)
pl_mean = pl_mean_var + pl_decay * (tf.reduce_mean(pl_lengths) - pl_mean_var)
pl_update = tf.assign(pl_mean_var, pl_mean)
with tf.control_dependencies([pl_update]):
pl_penalty = tf.square(pl_lengths - pl_mean)
pl_penalty = report_stat(aug, 'Loss/pl_penalty', pl_penalty)
G_reg += tf.tile(pl_penalty, [pl_minibatch_shrink]) * pl_weight
# bCR regularizer from "Improved consistency regularization for GANs".
if (bcr_real_weight != 0 or bcr_fake_weight != 0) and bcr_augment is not None:
with tf.name_scope('Loss_bCR'):
if bcr_real_weight != 0:
bcr_real_images, bcr_real_labels = dnnlib.util.call_func_by_name(D_real.images_aug, D_real.labels_aug, **bcr_augment)
D_bcr_real = eval_D(D, aug, bcr_real_images, bcr_real_labels, report='real_bcr', augment_inputs=False)
bcr_real_penalty = tf.square(D_bcr_real.scores - D_real.scores)
bcr_real_penalty = report_stat(aug, 'Loss/bcr_penalty/real', bcr_real_penalty)
D_loss += bcr_real_penalty * bcr_real_weight # NOTE: Must not use lazy regularization for this term.
if bcr_fake_weight != 0:
bcr_fake_images, bcr_fake_labels = dnnlib.util.call_func_by_name(D_fake.images_aug, D_fake.labels_aug, **bcr_augment)
D_bcr_fake = eval_D(D, aug, bcr_fake_images, bcr_fake_labels, report='fake_bcr', augment_inputs=False)
bcr_fake_penalty = tf.square(D_bcr_fake.scores - D_fake.scores)
bcr_fake_penalty = report_stat(aug, 'Loss/bcr_penalty/fake', bcr_fake_penalty)
D_loss += bcr_fake_penalty * bcr_fake_weight # NOTE: Must not use lazy regularization for this term.
# zCR regularizer from "Improved consistency regularization for GANs".
if zcr_gen_weight != 0 or zcr_dis_weight != 0:
with tf.name_scope('Loss_zCR'):
zcr_fake_latents = fake_latents + tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) * zcr_noise_std
G_zcr = eval_G(G, zcr_fake_latents, fake_labels)
if zcr_gen_weight > 0:
zcr_gen_penalty = -tf.reduce_mean(tf.square(G_fake.images - G_zcr.images), axis=[1,2,3])
zcr_gen_penalty = report_stat(aug, 'Loss/zcr_gen_penalty', zcr_gen_penalty)
G_loss += zcr_gen_penalty * zcr_gen_weight
if zcr_dis_weight > 0:
D_zcr = eval_D(D, aug, G_zcr.images, fake_labels, report='fake_zcr', augment_inputs=False)
zcr_dis_penalty = tf.square(D_fake.scores - D_zcr.scores)
zcr_dis_penalty = report_stat(aug, 'Loss/zcr_dis_penalty', zcr_dis_penalty)
D_loss += zcr_dis_penalty * zcr_dis_weight
# Auxiliary rotation loss from "Self-supervised GANs via auxiliary rotation loss".
if auxrot_alpha != 0 or auxrot_beta != 0:
with tf.name_scope('Loss_AuxRot'):
idx = tf.range(minibatch_size * 4, dtype=tf.int32) // minibatch_size
b0 = tf.logical_or(tf.equal(idx, 0), tf.equal(idx, 1))
b1 = tf.logical_or(tf.equal(idx, 0), tf.equal(idx, 3))
b2 = tf.logical_or(tf.equal(idx, 0), tf.equal(idx, 2))
if auxrot_alpha != 0:
auxrot_fake = tf.tile(G_fake.images, [4, 1, 1, 1])
auxrot_fake = tf.where(b0, auxrot_fake, tf.reverse(auxrot_fake, [2]))
auxrot_fake = tf.where(b1, auxrot_fake, tf.reverse(auxrot_fake, [3]))
auxrot_fake = tf.where(b2, auxrot_fake, tf.transpose(auxrot_fake, [0, 1, 3, 2]))
D_auxrot_fake = eval_D(D, aug, auxrot_fake, fake_labels, return_aux=4)
G_loss += tf.nn.sparse_softmax_cross_entropy_with_logits(labels=idx, logits=D_auxrot_fake.aux) * auxrot_alpha
if auxrot_beta != 0:
auxrot_real = tf.tile(real_images, [4, 1, 1, 1])
auxrot_real = tf.where(b0, auxrot_real, tf.reverse(auxrot_real, [2]))
auxrot_real = tf.where(b1, auxrot_real, tf.reverse(auxrot_real, [3]))
auxrot_real = tf.where(b2, auxrot_real, tf.transpose(auxrot_real, [0, 1, 3, 2]))
D_auxrot_real = eval_D(D, aug, auxrot_real, real_labels, return_aux=4)
D_loss += tf.nn.sparse_softmax_cross_entropy_with_logits(labels=idx, logits=D_auxrot_real.aux) * auxrot_beta
return report_loss(aug, G_loss, D_loss, G_reg, D_reg)
#----------------------------------------------------------------------------
# WGAN-GP loss with epsilon penalty, used in the paper
# "Progressive Growing of GANs for Improved Quality, Stability, and Variation".
def wgangp(G, D, aug, fake_labels, real_images, real_labels, wgan_epsilon=0.001, wgan_lambda=10, wgan_target=1, **_kwargs):
minibatch_size = tf.shape(fake_labels)[0]
fake_latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
G_fake = eval_G(G, fake_latents, fake_labels)
D_fake = eval_D(D, aug, G_fake.images, fake_labels, report='fake')
D_real = eval_D(D, aug, real_images, real_labels, report='real')
# WGAN loss from "Wasserstein Generative Adversarial Networks".
with tf.name_scope('Loss_main'):
G_loss = -D_fake.scores # pylint: disable=invalid-unary-operand-type
D_loss = D_fake.scores - D_real.scores
# Epsilon penalty from "Progressive Growing of GANs for Improved Quality, Stability, and Variation"
with tf.name_scope('Loss_epsilon'):
epsilon_penalty = report_stat(aug, 'Loss/epsilon_penalty', tf.square(D_real.scores))
D_loss += epsilon_penalty * wgan_epsilon
# Gradient penalty from "Improved Training of Wasserstein GANs".
with tf.name_scope('Loss_GP'):
mix_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0, 1, dtype=G_fake.images.dtype)
mix_images = tflib.lerp(tf.cast(real_images, G_fake.images.dtype), G_fake.images, mix_factors)
mix_labels = real_labels # NOTE: Mixing is performed without respect to fake_labels.
D_mix = eval_D(D, aug, mix_images, mix_labels, report='mix')
mix_grads = tf.gradients(tf.reduce_sum(D_mix.scores), [mix_images])[0]
mix_norms = tf.sqrt(tf.reduce_sum(tf.square(mix_grads), axis=[1,2,3]))
mix_norms = report_stat(aug, 'Loss/mix_norms', mix_norms)
gradient_penalty = tf.square(mix_norms - wgan_target)
D_reg = gradient_penalty * (wgan_lambda / (wgan_target**2))
return report_loss(aug, G_loss, D_loss, None, D_reg)
#----------------------------------------------------------------------------
Stylegan2 Network code
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
"""Network architectures used in the StyleGAN2 paper."""
import numpy as np
import tensorflow as tf
import dnnlib
import dnnlib.tflib as tflib
from dnnlib.tflib.ops.upfirdn_2d import upsample_2d, downsample_2d, upsample_conv_2d, conv_downsample_2d
from dnnlib.tflib.ops.fused_bias_act import fused_bias_act
# NOTE: Do not import any application-specific modules here!
# Specify all network parameters as kwargs.
#----------------------------------------------------------------------------
# Get/create weight tensor for a convolution or fully-connected layer.
def get_weight(shape, gain=1, use_wscale=True, lrmul=1, weight_var='weight'):
fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out]
he_std = gain / np.sqrt(fan_in) # He init
# Equalized learning rate and custom learning rate multiplier.
if use_wscale:
init_std = 1.0 / lrmul
runtime_coef = he_std * lrmul
else:
init_std = he_std / lrmul
runtime_coef = lrmul
# Create variable.
init = tf.initializers.random_normal(0, init_std)
return tf.get_variable(weight_var, shape=shape, initializer=init) * runtime_coef
#----------------------------------------------------------------------------
# Fully-connected layer.
def dense_layer(x, fmaps, gain=1, use_wscale=True, lrmul=1, weight_var='weight'):
if len(x.shape) > 2:
x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])])
w = get_weight([x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale, lrmul=lrmul, weight_var=weight_var)
w = tf.cast(w, x.dtype)
return tf.matmul(x, w)
#----------------------------------------------------------------------------
# Convolution layer with optional upsampling or downsampling.
def conv2d_layer(x, fmaps, kernel, up=False, down=False, resample_kernel=None, gain=1, use_wscale=True, lrmul=1, weight_var='weight'):
assert not (up and down)
assert kernel >= 1 and kernel % 2 == 1
w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale, lrmul=lrmul, weight_var=weight_var)
if up:
x = upsample_conv_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel)
elif down:
x = conv_downsample_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel)
else:
x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1,1,1,1], padding='SAME')
return x
#----------------------------------------------------------------------------
# Apply bias and activation func.
def apply_bias_act(x, act='linear', alpha=None, gain=None, lrmul=1, bias_var='bias'):
b = tf.get_variable(bias_var, shape=[x.shape[1]], initializer=tf.initializers.zeros()) * lrmul
return fused_bias_act(x, b=tf.cast(b, x.dtype), act=act, alpha=alpha, gain=gain)
#----------------------------------------------------------------------------
# Naive upsampling (nearest neighbor) and downsampling (average pooling).
def naive_upsample_2d(x, factor=2):
with tf.variable_scope('NaiveUpsample'):
_N, C, H, W = x.shape.as_list()
x = tf.reshape(x, [-1, C, H, 1, W, 1])
x = tf.tile(x, [1, 1, 1, factor, 1, factor])
return tf.reshape(x, [-1, C, H * factor, W * factor])
def naive_downsample_2d(x, factor=2):
with tf.variable_scope('NaiveDownsample'):
_N, C, H, W = x.shape.as_list()
x = tf.reshape(x, [-1, C, H // factor, factor, W // factor, factor])
return tf.reduce_mean(x, axis=[3,5])
#----------------------------------------------------------------------------
# Modulated convolution layer.
def modulated_conv2d_layer(x, y, fmaps, kernel, up=False, down=False, demodulate=True, resample_kernel=None, gain=1, use_wscale=True, lrmul=1, fused_modconv=True, weight_var='weight', mod_weight_var='mod_weight', mod_bias_var='mod_bias'):
assert not (up and down)
assert kernel >= 1 and kernel % 2 == 1
# Get weight.
w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale, lrmul=lrmul, weight_var=weight_var)
ww = w[np.newaxis] # [BkkIO] Introduce minibatch dimension.
# Modulate.
s = dense_layer(y, fmaps=x.shape[1].value, weight_var=mod_weight_var) # [BI] Transform incoming W to style.
s = apply_bias_act(s, bias_var=mod_bias_var) + 1 # [BI] Add bias (initially 1).
ww *= tf.cast(s[:, np.newaxis, np.newaxis, :, np.newaxis], w.dtype) # [BkkIO] Scale input feature maps.
# Demodulate.
if demodulate:
d = tf.rsqrt(tf.reduce_sum(tf.square(ww), axis=[1,2,3]) + 1e-8) # [BO] Scaling factor.
ww *= d[:, np.newaxis, np.newaxis, np.newaxis, :] # [BkkIO] Scale output feature maps.
# Reshape/scale input.
if fused_modconv:
x = tf.reshape(x, [1, -1, x.shape[2], x.shape[3]]) # Fused => reshape minibatch to convolution groups.
w = tf.reshape(tf.transpose(ww, [1, 2, 3, 0, 4]), [ww.shape[1], ww.shape[2], ww.shape[3], -1])
else:
x *= tf.cast(s[:, :, np.newaxis, np.newaxis], x.dtype) # [BIhw] Not fused => scale input activations.
# Convolution with optional up/downsampling.
if up:
x = upsample_conv_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel)
elif down:
x = conv_downsample_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel)
else:
x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1,1,1,1], padding='SAME')
# Reshape/scale output.
if fused_modconv:
x = tf.reshape(x, [-1, fmaps, x.shape[2], x.shape[3]]) # Fused => reshape convolution groups back to minibatch.
elif demodulate:
x *= tf.cast(d[:, :, np.newaxis, np.newaxis], x.dtype) # [BOhw] Not fused => scale output activations.
return x
#----------------------------------------------------------------------------
# Minibatch standard deviation layer.
def minibatch_stddev_layer(x, group_size=4, num_new_features=1):
group_size = tf.minimum(group_size, tf.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size.
s = x.shape # [NCHW] Input shape.
y = tf.reshape(x, [group_size, -1, num_new_features, s[1]//num_new_features, s[2], s[3]]) # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c.
y = tf.cast(y, tf.float32) # [GMncHW] Cast to FP32.
y -= tf.reduce_mean(y, axis=0, keepdims=True) # [GMncHW] Subtract mean over group.
y = tf.reduce_mean(tf.square(y), axis=0) # [MncHW] Calc variance over group.
y = tf.sqrt(y + 1e-8) # [MncHW] Calc stddev over group.
y = tf.reduce_mean(y, axis=[2,3,4], keepdims=True) # [Mn111] Take average over fmaps and pixels.
y = tf.reduce_mean(y, axis=[2]) # [Mn11] Split channels into c channel groups
y = tf.cast(y, x.dtype) # [Mn11] Cast back to original data type.
y = tf.tile(y, [group_size, 1, s[2], s[3]]) # [NnHW] Replicate over group and pixels.
return tf.concat([x, y], axis=1) # [NCHW] Append as new fmap.
#----------------------------------------------------------------------------
# Main generator network.
# Composed of two sub-networks (mapping and synthesis) that are defined below.
# Used in configs B-F (Table 1).
def G_main(
latents_in, # First input: Latent vectors (Z) [minibatch, latent_size].
labels_in, # Second input: Conditioning labels [minibatch, label_size].
truncation_psi = 0.5, # Style strength multiplier for the truncation trick. None = disable.
truncation_cutoff = None, # Number of layers for which to apply the truncation trick. None = disable.
truncation_psi_val = None, # Value for truncation_psi to use during validation.
truncation_cutoff_val = None, # Value for truncation_cutoff to use during validation.
dlatent_avg_beta = 0.995, # Decay for tracking the moving average of W during training. None = disable.
style_mixing_prob = 0.9, # Probability of mixing styles during training. None = disable.
is_training = False, # Network is under training? Enables and disables specific features.
is_validation = False, # Network is under validation? Chooses which value to use for truncation_psi.
return_dlatents = False, # Return dlatents in addition to the images?
is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation.
components = dnnlib.EasyDict(), # Container for sub-networks. Retained between calls.
mapping_func = 'G_mapping', # Build func name for the mapping network.
synthesis_func = 'G_synthesis_stylegan2', # Build func name for the synthesis network.
**kwargs): # Arguments for sub-networks (mapping and synthesis).
# Validate arguments.
assert not is_training or not is_validation
assert isinstance(components, dnnlib.EasyDict)
if is_validation:
truncation_psi = truncation_psi_val
truncation_cutoff = truncation_cutoff_val
if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1):
truncation_psi = None
if is_training:
truncation_cutoff = None
if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1):
dlatent_avg_beta = None
if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0):
style_mixing_prob = None
# Setup components.
if 'synthesis' not in components:
components.synthesis = tflib.Network('G_synthesis', func_name=globals()[synthesis_func], **kwargs)
num_layers = components.synthesis.input_shape[1]
dlatent_size = components.synthesis.input_shape[2]
if 'mapping' not in components:
components.mapping = tflib.Network('G_mapping', func_name=globals()[mapping_func], dlatent_broadcast=num_layers, **kwargs)
# Setup variables.
lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False)
dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False)
# Evaluate mapping network.
dlatents = components.mapping.get_output_for(latents_in, labels_in, is_training=is_training, **kwargs)
dlatents = tf.cast(dlatents, tf.float32)
# Update moving average of W.
if dlatent_avg_beta is not None:
with tf.variable_scope('DlatentAvg'):
batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
with tf.control_dependencies([update_op]):
dlatents = tf.identity(dlatents)
# Perform style mixing regularization.
if style_mixing_prob is not None:
with tf.variable_scope('StyleMix'):
latents2 = tf.random_normal(tf.shape(latents_in))
dlatents2 = components.mapping.get_output_for(latents2, labels_in, is_training=is_training, **kwargs)
dlatents2 = tf.cast(dlatents2, tf.float32)
layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2
mixing_cutoff = tf.cond(
tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,
lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32),
lambda: cur_layers)
dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2)
# Apply truncation trick.
if truncation_psi is not None:
with tf.variable_scope('Truncation'):
layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
layer_psi = np.ones(layer_idx.shape, dtype=np.float32)
if truncation_cutoff is None:
layer_psi *= truncation_psi
else:
layer_psi = tf.where(layer_idx < truncation_cutoff, layer_psi * truncation_psi, layer_psi)
dlatents = tflib.lerp(dlatent_avg, dlatents, layer_psi)
# Evaluate synthesis network.
deps = []
if 'lod' in components.synthesis.vars:
deps.append(tf.assign(components.synthesis.vars['lod'], lod_in))
with tf.control_dependencies(deps):
images_out = components.synthesis.get_output_for(dlatents, is_training=is_training, force_clean_graph=is_template_graph, **kwargs)
# Return requested outputs.
images_out = tf.identity(images_out, name='images_out')
if return_dlatents:
return images_out, dlatents
return images_out
#----------------------------------------------------------------------------
# Mapping network.
# Transforms the input latent code (z) to the disentangled latent code (w).
# Used in configs B-F (Table 1).
def G_mapping(
latents_in, # First input: Latent vectors (Z) [minibatch, latent_size].
labels_in, # Second input: Conditioning labels [minibatch, label_size].
latent_size = 512, # Latent vector (Z) dimensionality.
label_size = 0, # Label dimensionality, 0 if no labels.
dlatent_size = 512, # Disentangled latent (W) dimensionality.
dlatent_broadcast = None, # Output disentangled latent (W) as [minibatch, dlatent_size] or [minibatch, dlatent_broadcast, dlatent_size].
mapping_layers = 8, # Number of mapping layers.
mapping_fmaps = 512, # Number of activations in the mapping layers.
mapping_lrmul = 0.01, # Learning rate multiplier for the mapping layers.
mapping_nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
normalize_latents = True, # Normalize latent vectors (Z) before feeding them to the mapping layers?
dtype = 'float32', # Data type to use for activations and outputs.
**_kwargs): # Ignore unrecognized keyword args.
act = mapping_nonlinearity
# Inputs.
latents_in.set_shape([None, latent_size])
labels_in.set_shape([None, label_size])
latents_in = tf.cast(latents_in, dtype)
labels_in = tf.cast(labels_in, dtype)
x = latents_in
# Embed labels and concatenate them with latents.
if label_size:
with tf.variable_scope('LabelConcat'):
w = tf.get_variable('weight', shape=[label_size, latent_size], initializer=tf.initializers.random_normal())
y = tf.matmul(labels_in, tf.cast(w, dtype))
x = tf.concat([x, y], axis=1)
# Normalize latents.
if normalize_latents:
with tf.variable_scope('Normalize'):
x *= tf.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + 1e-8)
# Mapping layers.
for layer_idx in range(mapping_layers):
with tf.variable_scope('Dense%d' % layer_idx):
fmaps = dlatent_size if layer_idx == mapping_layers - 1 else mapping_fmaps
x = apply_bias_act(dense_layer(x, fmaps=fmaps, lrmul=mapping_lrmul), act=act, lrmul=mapping_lrmul)
# Broadcast.
if dlatent_broadcast is not None:
with tf.variable_scope('Broadcast'):
x = tf.tile(x[:, np.newaxis], [1, dlatent_broadcast, 1])
# Output.
assert x.dtype == tf.as_dtype(dtype)
return tf.identity(x, name='dlatents_out')
#----------------------------------------------------------------------------
# StyleGAN synthesis network with revised architecture (Figure 2d).
# Implements progressive growing, but no skip connections or residual nets (Figure 7).
# Used in configs B-D (Table 1).
def G_synthesis_stylegan_revised(
dlatents_in, # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].
dlatent_size = 512, # Disentangled latent (W) dimensionality.
num_channels = 3, # Number of output color channels.
resolution = 1024, # Output resolution.
fmap_base = 16 << 10, # Overall multiplier for the number of feature maps.
fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution.
fmap_min = 1, # Minimum number of feature maps in any layer.
fmap_max = 512, # Maximum number of feature maps in any layer.
randomize_noise = True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
dtype = 'float32', # Data type to use for activations and outputs.
resample_kernel = [1,3,3,1], # Low-pass filter to apply when resampling activations. None = no filtering.
fused_modconv = True, # Implement modulated_conv2d_layer() as a single fused op?
structure = 'auto', # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically.
is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation.
force_clean_graph = False, # True = construct a clean graph that looks nice in TensorBoard, False = default behavior.
**_kwargs): # Ignore unrecognized keyword args.
resolution_log2 = int(np.log2(resolution))
assert resolution == 2**resolution_log2 and resolution >= 4
def nf(stage): return np.clip(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_min, fmap_max)
if is_template_graph: force_clean_graph = True
if force_clean_graph: randomize_noise = False
if structure == 'auto': structure = 'linear' if force_clean_graph else 'recursive'
act = nonlinearity
num_layers = resolution_log2 * 2 - 2
images_out = None
# Primary inputs.
dlatents_in.set_shape([None, num_layers, dlatent_size])
dlatents_in = tf.cast(dlatents_in, dtype)
lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype)
# Noise inputs.
noise_inputs = []
for layer_idx in range(num_layers - 1):
res = (layer_idx + 5) // 2
shape = [1, 1, 2**res, 2**res]
noise_inputs.append(tf.get_variable('noise%d' % layer_idx, shape=shape, initializer=tf.initializers.random_normal(), trainable=False))
# Single convolution layer with all the bells and whistles.
def layer(x, layer_idx, fmaps, kernel, up=False):
x = modulated_conv2d_layer(x, dlatents_in[:, layer_idx], fmaps=fmaps, kernel=kernel, up=up, resample_kernel=resample_kernel, fused_modconv=fused_modconv)
if randomize_noise:
noise = tf.random_normal([tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)
else:
noise = tf.cast(noise_inputs[layer_idx], x.dtype)
noise_strength = tf.get_variable('noise_strength', shape=[], initializer=tf.initializers.zeros())
x += noise * tf.cast(noise_strength, x.dtype)
return apply_bias_act(x, act=act)
# Early layers.
with tf.variable_scope('4x4'):
with tf.variable_scope('Const'):
x = tf.get_variable('const', shape=[1, nf(1), 4, 4], initializer=tf.initializers.random_normal())
x = tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1])
with tf.variable_scope('Conv'):
x = layer(x, layer_idx=0, fmaps=nf(1), kernel=3)
# Building blocks for remaining layers.
def block(res, x): # res = 3..resolution_log2
with tf.variable_scope('%dx%d' % (2**res, 2**res)):
with tf.variable_scope('Conv0_up'):
x = layer(x, layer_idx=res*2-5, fmaps=nf(res-1), kernel=3, up=True)
with tf.variable_scope('Conv1'):
x = layer(x, layer_idx=res*2-4, fmaps=nf(res-1), kernel=3)
return x
def torgb(res, x): # res = 2..resolution_log2
with tf.variable_scope('ToRGB_lod%d' % (resolution_log2 - res)):
return apply_bias_act(modulated_conv2d_layer(x, dlatents_in[:, res*2-3], fmaps=num_channels, kernel=1, demodulate=False, fused_modconv=fused_modconv))
# Fixed structure: simple and efficient, but does not support progressive growing.
if structure == 'fixed':
for res in range(3, resolution_log2 + 1):
x = block(res, x)
images_out = torgb(resolution_log2, x)
# Linear structure: simple but inefficient.
if structure == 'linear':
images_out = torgb(2, x)
for res in range(3, resolution_log2 + 1):
lod = resolution_log2 - res
x = block(res, x)
img = torgb(res, x)
with tf.variable_scope('Upsample_lod%d' % lod):
images_out = upsample_2d(images_out)
with tf.variable_scope('Grow_lod%d' % lod):
images_out = tflib.lerp_clip(img, images_out, lod_in - lod)
# Recursive structure: complex but efficient.
if structure == 'recursive':
def cset(cur_lambda, new_cond, new_lambda):
return lambda: tf.cond(new_cond, new_lambda, cur_lambda)
def grow(x, res, lod):
y = block(res, x)
img = lambda: naive_upsample_2d(torgb(res, y), factor=2**lod)
img = cset(img, (lod_in > lod), lambda: naive_upsample_2d(tflib.lerp(torgb(res, y), upsample_2d(torgb(res - 1, x)), lod_in - lod), factor=2**lod))
if lod > 0: img = cset(img, (lod_in < lod), lambda: grow(y, res + 1, lod - 1))
return img()
images_out = grow(x, 3, resolution_log2 - 3)
assert images_out.dtype == tf.as_dtype(dtype)
return tf.identity(images_out, name='images_out')
#----------------------------------------------------------------------------
# StyleGAN2 synthesis network (Figure 7).
# Implements skip connections and residual nets (Figure 7), but no progressive growing.
# Used in configs E-F (Table 1).
def G_synthesis_stylegan2(
dlatents_in, # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].
dlatent_size = 512, # Disentangled latent (W) dimensionality.
num_channels = 3, # Number of output color channels.
resolution = 1024, # Output resolution.
fmap_base = 16 << 10, # Overall multiplier for the number of feature maps.
fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution.
fmap_min = 1, # Minimum number of feature maps in any layer.
fmap_max = 512, # Maximum number of feature maps in any layer.
randomize_noise = True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
dtype = 'float32', # Data type to use for activations and outputs.
resample_kernel = [1,3,3,1], # Low-pass filter to apply when resampling activations. None = no filtering.
fused_modconv = True, # Implement modulated_conv2d_layer() as a single fused op?
**_kwargs): # Ignore unrecognized keyword args.
resolution_log2 = int(np.log2(resolution))
assert resolution == 2**resolution_log2 and resolution >= 4
def nf(stage): return np.clip(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_min, fmap_max)
assert architecture in ['orig', 'skip', 'resnet']
act = nonlinearity
num_layers = resolution_log2 * 2 - 2
images_out = None
# Primary inputs.
dlatents_in.set_shape([None, num_layers, dlatent_size])
dlatents_in = tf.cast(dlatents_in, dtype)
# Noise inputs.
noise_inputs = []
for layer_idx in range(num_layers - 1):
res = (layer_idx + 5) // 2
shape = [1, 1, 2**res, 2**res]
noise_inputs.append(tf.get_variable('noise%d' % layer_idx, shape=shape, initializer=tf.initializers.random_normal(), trainable=False))
# Single convolution layer with all the bells and whistles.
def layer(x, layer_idx, fmaps, kernel, up=False):
x = modulated_conv2d_layer(x, dlatents_in[:, layer_idx], fmaps=fmaps, kernel=kernel, up=up, resample_kernel=resample_kernel, fused_modconv=fused_modconv)
if randomize_noise:
noise = tf.random_normal([tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)
else:
noise = tf.cast(noise_inputs[layer_idx], x.dtype)
noise_strength = tf.get_variable('noise_strength', shape=[], initializer=tf.initializers.zeros())
x += noise * tf.cast(noise_strength, x.dtype)
return apply_bias_act(x, act=act)
# Building blocks for main layers.
def block(x, res): # res = 3..resolution_log2
t = x
with tf.variable_scope('Conv0_up'):
x = layer(x, layer_idx=res*2-5, fmaps=nf(res-1), kernel=3, up=True)
with tf.variable_scope('Conv1'):
x = layer(x, layer_idx=res*2-4, fmaps=nf(res-1), kernel=3)
if architecture == 'resnet':
with tf.variable_scope('Skip'):
t = conv2d_layer(t, fmaps=nf(res-1), kernel=1, up=True, resample_kernel=resample_kernel)
x = (x + t) * (1 / np.sqrt(2))
return x
def upsample(y):
with tf.variable_scope('Upsample'):
return upsample_2d(y, k=resample_kernel)
def torgb(x, y, res): # res = 2..resolution_log2
with tf.variable_scope('ToRGB'):
t = apply_bias_act(modulated_conv2d_layer(x, dlatents_in[:, res*2-3], fmaps=num_channels, kernel=1, demodulate=False, fused_modconv=fused_modconv))
return t if y is None else y + t
# Early layers.
y = None
with tf.variable_scope('4x4'):
with tf.variable_scope('Const'):
x = tf.get_variable('const', shape=[1, nf(1), 4, 4], initializer=tf.initializers.random_normal())
x = tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1])
with tf.variable_scope('Conv'):
x = layer(x, layer_idx=0, fmaps=nf(1), kernel=3)
if architecture == 'skip':
y = torgb(x, y, 2)
# Main layers.
for res in range(3, resolution_log2 + 1):
with tf.variable_scope('%dx%d' % (2**res, 2**res)):
x = block(x, res)
if architecture == 'skip':
y = upsample(y)
if architecture == 'skip' or res == resolution_log2:
y = torgb(x, y, res)
images_out = y
assert images_out.dtype == tf.as_dtype(dtype)
return tf.identity(images_out, name='images_out')
#----------------------------------------------------------------------------
# Original StyleGAN discriminator.
# Used in configs B-D (Table 1).
def D_stylegan(
images_in, # First input: Images [minibatch, channel, height, width].
labels_in, # Second input: Labels [minibatch, label_size].
num_channels = 3, # Number of input color channels. Overridden based on dataset.
resolution = 1024, # Input resolution. Overridden based on dataset.
label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
fmap_base = 16 << 10, # Overall multiplier for the number of feature maps.
fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution.
fmap_min = 1, # Minimum number of feature maps in any layer.
fmap_max = 512, # Maximum number of feature maps in any layer.
nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, 0 = disable.
mbstd_num_features = 1, # Number of features for the minibatch standard deviation layer.
dtype = 'float32', # Data type to use for activations and outputs.
resample_kernel = [1,3,3,1], # Low-pass filter to apply when resampling activations. None = no filtering.
structure = 'auto', # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically.
is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation.
**_kwargs): # Ignore unrecognized keyword args.
resolution_log2 = int(np.log2(resolution))
assert resolution == 2**resolution_log2 and resolution >= 4
def nf(stage): return np.clip(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_min, fmap_max)
if structure == 'auto': structure = 'linear' if is_template_graph else 'recursive'
act = nonlinearity
images_in.set_shape([None, num_channels, resolution, resolution])
labels_in.set_shape([None, label_size])
images_in = tf.cast(images_in, dtype)
labels_in = tf.cast(labels_in, dtype)
lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype)
# Building blocks for spatial layers.
def fromrgb(x, res): # res = 2..resolution_log2
with tf.variable_scope('FromRGB_lod%d' % (resolution_log2 - res)):
return apply_bias_act(conv2d_layer(x, fmaps=nf(res-1), kernel=1), act=act)
def block(x, res): # res = 2..resolution_log2
with tf.variable_scope('%dx%d' % (2**res, 2**res)):
with tf.variable_scope('Conv0'):
x = apply_bias_act(conv2d_layer(x, fmaps=nf(res-1), kernel=3), act=act)
with tf.variable_scope('Conv1_down'):
x = apply_bias_act(conv2d_layer(x, fmaps=nf(res-2), kernel=3, down=True, resample_kernel=resample_kernel), act=act)
return x
# Fixed structure: simple and efficient, but does not support progressive growing.
if structure == 'fixed':
x = fromrgb(images_in, resolution_log2)
for res in range(resolution_log2, 2, -1):
x = block(x, res)
# Linear structure: simple but inefficient.
if structure == 'linear':
img = images_in
x = fromrgb(img, resolution_log2)
for res in range(resolution_log2, 2, -1):
lod = resolution_log2 - res
x = block(x, res)
with tf.variable_scope('Downsample_lod%d' % lod):
img = downsample_2d(img)
y = fromrgb(img, res - 1)
with tf.variable_scope('Grow_lod%d' % lod):
x = tflib.lerp_clip(x, y, lod_in - lod)
# Recursive structure: complex but efficient.
if structure == 'recursive':
def cset(cur_lambda, new_cond, new_lambda):
return lambda: tf.cond(new_cond, new_lambda, cur_lambda)
def grow(res, lod):
x = lambda: fromrgb(naive_downsample_2d(images_in, factor=2**lod), res)
if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1))
x = block(x(), res); y = lambda: x
y = cset(y, (lod_in > lod), lambda: tflib.lerp(x, fromrgb(naive_downsample_2d(images_in, factor=2**(lod+1)), res - 1), lod_in - lod))
return y()
x = grow(3, resolution_log2 - 3)
# Final layers at 4x4 resolution.
with tf.variable_scope('4x4'):
if mbstd_group_size > 1:
with tf.variable_scope('MinibatchStddev'):
x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features)
with tf.variable_scope('Conv'):
x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act)
with tf.variable_scope('Dense0'):
x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act)
# Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?"
with tf.variable_scope('Output'):
x = apply_bias_act(dense_layer(x, fmaps=max(labels_in.shape[1], 1)))
if labels_in.shape[1] > 0:
x = tf.reduce_sum(x * labels_in, axis=1, keepdims=True)
scores_out = x
# Output.
assert scores_out.dtype == tf.as_dtype(dtype)
scores_out = tf.identity(scores_out, name='scores_out')
return scores_out
#----------------------------------------------------------------------------
# StyleGAN2 discriminator (Figure 7).
# Implements skip connections and residual nets (Figure 7), but no progressive growing.
# Used in configs E-F (Table 1).
def D_stylegan2(
images_in, # First input: Images [minibatch, channel, height, width].
labels_in, # Second input: Labels [minibatch, label_size].
num_channels = 3, # Number of input color channels. Overridden based on dataset.
resolution = 1024, # Input resolution. Overridden based on dataset.
label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
fmap_base = 16 << 10, # Overall multiplier for the number of feature maps.
fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution.
fmap_min = 1, # Minimum number of feature maps in any layer.
fmap_max = 512, # Maximum number of feature maps in any layer.
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, 0 = disable.
mbstd_num_features = 1, # Number of features for the minibatch standard deviation layer.
dtype = 'float32', # Data type to use for activations and outputs.
resample_kernel = [1,3,3,1], # Low-pass filter to apply when resampling activations. None = no filtering.
**_kwargs): # Ignore unrecognized keyword args.
resolution_log2 = int(np.log2(resolution))
assert resolution == 2**resolution_log2 and resolution >= 4
def nf(stage): return np.clip(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_min, fmap_max)
assert architecture in ['orig', 'skip', 'resnet']
act = nonlinearity
images_in.set_shape([None, num_channels, resolution, resolution])
labels_in.set_shape([None, label_size])
images_in = tf.cast(images_in, dtype)
labels_in = tf.cast(labels_in, dtype)
# Building blocks for main layers.
def fromrgb(x, y, res): # res = 2..resolution_log2
with tf.variable_scope('FromRGB'):
t = apply_bias_act(conv2d_layer(y, fmaps=nf(res-1), kernel=1), act=act)
return t if x is None else x + t
def block(x, res): # res = 2..resolution_log2
t = x
with tf.variable_scope('Conv0'):
x = apply_bias_act(conv2d_layer(x, fmaps=nf(res-1), kernel=3), act=act)
with tf.variable_scope('Conv1_down'):
x = apply_bias_act(conv2d_layer(x, fmaps=nf(res-2), kernel=3, down=True, resample_kernel=resample_kernel), act=act)
if architecture == 'resnet':
with tf.variable_scope('Skip'):
t = conv2d_layer(t, fmaps=nf(res-2), kernel=1, down=True, resample_kernel=resample_kernel)
x = (x + t) * (1 / np.sqrt(2))
return x
def downsample(y):
with tf.variable_scope('Downsample'):
return downsample_2d(y, k=resample_kernel)
# Main layers.
x = None
y = images_in
for res in range(resolution_log2, 2, -1):
with tf.variable_scope('%dx%d' % (2**res, 2**res)):
if architecture == 'skip' or res == resolution_log2:
x = fromrgb(x, y, res)
x = block(x, res)
if architecture == 'skip':
y = downsample(y)
# Final layers.
with tf.variable_scope('4x4'):
if architecture == 'skip':
x = fromrgb(x, y, 2)
if mbstd_group_size > 1:
with tf.variable_scope('MinibatchStddev'):
x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features)
with tf.variable_scope('Conv'):
x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act)
with tf.variable_scope('Dense0'):
x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act)
# Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?"
with tf.variable_scope('Output'):
x = apply_bias_act(dense_layer(x, fmaps=max(labels_in.shape[1], 1)))
if labels_in.shape[1] > 0:
x = tf.reduce_sum(x * labels_in, axis=1, keepdims=True)
scores_out = x
# Output.
assert scores_out.dtype == tf.as_dtype(dtype)
scores_out = tf.identity(scores_out, name='scores_out')
return scores_out
#----------------------------------------------------------------------------
Stylegan2-ada Network code
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Network architectures from the paper
"Training Generative Adversarial Networks with Limited Data"."""
import numpy as np
import tensorflow as tf
import dnnlib
import dnnlib.tflib as tflib
from dnnlib.tflib.ops.upfirdn_2d import upsample_2d, downsample_2d, upsample_conv_2d, conv_downsample_2d
from dnnlib.tflib.ops.fused_bias_act import fused_bias_act
# NOTE: Do not import any application-specific modules here!
# Specify all network parameters as kwargs.
#----------------------------------------------------------------------------
# Get/create weight tensor for convolution or fully-connected layer.
def get_weight(shape, gain=1, equalized_lr=True, lrmul=1, weight_var='weight', trainable=True, use_spectral_norm=False):
fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] for conv2d, [in, out] for fully-connected.
he_std = gain / np.sqrt(fan_in) # He init.
# Apply equalized learning rate from the paper
# "Progressive Growing of GANs for Improved Quality, Stability, and Variation".
if equalized_lr:
init_std = 1.0 / lrmul
runtime_coef = he_std * lrmul
else:
init_std = he_std / lrmul
runtime_coef = lrmul
# Create variable.
init = tf.initializers.random_normal(0, init_std)
w = tf.get_variable(weight_var, shape=shape, initializer=init, trainable=trainable) * runtime_coef
if use_spectral_norm:
w = apply_spectral_norm(w, state_var=weight_var+'_sn')
return w
#----------------------------------------------------------------------------
# Bias and activation function.
def apply_bias_act(x, act='linear', gain=None, lrmul=1, clamp=None, bias_var='bias', trainable=True):
b = tf.get_variable(bias_var, shape=[x.shape[1]], initializer=tf.initializers.zeros(), trainable=trainable) * lrmul
return fused_bias_act(x, b=tf.cast(b, x.dtype), act=act, gain=gain, clamp=clamp)
#----------------------------------------------------------------------------
# Fully-connected layer.
def dense_layer(x, fmaps, lrmul=1, weight_var='weight', trainable=True, use_spectral_norm=False):
if len(x.shape) > 2:
x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])])
w = get_weight([x.shape[1].value, fmaps], lrmul=lrmul, weight_var=weight_var, trainable=trainable, use_spectral_norm=use_spectral_norm)
w = tf.cast(w, x.dtype)
return tf.matmul(x, w)
#----------------------------------------------------------------------------
# 2D convolution op with optional upsampling, downsampling, and padding.
def conv2d(x, w, up=False, down=False, resample_kernel=None, padding=0):
assert not (up and down)
kernel = w.shape[0].value
assert w.shape[1].value == kernel
assert kernel >= 1 and kernel % 2 == 1
w = tf.cast(w, x.dtype)
if up:
x = upsample_conv_2d(x, w, data_format='NCHW', k=resample_kernel, padding=padding)
elif down:
x = conv_downsample_2d(x, w, data_format='NCHW', k=resample_kernel, padding=padding)
else:
padding_mode = {0: 'SAME', -(kernel // 2): 'VALID'}[padding]
x = tf.nn.conv2d(x, w, data_format='NCHW', strides=[1,1,1,1], padding=padding_mode)
return x
#----------------------------------------------------------------------------
# 2D convolution layer.
def conv2d_layer(x, fmaps, kernel, up=False, down=False, resample_kernel=None, lrmul=1, trainable=True, use_spectral_norm=False):
w = get_weight([kernel, kernel, x.shape[1].value, fmaps], lrmul=lrmul, trainable=trainable, use_spectral_norm=use_spectral_norm)
return conv2d(x, tf.cast(w, x.dtype), up=up, down=down, resample_kernel=resample_kernel)
#----------------------------------------------------------------------------
# Modulated 2D convolution layer from the paper
# "Analyzing and Improving Image Quality of StyleGAN".
def modulated_conv2d_layer(x, y, fmaps, kernel, up=False, down=False, demodulate=True, resample_kernel=None, lrmul=1, fused_modconv=False, trainable=True, use_spectral_norm=False):
assert not (up and down)
assert kernel >= 1 and kernel % 2 == 1
# Get weight.
wshape = [kernel, kernel, x.shape[1].value, fmaps]
w = get_weight(wshape, lrmul=lrmul, trainable=trainable, use_spectral_norm=use_spectral_norm)
if x.dtype.name == 'float16' and not fused_modconv and demodulate:
w *= np.sqrt(1 / np.prod(wshape[:-1])) / tf.reduce_max(tf.abs(w), axis=[0,1,2]) # Pre-normalize to avoid float16 overflow.
ww = w[np.newaxis] # [BkkIO] Introduce minibatch dimension.
# Modulate.
s = dense_layer(y, fmaps=x.shape[1].value, weight_var='mod_weight', trainable=trainable, use_spectral_norm=use_spectral_norm) # [BI] Transform incoming W to style.
s = apply_bias_act(s, bias_var='mod_bias', trainable=trainable) + 1 # [BI] Add bias (initially 1).
if x.dtype.name == 'float16' and not fused_modconv and demodulate:
s *= 1 / tf.reduce_max(tf.abs(s)) # Pre-normalize to avoid float16 overflow.
ww *= tf.cast(s[:, np.newaxis, np.newaxis, :, np.newaxis], w.dtype) # [BkkIO] Scale input feature maps.
# Demodulate.
if demodulate:
d = tf.rsqrt(tf.reduce_sum(tf.square(ww), axis=[1,2,3]) + 1e-8) # [BO] Scaling factor.
ww *= d[:, np.newaxis, np.newaxis, np.newaxis, :] # [BkkIO] Scale output feature maps.
# Reshape/scale input.
if fused_modconv:
x = tf.reshape(x, [1, -1, x.shape[2], x.shape[3]]) # Fused => reshape minibatch to convolution groups.
w = tf.reshape(tf.transpose(ww, [1, 2, 3, 0, 4]), [ww.shape[1], ww.shape[2], ww.shape[3], -1])
else:
x *= tf.cast(s[:, :, np.newaxis, np.newaxis], x.dtype) # [BIhw] Not fused => scale input activations.
# 2D convolution.
x = conv2d(x, tf.cast(w, x.dtype), up=up, down=down, resample_kernel=resample_kernel)
# Reshape/scale output.
if fused_modconv:
x = tf.reshape(x, [-1, fmaps, x.shape[2], x.shape[3]]) # Fused => reshape convolution groups back to minibatch.
elif demodulate:
x *= tf.cast(d[:, :, np.newaxis, np.newaxis], x.dtype) # [BOhw] Not fused => scale output activations.
return x
#----------------------------------------------------------------------------
# Normalize 2nd raw moment of the given activation tensor along specified axes.
def normalize_2nd_moment(x, axis=1, eps=1e-8):
return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=axis, keepdims=True) + eps)
#----------------------------------------------------------------------------
# Minibatch standard deviation layer from the paper
# "Progressive Growing of GANs for Improved Quality, Stability, and Variation".
def minibatch_stddev_layer(x, group_size=None, num_new_features=1):
if group_size is None:
group_size = tf.shape(x)[0]
else:
group_size = tf.minimum(group_size, tf.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size.
G = group_size
F = num_new_features
_N, C, H, W = x.shape.as_list()
c = C // F
y = tf.cast(x, tf.float32) # [NCHW] Cast to FP32.
y = tf.reshape(y, [G, -1, F, c, H, W]) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
y -= tf.reduce_mean(y, axis=0) # [GnFcHW] Subtract mean over group.
y = tf.reduce_mean(tf.square(y), axis=0) # [nFcHW] Calc variance over group.
y = tf.sqrt(y + 1e-8) # [nFcHW] Calc stddev over group.
y = tf.reduce_mean(y, axis=[2,3,4]) # [nF] Take average over channels and pixels.
y = tf.cast(y, x.dtype) # [nF] Cast back to original data type.
y = tf.reshape(y, [-1, F, 1, 1]) # [nF11] Add missing dimensions.
y = tf.tile(y, [G, 1, H, W]) # [NFHW] Replicate over group and pixels.
return tf.concat([x, y], axis=1) # [NCHW] Append to input as new channels.
#----------------------------------------------------------------------------
# Spectral normalization from the paper
# "Spectral Normalization for Generative Adversarial Networks".
def apply_spectral_norm(w, state_var='sn', iterations=1, eps=1e-8):
fmaps = w.shape[-1].value
w_mat = tf.reshape(w, [-1, fmaps])
u_var = tf.get_variable(state_var, shape=[1,fmaps], initializer=tf.initializers.random_normal(), trainable=False)
u = u_var
for _ in range(iterations):
v = tf.matmul(u, w_mat, transpose_b=True)
v *= tf.rsqrt(tf.reduce_sum(tf.square(v)) + eps)
u = tf.matmul(v, w_mat)
sigma_inv = tf.rsqrt(tf.reduce_sum(tf.square(u)) + eps)
u *= sigma_inv
with tf.control_dependencies([tf.assign(u_var, u)]):
return w * sigma_inv
#----------------------------------------------------------------------------
# Main generator network.
# Composed of two sub-networks (mapping and synthesis) that are defined below.
def G_main(
latents_in, # First input: Latent vectors (Z) [minibatch, latent_size].
labels_in, # Second input: Conditioning labels [minibatch, label_size].
# Evaluation mode.
is_training = False, # Network is under training? Enables and disables specific features.
is_validation = False, # Network is under validation? Chooses which value to use for truncation_psi.
return_dlatents = False, # Return dlatents (W) in addition to the images?
# Truncation & style mixing.
truncation_psi = 0.5, # Style strength multiplier for the truncation trick. None = disable.
truncation_cutoff = None, # Number of layers for which to apply the truncation trick. None = disable.
truncation_psi_val = None, # Value for truncation_psi to use during validation.
truncation_cutoff_val = None, # Value for truncation_cutoff to use during validation.
dlatent_avg_beta = 0.995, # Decay for tracking the moving average of W during training. None = disable.
style_mixing_prob = 0.9, # Probability of mixing styles during training. None = disable.
# Sub-networks.
components = dnnlib.EasyDict(), # Container for sub-networks. Retained between calls.
mapping_func = 'G_mapping', # Build func name for the mapping network.
synthesis_func = 'G_synthesis', # Build func name for the synthesis network.
is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation.
**kwargs, # Arguments for sub-networks (mapping and synthesis).
):
# Validate arguments.
assert not is_training or not is_validation
assert isinstance(components, dnnlib.EasyDict)
if is_validation:
truncation_psi = truncation_psi_val
truncation_cutoff = truncation_cutoff_val
if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1):
truncation_psi = None
if is_training:
truncation_cutoff = None
if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1):
dlatent_avg_beta = None
if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0):
style_mixing_prob = None
# Setup components.
if 'synthesis' not in components:
components.synthesis = tflib.Network('G_synthesis', func_name=globals()[synthesis_func], **kwargs)
num_layers = components.synthesis.input_shape[1]
dlatent_size = components.synthesis.input_shape[2]
if 'mapping' not in components:
components.mapping = tflib.Network('G_mapping', func_name=globals()[mapping_func], dlatent_broadcast=num_layers, **kwargs)
# Evaluate mapping network.
dlatents = components.mapping.get_output_for(latents_in, labels_in, is_training=is_training, **kwargs)
dlatents = tf.cast(dlatents, tf.float32)
# Update moving average of W.
dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False)
if dlatent_avg_beta is not None:
with tf.variable_scope('DlatentAvg'):
batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
with tf.control_dependencies([update_op]):
dlatents = tf.identity(dlatents)
# Perform style mixing regularization.
if style_mixing_prob is not None:
with tf.variable_scope('StyleMix'):
latents2 = tf.random_normal(tf.shape(latents_in))
dlatents2 = components.mapping.get_output_for(latents2, labels_in, is_training=is_training, **kwargs)
dlatents2 = tf.cast(dlatents2, tf.float32)
layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
mixing_cutoff = tf.cond(
tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,
lambda: tf.random_uniform([], 1, num_layers, dtype=tf.int32),
lambda: num_layers)
dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2)
# Apply truncation.
if truncation_psi is not None:
with tf.variable_scope('Truncation'):
layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
layer_psi = np.ones(layer_idx.shape, dtype=np.float32)
if truncation_cutoff is None:
layer_psi *= truncation_psi
else:
layer_psi = tf.where(layer_idx < truncation_cutoff, layer_psi * truncation_psi, layer_psi)
dlatents = tflib.lerp(dlatent_avg, dlatents, layer_psi)
# Evaluate synthesis network.
images_out = components.synthesis.get_output_for(dlatents, is_training=is_training, force_clean_graph=is_template_graph, **kwargs)
images_out = tf.identity(images_out, name='images_out')
if return_dlatents:
return images_out, dlatents
return images_out
#----------------------------------------------------------------------------
# Generator mapping network.
def G_mapping(
latents_in, # First input: Latent vectors (Z) [minibatch, latent_size].
labels_in, # Second input: Conditioning labels [minibatch, label_size].
# Input & output dimensions.
latent_size = 512, # Latent vector (Z) dimensionality.
label_size = 0, # Label dimensionality, 0 if no labels.
dlatent_size = 512, # Disentangled latent (W) dimensionality.
dlatent_broadcast = None, # Output disentangled latent (W) as [minibatch, dlatent_size] or [minibatch, dlatent_broadcast, dlatent_size].
# Internal details.
mapping_layers = 8, # Number of mapping layers.
mapping_fmaps = None, # Number of activations in the mapping layers, None = same as dlatent_size.
mapping_lrmul = 0.01, # Learning rate multiplier for the mapping layers.
mapping_nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
normalize_latents = True, # Normalize latent vectors (Z) before feeding them to the mapping layers?
label_fmaps = None, # Label embedding dimensionality, None = same as latent_size.
dtype = 'float32', # Data type to use for intermediate activations and outputs.
**_kwargs, # Ignore unrecognized keyword args.
):
# Inputs.
latents_in.set_shape([None, latent_size])
labels_in.set_shape([None, label_size])
latents_in = tf.cast(latents_in, dtype)
labels_in = tf.cast(labels_in, dtype)
x = latents_in
# Normalize latents.
if normalize_latents:
with tf.variable_scope('Normalize'):
x = normalize_2nd_moment(x)
# Embed labels, normalize, and concatenate with latents.
if label_size > 0:
with tf.variable_scope('LabelEmbed'):
fmaps = label_fmaps if label_fmaps is not None else latent_size
y = labels_in
y = apply_bias_act(dense_layer(y, fmaps=fmaps))
y = normalize_2nd_moment(y)
x = tf.concat([x, y], axis=1)
# Mapping layers.
for layer_idx in range(mapping_layers):
with tf.variable_scope(f'Dense{layer_idx}'):
fmaps = mapping_fmaps if mapping_fmaps is not None and layer_idx < mapping_layers - 1 else dlatent_size
x = apply_bias_act(dense_layer(x, fmaps=fmaps, lrmul=mapping_lrmul), act=mapping_nonlinearity, lrmul=mapping_lrmul)
# Broadcast.
if dlatent_broadcast is not None:
with tf.variable_scope('Broadcast'):
x = tf.tile(x[:, np.newaxis], [1, dlatent_broadcast, 1])
# Output.
assert x.dtype == tf.as_dtype(dtype)
return tf.identity(x, name='dlatents_out')
#----------------------------------------------------------------------------
# Generator synthesis network.
def G_synthesis(
dlatents_in, # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].
# Input & output dimensions.
dlatent_size = 512, # Disentangled latent (W) dimensionality.
num_channels = 3, # Number of output color channels.
resolution = 1024, # Output resolution.
# Capacity.
fmap_base = 16384, # Overall multiplier for the number of feature maps.
fmap_decay = 1, # Log2 feature map reduction when doubling the resolution.
fmap_min = 1, # Minimum number of feature maps in any layer.
fmap_max = 512, # Maximum number of feature maps in any layer.
fmap_const = None, # Number of feature maps in the constant input layer. None = default.
# Internal details.
use_noise = True, # Enable noise inputs?
randomize_noise = True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
dtype = 'float32', # Data type to use for intermediate activations and outputs.
num_fp16_res = 0, # Use FP16 for the N highest resolutions, regardless of dtype.
conv_clamp = None, # Clamp the output of convolution layers to [-conv_clamp, +conv_clamp], None = disable clamping.
resample_kernel = [1,3,3,1], # Low-pass filter to apply when resampling activations, None = box filter.
fused_modconv = False, # Implement modulated_conv2d_layer() using grouped convolution?
**_kwargs, # Ignore unrecognized keyword args.
):
resolution_log2 = int(np.log2(resolution))
assert resolution == 2**resolution_log2 and resolution >= 4
def nf(stage): return np.clip(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_min, fmap_max)
assert architecture in ['orig', 'skip', 'resnet']
act = nonlinearity
num_layers = resolution_log2 * 2 - 2
# Disentangled latent (W).
dlatents_in.set_shape([None, num_layers, dlatent_size])
dlatents_in = tf.cast(dlatents_in, dtype)
# Noise inputs.
noise_inputs = []
if use_noise:
for layer_idx in range(num_layers - 1):
res = (layer_idx + 5) // 2
shape = [1, 1, 2**res, 2**res]
noise_inputs.append(tf.get_variable(f'noise{layer_idx}', shape=shape, initializer=tf.initializers.random_normal(), trainable=False))
# Single convolution layer with all the bells and whistles.
def layer(x, layer_idx, fmaps, kernel, up=False):
x = modulated_conv2d_layer(x, dlatents_in[:, layer_idx], fmaps=fmaps, kernel=kernel, up=up, resample_kernel=resample_kernel, fused_modconv=fused_modconv)
if use_noise:
if randomize_noise:
noise = tf.random_normal([tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)
else:
noise = tf.cast(noise_inputs[layer_idx], x.dtype)
noise_strength = tf.get_variable('noise_strength', shape=[], initializer=tf.initializers.zeros())
x += noise * tf.cast(noise_strength, x.dtype)
return apply_bias_act(x, act=act, clamp=conv_clamp)
# Main block for one resolution.
def block(x, res): # res = 3..resolution_log2
x = tf.cast(x, 'float16' if res > resolution_log2 - num_fp16_res else dtype)
t = x
with tf.variable_scope('Conv0_up'):
x = layer(x, layer_idx=res*2-5, fmaps=nf(res-1), kernel=3, up=True)
with tf.variable_scope('Conv1'):
x = layer(x, layer_idx=res*2-4, fmaps=nf(res-1), kernel=3)
if architecture == 'resnet':
with tf.variable_scope('Skip'):
t = conv2d_layer(t, fmaps=nf(res-1), kernel=1, up=True, resample_kernel=resample_kernel)
x = (x + t) * (1 / np.sqrt(2))
return x
# Upsampling block.
def upsample(y):
with tf.variable_scope('Upsample'):
return upsample_2d(y, k=resample_kernel)
# ToRGB block.
def torgb(x, y, res): # res = 2..resolution_log2
with tf.variable_scope('ToRGB'):
t = modulated_conv2d_layer(x, dlatents_in[:, res*2-3], fmaps=num_channels, kernel=1, demodulate=False, fused_modconv=fused_modconv)
t = apply_bias_act(t, clamp=conv_clamp)
t = tf.cast(t, dtype)
if y is not None:
t += tf.cast(y, t.dtype)
return t
# Layers for 4x4 resolution.
y = None
with tf.variable_scope('4x4'):
with tf.variable_scope('Const'):
fmaps = fmap_const if fmap_const is not None else nf(1)
x = tf.get_variable('const', shape=[1, fmaps, 4, 4], initializer=tf.initializers.random_normal())
x = tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1])
with tf.variable_scope('Conv'):
x = layer(x, layer_idx=0, fmaps=nf(1), kernel=3)
if architecture == 'skip':
y = torgb(x, y, 2)
# Layers for >=8x8 resolutions.
for res in range(3, resolution_log2 + 1):
with tf.variable_scope(f'{2**res}x{2**res}'):
x = block(x, res)
if architecture == 'skip':
y = upsample(y)
if architecture == 'skip' or res == resolution_log2:
y = torgb(x, y, res)
images_out = y
assert images_out.dtype == tf.as_dtype(dtype)
return tf.identity(images_out, name='images_out')
#----------------------------------------------------------------------------
# Discriminator.
def D_main(
images_in, # First input: Images [minibatch, channel, height, width].
labels_in, # Second input: Conditioning labels [minibatch, label_size].
# Input dimensions.
num_channels = 3, # Number of input color channels. Overridden based on dataset.
resolution = 1024, # Input resolution. Overridden based on dataset.
label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
# Capacity.
fmap_base = 16384, # Overall multiplier for the number of feature maps.
fmap_decay = 1, # Log2 feature map reduction when doubling the resolution.
fmap_min = 1, # Minimum number of feature maps in any layer.
fmap_max = 512, # Maximum number of feature maps in any layer.
# Internal details.
mapping_layers = 0, # Number of additional mapping layers for the conditioning labels.
mapping_fmaps = None, # Number of activations in the mapping layers, None = default.
mapping_lrmul = 0.1, # Learning rate multiplier for the mapping layers.
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
mbstd_group_size = None, # Group size for the minibatch standard deviation layer, None = entire minibatch.
mbstd_num_features = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
dtype = 'float32', # Data type to use for intermediate activations and outputs.
num_fp16_res = 0, # Use FP16 for the N highest resolutions, regardless of dtype.
conv_clamp = None, # Clamp the output of convolution layers to [-conv_clamp, +conv_clamp], None = disable clamping.
resample_kernel = [1,3,3,1], # Low-pass filter to apply when resampling activations, None = box filter.
# Comparison methods.
augment_strength = 0, # AdaptiveAugment.get_strength_var() for pagan & adropout.
use_pagan = False, # pagan: Enable?
pagan_num = 16, # pagan: Number of active bits with augment_strength=1.
pagan_fade = 0.5, # pagan: Relative duration of fading in new bits.
score_size = 1, # auxrot: Number of scalars to output. Can vary between evaluations.
score_max = 1, # auxrot: Maximum number of scalars to output. Must be set at construction time.
use_spectral_norm = False, # spectralnorm: Enable?
adaptive_dropout = 0, # adropout: Standard deviation to use with augment_strength=1, 0 = disable.
freeze_layers = 0, # Freeze-D: Number of layers to freeze.
**_kwargs, # Ignore unrecognized keyword args.
):
resolution_log2 = int(np.log2(resolution))
assert resolution == 2**resolution_log2 and resolution >= 4
def nf(stage): return np.clip(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_min, fmap_max)
assert architecture in ['orig', 'skip', 'resnet']
if mapping_fmaps is None:
mapping_fmaps = nf(0)
act = nonlinearity
# Inputs.
images_in.set_shape([None, num_channels, resolution, resolution])
labels_in.set_shape([None, label_size])
images_in = tf.cast(images_in, dtype)
labels_in = tf.cast(labels_in, dtype)
# Label embedding and mapping.
if label_size > 0:
y = labels_in
with tf.variable_scope('LabelEmbed'):
y = apply_bias_act(dense_layer(y, fmaps=mapping_fmaps))
y = normalize_2nd_moment(y)
for idx in range(mapping_layers):
with tf.variable_scope(f'Mapping{idx}'):
y = apply_bias_act(dense_layer(y, fmaps=mapping_fmaps, lrmul=mapping_lrmul), act=act, lrmul=mapping_lrmul)
labels_in = y
# Adaptive multiplicative dropout.
def adrop(x):
if adaptive_dropout != 0:
s = [tf.shape(x)[0], x.shape[1]] + [1] * (x.shape.rank - 2)
x *= tf.cast(tf.exp(tf.random_normal(s) * (augment_strength * adaptive_dropout)), x.dtype)
return x
# Freeze-D.
cur_layer_idx = 0
def is_next_layer_trainable():
nonlocal cur_layer_idx
trainable = (cur_layer_idx >= freeze_layers)
cur_layer_idx += 1
return trainable
# Construct PA-GAN bit vector.
pagan_bits = None
pagan_signs = None
if use_pagan:
with tf.variable_scope('PAGAN'):
idx = tf.range(pagan_num, dtype=tf.float32)
active = (augment_strength * pagan_num - idx - 1) / max(pagan_fade, 1e-8) + 1
prob = tf.clip_by_value(active[np.newaxis, :], 0, 1) * 0.5
rnd = tf.random_uniform([tf.shape(images_in)[0], pagan_num])
pagan_bits = tf.cast(rnd < prob, dtype=tf.float32)
pagan_signs = tf.reduce_prod(1 - pagan_bits * 2, axis=1, keepdims=True)
# FromRGB block.
def fromrgb(x, y, res): # res = 2..resolution_log2
with tf.variable_scope('FromRGB'):
trainable = is_next_layer_trainable()
t = tf.cast(y, 'float16' if res > resolution_log2 - num_fp16_res else dtype)
t = adrop(conv2d_layer(t, fmaps=nf(res-1), kernel=1, trainable=trainable))
if pagan_bits is not None:
with tf.variable_scope('PAGAN'):
t += dense_layer(tf.cast(pagan_bits, t.dtype), fmaps=nf(res-1), trainable=trainable)[:, :, np.newaxis, np.newaxis]
t = apply_bias_act(t, act=act, clamp=conv_clamp, trainable=trainable)
if x is not None:
t += tf.cast(x, t.dtype)
return t
# Main block for one resolution.
def block(x, res): # res = 2..resolution_log2
x = tf.cast(x, 'float16' if res > resolution_log2 - num_fp16_res else dtype)
t = x
with tf.variable_scope('Conv0'):
trainable = is_next_layer_trainable()
x = apply_bias_act(adrop(conv2d_layer(x, fmaps=nf(res-1), kernel=3, trainable=trainable, use_spectral_norm=use_spectral_norm)), act=act, clamp=conv_clamp, trainable=trainable)
with tf.variable_scope('Conv1_down'):
trainable = is_next_layer_trainable()
x = apply_bias_act(adrop(conv2d_layer(x, fmaps=nf(res-2), kernel=3, down=True, resample_kernel=resample_kernel, trainable=trainable, use_spectral_norm=use_spectral_norm)), act=act, clamp=conv_clamp, trainable=trainable)
if architecture == 'resnet':
with tf.variable_scope('Skip'):
trainable = is_next_layer_trainable()
t = adrop(conv2d_layer(t, fmaps=nf(res-2), kernel=1, down=True, resample_kernel=resample_kernel, trainable=trainable))
x = (x + t) * (1 / np.sqrt(2))
return x
# Downsampling block.
def downsample(y):
with tf.variable_scope('Downsample'):
return downsample_2d(y, k=resample_kernel)
# Layers for >=8x8 resolutions.
x = None
y = images_in
for res in range(resolution_log2, 2, -1):
with tf.variable_scope(f'{2**res}x{2**res}'):
if architecture == 'skip' or res == resolution_log2:
x = fromrgb(x, y, res)
x = block(x, res)
if architecture == 'skip':
y = downsample(y)
# Layers for 4x4 resolution.
with tf.variable_scope('4x4'):
if architecture == 'skip':
x = fromrgb(x, y, 2)
x = tf.cast(x, dtype)
if mbstd_num_features > 0:
with tf.variable_scope('MinibatchStddev'):
x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features)
with tf.variable_scope('Conv'):
trainable = is_next_layer_trainable()
x = apply_bias_act(adrop(conv2d_layer(x, fmaps=nf(1), kernel=3, trainable=trainable, use_spectral_norm=use_spectral_norm)), act=act, clamp=conv_clamp, trainable=trainable)
with tf.variable_scope('Dense0'):
trainable = is_next_layer_trainable()
x = apply_bias_act(adrop(dense_layer(x, fmaps=nf(0), trainable=trainable)), act=act, trainable=trainable)
# Output layer (always trainable).
with tf.variable_scope('Output'):
if label_size > 0:
assert score_max == 1
x = apply_bias_act(dense_layer(x, fmaps=mapping_fmaps))
x = tf.reduce_sum(x * labels_in, axis=1, keepdims=True) / np.sqrt(mapping_fmaps)
else:
x = apply_bias_act(dense_layer(x, fmaps=score_max))
if pagan_signs is not None:
assert score_max == 1
x *= pagan_signs
scores_out = x[:, :score_size]
# Output.
assert scores_out.dtype == tf.as_dtype(dtype)
scores_out = tf.identity(scores_out, name='scores_out')
return scores_out
#----------------------------------------------------------------------------
github.com/NVlabs/stylegan2-ada
'Data-science > deep learning' 카테고리의 다른 글
trajactory GRU 코드 (0) | 2020.10.23 |
---|---|
RainNet v1.0 설명 (Spatiotemporal prediction) (0) | 2020.10.21 |
[StyleGan2-ada 실습] AFHQ 데이터 셋 이용해서 stylegan2-ada 학습하기 2 (0) | 2020.10.16 |
[StarGan v2] AFHQ 데이터 셋 및 pretrained network 성능 (0) | 2020.10.15 |
[StyleGan2-ada 실습] AFHQ 데이터 셋 이용해서 stylegan2-ada 학습하기 1 (0) | 2020.10.13 |