Skip to content

Commit

Permalink
Removed img.convert('RGB') for MNSIT. Generator and Discriminator use…
Browse files Browse the repository at this point in the history
… num_channels in image (pytorch#401)
  • Loading branch information
anmolsjoshi authored and vfdev-5 committed Jan 16, 2019
1 parent 5d72980 commit 6c8f7c9
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions examples/gan/dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Generator(Net):
nf (int): Number of filters in the second-to-last deconv layer
"""

def __init__(self, z_dim, nf):
def __init__(self, z_dim, nf, nc):
super(Generator, self).__init__()

self.net = nn.Sequential(
Expand All @@ -88,7 +88,7 @@ def __init__(self, z_dim, nf):
nn.ReLU(inplace=True),

# state size. (nf) x 32 x 32
nn.ConvTranspose2d(in_channels=nf, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False),
nn.ConvTranspose2d(in_channels=nf, out_channels=nc, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()

# state size. (nc) x 64 x 64
Expand All @@ -107,13 +107,13 @@ class Discriminator(Net):
nf (int): Number of filters in the first conv layer.
"""

def __init__(self, nf):
def __init__(self, nc, nf):
super(Discriminator, self).__init__()

self.net = nn.Sequential(

# input is (nc) x 64 x 64
nn.Conv2d(in_channels=3, out_channels=nf, kernel_size=4, stride=2, padding=1, bias=False),
nn.Conv2d(in_channels=nc, out_channels=nf, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),

# state size. (nf) x 32 x 32
Expand Down Expand Up @@ -166,7 +166,6 @@ def check_dataset(dataset, dataroot):
dataset (data.Dataset): torchvision Dataset object
"""
to_rgb = transforms.Lambda(lambda img: img.convert('RGB'))
resize = transforms.Resize(64)
crop = transforms.CenterCrop(64)
to_tensor = transforms.ToTensor()
Expand All @@ -177,31 +176,35 @@ def check_dataset(dataset, dataroot):
crop,
to_tensor,
normalize]))
nc = 3

elif dataset == 'lsun':
dataset = dset.LSUN(root=dataroot, classes=['bedroom_train'], transform=transforms.Compose([resize,
crop,
to_tensor,
normalize]))
nc = 3

elif dataset == 'cifar10':
dataset = dset.CIFAR10(root=dataroot, download=True, transform=transforms.Compose([resize,
to_tensor,
normalize]))
nc = 3

elif dataset == 'mnist':
dataset = dset.MNIST(root=dataroot, download=True, transform=transforms.Compose([to_rgb,
resize,
dataset = dset.MNIST(root=dataroot, download=True, transform=transforms.Compose([resize,
to_tensor,
normalize]))
nc = 1

elif dataset == 'fake':
dataset = dset.FakeData(size=256, image_size=(3, 64, 64), transform=to_tensor)
nc = 3

else:
raise RuntimeError("Invalid dataset name: {}".format(dataset))

return dataset
return dataset, nc


def main(dataset, dataroot,
Expand All @@ -216,9 +219,13 @@ def main(dataset, dataroot,
# seed
check_manual_seed(seed)

# data
dataset, num_channels = check_dataset(dataset, dataroot)
loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True)

# netowrks
netG = Generator(z_dim, g_filters).to(device)
netD = Discriminator(d_filters).to(device)
netG = Generator(z_dim, g_filters, num_channels).to(device)
netD = Discriminator(num_channels, d_filters).to(device)

# criterion
bce = nn.BCELoss()
Expand All @@ -227,10 +234,6 @@ def main(dataset, dataroot,
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta_1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta_1, 0.999))

# data
dataset = check_dataset(dataset, dataroot)
loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True)

# load pre-trained models
if saved_G:
netG.load_state_dict(torch.load(saved_G))
Expand Down

0 comments on commit 6c8f7c9

Please sign in to comment.