Skip to content

Commit 16b2c0b

Browse files
author
Andrey Voynov
committed
added generative models-2 seminar
1 parent 2f47b5a commit 16b2c0b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+6752
-0
lines changed

seminar07-gen_models_2/gans/__init__.py

Whitespace-only changes.
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import numpy as np
2+
import torch
3+
from torch import nn
4+
5+
from gans.models.BigGAN import BigGAN
6+
from gans.models.StyleGAN2.model import Generator as StyleGAN2Generator
7+
8+
9+
class ConditionedBigGAN(nn.Module):
10+
def __init__(self, big_gan, target_classes=(239)):
11+
super(ConditionedBigGAN, self).__init__()
12+
self.big_gan = big_gan
13+
14+
self.set_classes(target_classes)
15+
self.dim_z = self.big_gan.dim_z
16+
17+
def set_classes(self, target_classes):
18+
self.target_classes = nn.Parameter(torch.tensor(target_classes, dtype=torch.int64),
19+
requires_grad=False)
20+
21+
def mixed_classes(self, batch_size):
22+
if len(self.target_classes.data.shape) == 0:
23+
return self.target_classes.repeat(batch_size).cuda()
24+
else:
25+
return torch.from_numpy(
26+
np.random.choice(self.target_classes.cpu(), [batch_size])).cuda()
27+
28+
def forward(self, z, classes=None):
29+
if classes is None:
30+
classes = self.mixed_classes(z.shape[0]).to(z.device)
31+
return self.big_gan(z, self.big_gan.shared(classes))
32+
33+
34+
class StyleGAN2Wrapper(nn.Module):
35+
def __init__(self, g, shift_in_w):
36+
super(StyleGAN2Wrapper, self).__init__()
37+
self.style_gan2 = g
38+
self.dim_z = 512
39+
self.dim_shift = self.style_gan2.style_dim if shift_in_w else self.dim_z
40+
self.shift_in_w = shift_in_w
41+
42+
def forward(self, input, w_space=False, noise=None):
43+
if not isinstance(input, list):
44+
input = [input]
45+
return self.style_gan2(input, input_is_latent=w_space, noise=noise)[0]
46+
47+
48+
def make_biggan_config(resolution):
49+
attn_dict = {128: '64', 256: '128', 512: '64'}
50+
dim_z_dict = {128: 120, 256: 140, 512: 128}
51+
config = {
52+
'G_param': 'SN', 'D_param': 'SN',
53+
'G_ch': 96, 'D_ch': 96,
54+
'D_wide': True, 'G_shared': True,
55+
'shared_dim': 128, 'dim_z': dim_z_dict[resolution],
56+
'hier': True, 'cross_replica': False,
57+
'mybn': False, 'G_activation': nn.ReLU(inplace=True),
58+
'G_attn': attn_dict[resolution],
59+
'norm_style': 'bn',
60+
'G_init': 'ortho', 'skip_init': True, 'no_optim': True,
61+
'G_fp16': False, 'G_mixed_precision': False,
62+
'accumulate_stats': False, 'num_standing_accumulations': 16,
63+
'G_eval_mode': True,
64+
'BN_eps': 1e-04, 'SN_eps': 1e-04,
65+
'num_G_SVs': 1, 'num_G_SV_itrs': 1, 'resolution': resolution,
66+
'n_classes': 1000}
67+
return config
68+
69+
70+
def make_big_gan(weights, target_classes=None, resolution=128, n_classes=1000):
71+
config = make_biggan_config(resolution)
72+
config['n_classes'] = n_classes
73+
G = BigGAN.Generator(**config)
74+
G.load_state_dict(torch.load(weights, map_location=torch.device('cpu')), strict=False)
75+
76+
if target_classes is None:
77+
target_classes = np.arange(0, n_classes, 1)
78+
return ConditionedBigGAN(G, target_classes).cuda().eval()
79+
80+
81+
def make_stylegan2(resolution, weights, shift_in_w=True, target_key='g_ema', g_kwargs={}):
82+
G = StyleGAN2Generator(resolution, 512, 8, **g_kwargs)
83+
G.load_state_dict(torch.load(weights)[target_key] if target_key is not None else \
84+
torch.load(weights))
85+
G.cuda().eval()
86+
87+
return StyleGAN2Wrapper(G, shift_in_w=shift_in_w)

0 commit comments

Comments
 (0)