|
4 | 4 |
|
5 | 5 | ## Users can get the diverse models of pytorch_gan_zoo by calling |
6 | 6 | hub_model = hub.load( |
7 | | - '??/pytorch_gan_zoo:master', |
| 7 | + 'facebookresearch/pytorch_gan_zoo:master', |
8 | 8 | $MODEL_NAME, # |
9 | 9 | config = None, |
10 | 10 | useGPU = True, |
11 | 11 | pretrained=False) # (Not pretrained models online yet) |
12 | 12 |
|
13 | | -Available model'names are [DCGAN, PGAN]. |
| 13 | +Available model'names are [DCGAN, PGAN, StyleGAN]. |
14 | 14 | The config option should be a dictionnary defining the training parameters of |
15 | 15 | the model. See ??/pytorch_gan_zoo/models/trainer/standard_configurations to see |
16 | 16 | all possible options |
@@ -99,6 +99,28 @@ def PGAN(pretrained=False, *args, **kwargs): |
99 | 99 | return model |
100 | 100 |
|
101 | 101 |
|
| 102 | +def StyleGAN(pretrained=False, *args, **kwargs): |
| 103 | + """ |
| 104 | + NVIDIA StyleGAN |
| 105 | + pretrained (bool): load a 1024x1024 model trained on FlickrHQ |
| 106 | + """ |
| 107 | + from models.styleGAN import StyleGAN |
| 108 | + if 'config' not in kwargs or kwargs['config'] is None: |
| 109 | + kwargs['config'] = {} |
| 110 | + |
| 111 | + model = StyleGAN(useGPU=kwargs.get('useGPU', True), |
| 112 | + storeAVG=True, |
| 113 | + **kwargs['config']) |
| 114 | + |
| 115 | + checkpoint = 'https://dl.fbaipublicfiles.com/gan_zoo/StyleGAN/FFHQ_styleGAN-7cbdec00.pth' |
| 116 | + if pretrained: |
| 117 | + print("Loading default model : Flickr-HQ") |
| 118 | + state_dict = model_zoo.load_url(checkpoint, |
| 119 | + map_location='cpu') |
| 120 | + model.load_state_dict(state_dict) |
| 121 | + return model |
| 122 | + |
| 123 | + |
102 | 124 | def DCGAN(pretrained=False, *args, **kwargs): |
103 | 125 | """ |
104 | 126 | DCGAN basic model |
|
0 commit comments