diff --git a/examples/gan/dcgan.py b/examples/gan/dcgan.py index 03669add630..0980f4dc085 100644 --- a/examples/gan/dcgan.py +++ b/examples/gan/dcgan.py @@ -62,7 +62,7 @@ class Generator(Net): nf (int): Number of filters in the second-to-last deconv layer """ - def __init__(self, z_dim, nf): + def __init__(self, z_dim, nf, nc): super(Generator, self).__init__() self.net = nn.Sequential( @@ -88,7 +88,7 @@ def __init__(self, z_dim, nf): nn.ReLU(inplace=True), # state size. (nf) x 32 x 32 - nn.ConvTranspose2d(in_channels=nf, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False), + nn.ConvTranspose2d(in_channels=nf, out_channels=nc, kernel_size=4, stride=2, padding=1, bias=False), nn.Tanh() # state size. (nc) x 64 x 64 @@ -107,13 +107,13 @@ class Discriminator(Net): nf (int): Number of filters in the first conv layer. """ - def __init__(self, nf): + def __init__(self, nc, nf): super(Discriminator, self).__init__() self.net = nn.Sequential( # input is (nc) x 64 x 64 - nn.Conv2d(in_channels=3, out_channels=nf, kernel_size=4, stride=2, padding=1, bias=False), + nn.Conv2d(in_channels=nc, out_channels=nf, kernel_size=4, stride=2, padding=1, bias=False), nn.LeakyReLU(0.2, inplace=True), # state size. (nf) x 32 x 32 @@ -166,7 +166,6 @@ def check_dataset(dataset, dataroot): dataset (data.Dataset): torchvision Dataset object """ - to_rgb = transforms.Lambda(lambda img: img.convert('RGB')) resize = transforms.Resize(64) crop = transforms.CenterCrop(64) to_tensor = transforms.ToTensor() @@ -177,31 +176,35 @@ def check_dataset(dataset, dataroot): crop, to_tensor, normalize])) + nc = 3 elif dataset == 'lsun': dataset = dset.LSUN(root=dataroot, classes=['bedroom_train'], transform=transforms.Compose([resize, crop, to_tensor, normalize])) + nc = 3 elif dataset == 'cifar10': dataset = dset.CIFAR10(root=dataroot, download=True, transform=transforms.Compose([resize, to_tensor, normalize])) + nc = 3 elif dataset == 'mnist': - dataset = dset.MNIST(root=dataroot, download=True, transform=transforms.Compose([to_rgb, - resize, + dataset = dset.MNIST(root=dataroot, download=True, transform=transforms.Compose([resize, to_tensor, normalize])) + nc = 1 elif dataset == 'fake': dataset = dset.FakeData(size=256, image_size=(3, 64, 64), transform=to_tensor) + nc = 3 else: raise RuntimeError("Invalid dataset name: {}".format(dataset)) - return dataset + return dataset, nc def main(dataset, dataroot, @@ -216,9 +219,13 @@ def main(dataset, dataroot, # seed check_manual_seed(seed) + # data + dataset, num_channels = check_dataset(dataset, dataroot) + loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) + # netowrks - netG = Generator(z_dim, g_filters).to(device) - netD = Discriminator(d_filters).to(device) + netG = Generator(z_dim, g_filters, num_channels).to(device) + netD = Discriminator(num_channels, d_filters).to(device) # criterion bce = nn.BCELoss() @@ -227,10 +234,6 @@ def main(dataset, dataroot, optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) - # data - dataset = check_dataset(dataset, dataroot) - loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) - # load pre-trained models if saved_G: netG.load_state_dict(torch.load(saved_G))