|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | +from torch import nn |
| 4 | + |
| 5 | +from gans.models.BigGAN import BigGAN |
| 6 | +from gans.models.StyleGAN2.model import Generator as StyleGAN2Generator |
| 7 | + |
| 8 | + |
| 9 | +class ConditionedBigGAN(nn.Module): |
| 10 | + def __init__(self, big_gan, target_classes=(239)): |
| 11 | + super(ConditionedBigGAN, self).__init__() |
| 12 | + self.big_gan = big_gan |
| 13 | + |
| 14 | + self.set_classes(target_classes) |
| 15 | + self.dim_z = self.big_gan.dim_z |
| 16 | + |
| 17 | + def set_classes(self, target_classes): |
| 18 | + self.target_classes = nn.Parameter(torch.tensor(target_classes, dtype=torch.int64), |
| 19 | + requires_grad=False) |
| 20 | + |
| 21 | + def mixed_classes(self, batch_size): |
| 22 | + if len(self.target_classes.data.shape) == 0: |
| 23 | + return self.target_classes.repeat(batch_size).cuda() |
| 24 | + else: |
| 25 | + return torch.from_numpy( |
| 26 | + np.random.choice(self.target_classes.cpu(), [batch_size])).cuda() |
| 27 | + |
| 28 | + def forward(self, z, classes=None): |
| 29 | + if classes is None: |
| 30 | + classes = self.mixed_classes(z.shape[0]).to(z.device) |
| 31 | + return self.big_gan(z, self.big_gan.shared(classes)) |
| 32 | + |
| 33 | + |
| 34 | +class StyleGAN2Wrapper(nn.Module): |
| 35 | + def __init__(self, g, shift_in_w): |
| 36 | + super(StyleGAN2Wrapper, self).__init__() |
| 37 | + self.style_gan2 = g |
| 38 | + self.dim_z = 512 |
| 39 | + self.dim_shift = self.style_gan2.style_dim if shift_in_w else self.dim_z |
| 40 | + self.shift_in_w = shift_in_w |
| 41 | + |
| 42 | + def forward(self, input, w_space=False, noise=None): |
| 43 | + if not isinstance(input, list): |
| 44 | + input = [input] |
| 45 | + return self.style_gan2(input, input_is_latent=w_space, noise=noise)[0] |
| 46 | + |
| 47 | + |
| 48 | +def make_biggan_config(resolution): |
| 49 | + attn_dict = {128: '64', 256: '128', 512: '64'} |
| 50 | + dim_z_dict = {128: 120, 256: 140, 512: 128} |
| 51 | + config = { |
| 52 | + 'G_param': 'SN', 'D_param': 'SN', |
| 53 | + 'G_ch': 96, 'D_ch': 96, |
| 54 | + 'D_wide': True, 'G_shared': True, |
| 55 | + 'shared_dim': 128, 'dim_z': dim_z_dict[resolution], |
| 56 | + 'hier': True, 'cross_replica': False, |
| 57 | + 'mybn': False, 'G_activation': nn.ReLU(inplace=True), |
| 58 | + 'G_attn': attn_dict[resolution], |
| 59 | + 'norm_style': 'bn', |
| 60 | + 'G_init': 'ortho', 'skip_init': True, 'no_optim': True, |
| 61 | + 'G_fp16': False, 'G_mixed_precision': False, |
| 62 | + 'accumulate_stats': False, 'num_standing_accumulations': 16, |
| 63 | + 'G_eval_mode': True, |
| 64 | + 'BN_eps': 1e-04, 'SN_eps': 1e-04, |
| 65 | + 'num_G_SVs': 1, 'num_G_SV_itrs': 1, 'resolution': resolution, |
| 66 | + 'n_classes': 1000} |
| 67 | + return config |
| 68 | + |
| 69 | + |
| 70 | +def make_big_gan(weights, target_classes=None, resolution=128, n_classes=1000): |
| 71 | + config = make_biggan_config(resolution) |
| 72 | + config['n_classes'] = n_classes |
| 73 | + G = BigGAN.Generator(**config) |
| 74 | + G.load_state_dict(torch.load(weights, map_location=torch.device('cpu')), strict=False) |
| 75 | + |
| 76 | + if target_classes is None: |
| 77 | + target_classes = np.arange(0, n_classes, 1) |
| 78 | + return ConditionedBigGAN(G, target_classes).cuda().eval() |
| 79 | + |
| 80 | + |
| 81 | +def make_stylegan2(resolution, weights, shift_in_w=True, target_key='g_ema', g_kwargs={}): |
| 82 | + G = StyleGAN2Generator(resolution, 512, 8, **g_kwargs) |
| 83 | + G.load_state_dict(torch.load(weights)[target_key] if target_key is not None else \ |
| 84 | + torch.load(weights)) |
| 85 | + G.cuda().eval() |
| 86 | + |
| 87 | + return StyleGAN2Wrapper(G, shift_in_w=shift_in_w) |
0 commit comments