How To Train A GAN On 128 GPUs Using PyTorch

William Falcon
Towards Data Science
4 min readAug 14, 2019

--

If you’re into GANs, you know it can take a reaaaaaally long time to generate nice-looking outputs. With distributed training we can cut down that time dramatically.

In a different tutorial, I cover 9 things you can do to speed up your PyTorch models. In this tutorial we’ll implement a GAN, and train it on 32 machines (each with 4 GPUs) using distributed DataParallel.

Generator

First, we need to define a generator. This network will take as input random noise and it will generate an image from the latent space indexed by the noise.

This generator will also get its own optimizer

opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))

Discriminator

Real or fake pirate?

The discriminator is a just a classifier. It takes as input an image and decides whether it is real or not.

The discriminator also gets its own optimizer.

opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))

GAN training

There are many variations of GAN training, but this is the simplest, just to illustrate the major ideas.

You can think about it as basically alternating training between the generator and discriminator. So, batch 0 trains the generator, batch 1 trains the discriminator, etc…

To train the generator we do the following:

Generator Training

From the above we see that we sample some normal noise and give it to the discriminator, because it’s fake we generate fake labels asking the discriminator to classify as fake.

We backpropagate the above and move on to the next batch which trains the discriminator.

To train the discriminator we do the following

Discriminator Training

The discriminator calculates 2 losses. First, how well can it detect real examples? Second, how well can it detect fake examples? The full loss is just the average of these two.

LightningModule

PyTorch Lightning

To train this system on 128 GPUs we’re going to use a lightweight wrapper on top of PyTorch called PyTorch-Lightning which automates everything else we haven’t discussed here (training loop, validation, etc…).

The beauty of this library is that the only thing you need to define is the system in a LightningModule interface and you get free GPU and cluster support.

Here’s the full system as a LightningModule

This abstraction is very simple. The meat of the training always happens in the training_step, so you can look at any repository in the world and know what happens where!

Let’s step through this one step at a time.

First, we define the models this system will use in the __init__ method.

Next, we define what we want the output of the system to be in the .forward method. This means if I run this model from a server or API I’m always giving it noise and getting back a sampled image.

Now comes the meat of the system. We define the complex logic of any system in the training_step method, in this case the GAN training.

We also cache the images we generate for the discriminator and log examples with every batch.

Finally, we configure the optimizers and data we want.

That’s it! In summary, we have to specify the things we care about:

  1. Data
  2. Models involved (init)
  3. Optimizers involved
  4. The core training logic for the FULL system (training_step)
  5. There’s an optional validation_step for other systems that might need to calculate accuracy or use different datasets.

Training on 128 GPUs

This part is actually trivial now. With the GAN system defined, we can simply pass this into a Trainer object and tell it to train on 32 nodes each with 4 GPUs each.

Now we submit a job to SLURM that has these flags:

# SLURM SUBMIT SCRIPT
#SBATCH --gres=gpu:4
#SBATCH --nodes=32
#SBATCH --ntasks-per-node=4
#SBATCH --mem=0
#SBATCH --time=02:00:00
# activate conda env
conda activate my_env
# run script from above
python gan.py

and our model will train using all 128 GPUs!

In the background, Lightning will use DistributedDataParallel and configure everything to work correctly for you. DistributedDataParallel is explained in-depth in this tutorial.

At a high-level, DistributedDataParallel gives each GPU a portion of the dataset, inits the model on that GPU and only syncs gradients between models during training. So, it’s more like “distributed dataset + gradient syncing”.

But I Don’t Have 128 GPUs

No worries! No changed needed to your model. Just remove the nb_gpu_nodes parameter from the trainer to use all 4 GPUs on your machine:

Then run the script on the machine with the 4 GPUs.

The full code is available here.

This GAN code was adapted from this awesome GAN repository and refactored to use PyTorch Lightning

--

--

⚡️PyTorch Lightning Creator • PhD Student, AI (NYU, Facebook AI research).