Generative Adversarial Networks (GANs) From Scratch

Generative Adversarial Networks (GANs) are incredibly powerful in the world of machine learning. They excel at creating synthetic data that’s so close to real data, it’s almost uncanny. Today, we’re diving into how these networks work and even trying our hand at building a simple GAN framework using PyTorch.

But why the term “Adversarial”?

Well, it’s because GANs work on a fascinating dynamic between two networks: the Generator and the Discriminator. The Generator takes in random noise and tries to generate an output. At first, this output might seem like random noise, but the magic happens as it gets refined through a process called gradient descent learning. Think of it like sculpting—shaping that noise into something meaningful.

Now, here’s where the adversarial aspect kicks in. We need a way to guide the Generator in the right direction. Since we don’t have a specific desired output, we have to think broadly. One clever approach is to classify the output as either real or fake and then provide feedback based on that.

Discriminator

Enter the Discriminator. Its job is to differentiate between real and fake data. We throw both the Generator’s output (the fake sample) and a real sample at it. The Discriminator then does its thing, ideally giving a 1 for real input and a 0 for fake input.

Before we give any feedback to the Generator, there’s another round of training involved. This part can be a bit confusing at first, but bear with me. We train the Discriminator first, teaching it to distinguish between real and fake. Once it’s got that down, it’s the Generator’s turn. Now, the Generator’s goal is to fool the Discriminator—making it classify the fake data as real.

To achieve this, we feed the Generator’s output into the Discriminator, calculate the loss, and then adjust the Generator’s weights accordingly. Crucially, we leave the Discriminator untouched since it’s already been trained.

In essence, the Generator and Discriminator work together as a sort of tag team, but with the Discriminator’s weights frozen during the Generator’s training phase.

Training

The original paper uses the following training algorithm :

[latexpage]
Sure, here’s the algorithm with plain text equations in paragraph form:

Algorithm 1: Minibatch Stochastic Gradient Descent Training of Generative Adversarial Nets

Set the number of steps for updating the discriminator, \( k \) (a hyperparameter, typically set to 1 in experiments).
For each training iteration, iterate over each of the \( k \) steps.
Sample a minibatch of \( m \) noise samples \({z(1), . . . , z(m)}\) from the noise prior \( p_g(z) \).
Sample a minibatch of \( m \) examples \({x(1), . . . , x(m)}\) from the data generating distribution \( p_{\text{data}}(x) \).
Update the discriminator by ascending its stochastic gradient:
\[
\text{Gradient}_d = \frac{1}{m} \sum_{i=1}^m [\log D(x(i)) + \log(1 – D(G(z(i))))]
\]

After completing the steps for updating the discriminator, sample another minibatch of \( m \) noise samples \({z(1), . . . , z(m)}\) from the noise prior \( p_g(z) \). Update the generator by descending its stochastic gradient:

\[
\text{Gradient}_g = \frac{1}{m} \sum_{i=1}^m \log(1 – D(G(z(i))))
\]

The gradient-based updates can utilize any standard gradient-based learning rule, with momentum being used in the experiments.

Noise prior here means random noise that we would feed to our model and the data-generating distribution simply means the real data.

Implementation

We will implement a basic model, that uses images since this is the most trending usage. We will work with anime facial images found here.

Basic Imports
# imports 
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import os
from skimage import io, transform
Generator and Discriminator Classes

The Generator object would take input of the image shape which consists of batch_size, channels, dimension 1, and dimension 2. Along with that, it takes input from the latent dimension, which will be the dimension of the noise that will be fed to the model. Why is it called latent dimension? Well latent dimension comes from latent space which is the compressed version of an image, which captures all the details. Since this lower dimensional noise input will be upscaled to a new image, we prefer calling it the latent dimension. Not a very good explanation, but should do for now.

The Discriminator object takes input from the image shape to use it for its model and outputs the binary classification that is whether it’s a real or fake image.

class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, img_shape[1]*img_shape[2]*img_shape[3]),  # Adjusted output size to match img_shape
            nn.Tanh(),
        )

    def forward(self, z):
        img = self.model(z)
        img = img.reshape(*self.img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.img_shape = img_shape
        self.flatten_dim = img_shape[1] * img_shape[2] * img_shape[3]  # Compute the flatten dimension
        self.model = nn.Sequential(
            nn.Linear(self.flatten_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        # Flatten the input image
        img_flat = img.view(-1, self.flatten_dim)
        validity = self.model(img_flat)
        batch_size = img.size(0)
        validity = validity.reshape(batch_size, 1)
        return validity
Defining Dataset Loaders

Basic custom dataset loading that we will use further in our training.

class AnimeDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, root_dir, transform=None):
        """
        Arguments:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.images = []
        images = os.listdir(root_dir)
        for i in range(len(images)):
            if 'seed' in images[i]:
                self.images.append(images[i])
        
                
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.images[idx])
        image = io.imread(img_name)
        if self.transform:
            image = self.transform(image)
        return image
Training

The training process is simple, for every iteration of training, we first train the discriminator, by feeding the fake image and the real image, and later averaging their losses, we calculate the gradients and backpropagate and train the Discriminator.

Once that’s done, we repeat the process with a new fake image generated, passed through the discriminator, and its loss calculated and backpropagated, but we only update the weights of the generator ( recall we don’t touch the discriminator this time). Similar to the algorithm used in the paper we train the Generator model to fool the discriminator.

# training 

# Set device
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

device = torch.device("cuda:0")

# Hyperparameters
lr = 0.0002
batch_size = 64
latent_dim = 100
img_shape = (64,3, 100,100 )
epochs = 20

# Initialize networks
generator = Generator(latent_dim, img_shape).to(device)
discriminator = Discriminator(img_shape).to(device)

# Initialize optimizers and loss function
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
criterion = nn.BCELoss()

# Prepare data
transform = transforms.Compose([ transforms.ToPILImage(),transforms.Resize((100, 100)),transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = AnimeDataset(root_dir='/kaggle/input/gananime-lite/out2/', transform = transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop
for epoch in range(epochs):
    for i, real_imgs in enumerate(dataloader):
        try:
            # Train Discriminator
            optimizer_D.zero_grad()

            # Reshape real images to match discriminator input shape
            real_imgs = real_imgs.to(device)
            real_imgs = real_imgs.view(real_imgs.size(0), -1)
            batch_size = real_imgs.size(0)
            real_labels = torch.ones(batch_size, 1).to(device)

            # Generate fake images
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_imgs = generator(z)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Discriminator loss for real images
            real_pred = discriminator(real_imgs)
            d_loss_real = criterion(real_pred, real_labels)

            # Discriminator loss for fake images
            fake_pred = discriminator(fake_imgs.detach())
            d_loss_fake = criterion(fake_pred, fake_labels)

            # Total discriminator loss
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()

            # Generate fake images
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_imgs = generator(z)

            # Discriminator loss for fake images
            fake_pred = discriminator(fake_imgs)
            g_loss = criterion(fake_pred, real_labels)

            # Update Generator
            g_loss.backward()
            optimizer_G.step()

            if i % 100 == 0:
                print(
                    f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]"
                )
        except:
            print("error occured")
            continue

After 20 epochs of training, here’s what you get :

Not bad but can be done better for a different configuration. and better learning.

So that sums up our article on Generative Adversarial Networks and implementing them from scratch. Stay tuned for more informative articles.

Scroll to Top