Tharindu Hasthika
Tharindu Hasthika's Blog

Tharindu Hasthika's Blog

PyTorch Lightning Basics

Subscribe to my newsletter and never miss my upcoming articles


PyTorch Lightning is a framework built on top of the PyTorch deep learning framework for ease of use, think of it as a Keras like API for the PyTorch framework. I have planned to write these series of articles from my own experience in using it for my research purposes. These articles assume that you have a good grasp of Deep Learning and PyTorch.


To install PyTorch Lightning use pip install pytorch-lightning.


First, we will go over some of the important concepts in PyTorch Lightning so that it would be easier to work with them later. The PyTorch Lightning framework has been able to capture most of the requirements of people who are creating deep learning models. At the end of this article we will be going through a mock dataset in order to show the full framework in action.


A model is the neural network model that we need in order to learn some particular task. For that we have the pytorch_lightning.LightningModule this is similar to the PyTorch module, nn.Module.

The scaffold for a basic model is as follows.

import torch
import pytorch_lightning as pl

class MyModel(pl.LightningModule):

    def __init__(self):

    def forward(self, x):
        ## the forward pass

    def configure_optimizers(self):
        ## configure the optimizer that is used by the model

        # optimizer = torch.optim.Adam(self.parameters(),
        # return optimizer

    def training_step(self, batch, batch_idx):
        ## the training step

    def validation_step(self, batch, batch_idx):
        ## the validation step

    def test_step(self, batch, batch_idx):
        ## the test step

The forward method is similar to the one in PyTorch it is called whenever the input is needed to be fed into the network for a forward pass.

Likewise each of the methods training_step, validation_step and test_step are called when the model is in training, validation and test phases respectively.


In order to load data into the model we have to create a class that extends from the PyTorch Dataset class. Even though the PyTorch Lightning framework has its own LightningDataModule class it in turn depends on the PyTorch Dataset class.

Below is a way to handle data pipelining in PyTorch Lightning.

from import Dataset
import pandas as pd

class MyDataset(Dataset):

    def __init__(self, dataset_type="train"):

        if dataset_type == "train":
            ## load the train dataset
            self.df = ...
        elif dataset_type == "validation":
            self.df = ...
        elif dataset_type == "test":
            self.df = ...

    def get_features(self, index):
        ## extract the needed features
        X = ...
        return X

    def get_label(self, index):
        ## extract the needed label data
        y = ...
        return y

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        X = self.get_features(index)
        y = self.get_label(index)
        return (X, y)

The __getitem__ method is important here because by using this we can mould the data however we want it to be presented to the model.

Here, you can load the data anyway you like but the flow would be similar. In this case we have used dataset_type to differentiate between the types of data that we need, but you can use a method that is best for your particular need.


In order for the dataset to load the data in an efficiently PyTorch has the DataLoader class which loads the data in batches and also uses concurrency to speed up the process.

from import DataLoader

dataset = MyDataset()
dataloader = DataLoader(
    batch_size=32, # number of samples to load at a time
    num_workers=4 # number of threads (= number of processors)


The Trainer, as the name implies is the class responsible for the training and evaluation of the models that you create. It has a myriad of options that you can go through in the official documentation. For this article we will go through a subset of these options that are ciritical for operating with it.

import pytorch_lightning as pl
from import DataLoader

# create the model
model = MyModel()

# create the datasets
train_dataset = MyDataset(dataset_type="train")
validation_dataset = MyDataset(dataset_type="validation")

# create the dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=32, num_workers=4)
validation_dataloader = DataLoader(validation_dataset, batch_size=32, num_workers=4)

# create the trainer
trainer = pl.Trainer(
    gpus=1, # number of gpus to use -1 to use all
    max_epochs=10 # maximum number of epochs the trainer will execute

Trainer Constructor Arguements

  • gpus - Specifies how many gpus to use for the training purpose, by default it uses none.
  • max_epochs - Specifies the maximum number of epochs (how many times the dataset is shown to the model).

fit Method Arguments

  • train_dataloader - Specify the dataloader which is used by the trainer.
  • val_dataloaders - This can be either a list of dataloaders or a single dataloader, which is then used by the trainer to evaluate the model.

A Simple Example

To make use of the stuff that we have gone through, we will be making a simple model that can identify data points that belong to 4 classes. These data points will be created using sklearn.

1. Creating the Dataset

from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

classes = 4
n_samples = 1000

(X, y) = make_blobs(n_samples=n_samples, n_features=2, centers=classes,  cluster_std=2.5, center_box=(-10, 10) , random_state=42)

## Splitting the datasets
(X_train, X_test, y_train, y_test) = train_test_split(X, y, test_size=0.2, random_state=42)
(X_train, X_validation, y_train, y_validation) = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

colors = ['red', 'green', 'blue', 'black', 'purple']
cdict = dict(map(lambda x: (x, colors[x]), range(0, classes)))
fig, ax = plt.subplots()
for g in np.unique(y):
    ix = np.where(y == g)
    ax.scatter(X[ix, 0], X[ix, 1], c = cdict[g], label = g)

Below are the clustering of points that we are trying to fit a model to.

Random Data Points

2. Create a Custom Dataset

import torch
from import Dataset, DataLoader

class MyCustomDataset(Dataset):

    def __init__(self, X, y):
        self.X = X
        self.y = y

        self.count = X.shape[0]

    def __len__(self):
        return self.count

    def __getitem__(self, index):
        X = self.X[index]
        y = self.y[index]
        return (torch.tensor(X, dtype=torch.float32), y)

ds_train = MyCustomDataset(X_train, y_train)
ds_validation = MyCustomDataset(X_validation, y_validation)
ds_test = MyCustomDataset(X_test, y_test)

dl_train = DataLoader(ds_train, batch_size=16, num_workers=2)
dl_validation = DataLoader(ds_validation, batch_size=16, num_workers=2)
dl_test = DataLoader(ds_test, batch_size=16, num_workers=2)

3. Create the Custom Model

## create the model class

import pytorch_lightning as pl

import torch
from torch import nn
from torch.nn import functional as F

class MyModel(pl.LightningModule):

    def __init__(self):

        ## make the model
        self.classifier = nn.Sequential(
            nn.Linear(in_features=2, out_features=4),
            nn.Linear(in_features=4, out_features=4)

        ## use cross entropy loss for categorical problems
        self.loss = F.cross_entropy

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        return optimizer

    def forward(self, x):
        x = self.classifier(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch

        y_logit = self(x)
        loss = self.loss(y_logit, y)
        pred = F.softmax(y_logit, dim=1)

        self.log('train/loss', loss, prog_bar=True, on_step=False, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        y_logit = self(x)
        loss = self.loss(y_logit, y)
        pred = F.softmax(y_logit, dim=1)

        self.log("val/loss", loss, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch

        y_logit = self(x)
        loss = self.loss(y_logit, y)
        pred = F.softmax(y_logit, dim=1)

        self.log("test/loss", loss)

4. Train the Model

trainer = pl.Trainer(

model = MyModel()

With the above code we execute the training_step and validation_step of the model to train and also validate the model.

5. Test the Model

After training the model, we can use the test set to check the model performance with unseen data.


The output of training and testing of the model is as follows.

Training Output

6. Save the Checkpoint

We will manually save the model for now, but the Trainer has more advanced options that allows us to automate the saving of models. You can check the documentation for more details.



In this article we have gone through each of the main steps that are necessary for using the PyTorch Lightning framework. The framework is a really wonderful addition on top of PyTorch framework. I will be posting more topics regarding PyTorch Lightning and Deep Learning in general.

Share this