PyTorch Lightning Basics
PyTorch Lightning is a framework built on top of the PyTorch deep learning framework for ease of use, think of it as a Keras like API for the PyTorch framework. I have planned to write these series of articles from my own experience in using it for my research purposes. These articles assume that you have a good grasp of Deep Learning and PyTorch.
To install PyTorch Lightning use
pip install pytorch-lightning.
First, we will go over some of the important concepts in PyTorch Lightning so that it would be easier to work with them later. The PyTorch Lightning framework has been able to capture most of the requirements of people who are creating deep learning models. At the end of this article we will be going through a mock dataset in order to show the full framework in action.
A model is the neural network model that we need in order to learn some particular task. For that we have the
pytorch_lightning.LightningModule this is similar to the PyTorch module,
The scaffold for a basic model is as follows.
import torch import pytorch_lightning as pl class MyModel(pl.LightningModule): def __init__(self): super().__init__() def forward(self, x): ## the forward pass def configure_optimizers(self): ## configure the optimizer that is used by the model # optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) # return optimizer def training_step(self, batch, batch_idx): ## the training step def validation_step(self, batch, batch_idx): ## the validation step def test_step(self, batch, batch_idx): ## the test step
forward method is similar to the one in PyTorch it is called whenever the input is needed to be fed into the network for a forward pass.
Likewise each of the methods
test_step are called when the model is in
test phases respectively.
In order to load data into the model we have to create a class that extends from the PyTorch
Dataset class. Even though the PyTorch Lightning framework has its own
LightningDataModule class it in turn depends on the PyTorch
Below is a way to handle data pipelining in PyTorch Lightning.
from torch.utils.data import Dataset import pandas as pd class MyDataset(Dataset): def __init__(self, dataset_type="train"): if dataset_type == "train": ## load the train dataset self.df = ... elif dataset_type == "validation": self.df = ... elif dataset_type == "test": self.df = ... def get_features(self, index): ## extract the needed features X = ... return X def get_label(self, index): ## extract the needed label data y = ... return y def __len__(self): return len(self.df) def __getitem__(self, index): X = self.get_features(index) y = self.get_label(index) return (X, y)
__getitem__ method is important here because by using this we can mould the data however we want it to be presented to the model.
Here, you can load the data anyway you like but the flow would be similar. In this case we have used
dataset_type to differentiate between the types of data that we need, but you can use a method that is best for your particular need.
In order for the dataset to load the data in an efficiently PyTorch has the
DataLoader class which loads the data in batches and also uses concurrency to speed up the process.
from torch.utils.data import DataLoader dataset = MyDataset() dataloader = DataLoader( dataset, batch_size=32, # number of samples to load at a time num_workers=4 # number of threads (= number of processors) )
Trainer, as the name implies is the class responsible for the training and evaluation of the models that you create. It has a myriad of options that you can go through in the official documentation. For this article we will go through a subset of these options that are ciritical for operating with it.
import pytorch_lightning as pl from torch.utils.data import DataLoader # create the model model = MyModel() # create the datasets train_dataset = MyDataset(dataset_type="train") validation_dataset = MyDataset(dataset_type="validation") # create the dataloaders train_dataloader = DataLoader(train_dataset, batch_size=32, num_workers=4) validation_dataloader = DataLoader(validation_dataset, batch_size=32, num_workers=4) # create the trainer trainer = pl.Trainer( gpus=1, # number of gpus to use -1 to use all max_epochs=10 # maximum number of epochs the trainer will execute ) trainer.fit( model, train_dataloader=train_dataloader, val_dataloaders=validation_dataloader )
Trainer Constructor Arguements
gpus- Specifies how many gpus to use for the training purpose, by default it uses none.
max_epochs- Specifies the maximum number of epochs (how many times the dataset is shown to the model).
fit Method Arguments
train_dataloader- Specify the dataloader which is used by the trainer.
val_dataloaders- This can be either a list of dataloaders or a single dataloader, which is then used by the trainer to evaluate the model.
A Simple Example
To make use of the stuff that we have gone through, we will be making a simple model that can identify data points that belong to 4 classes. These data points will be created using
1. Creating the Dataset
from sklearn.datasets import make_blobs from sklearn.model_selection import train_test_split import matplotlib.pyplot as plt classes = 4 n_samples = 1000 (X, y) = make_blobs(n_samples=n_samples, n_features=2, centers=classes, cluster_std=2.5, center_box=(-10, 10) , random_state=42) ## Splitting the datasets (X_train, X_test, y_train, y_test) = train_test_split(X, y, test_size=0.2, random_state=42) (X_train, X_validation, y_train, y_validation) = train_test_split(X_train, y_train, test_size=0.2, random_state=42) colors = ['red', 'green', 'blue', 'black', 'purple'] cdict = dict(map(lambda x: (x, colors[x]), range(0, classes))) fig, ax = plt.subplots() for g in np.unique(y): ix = np.where(y == g) ax.scatter(X[ix, 0], X[ix, 1], c = cdict[g], label = g) ax.plot ax.legend() plt.show()
Below are the clustering of points that we are trying to fit a model to.
2. Create a Custom Dataset
import torch from torch.utils.data import Dataset, DataLoader class MyCustomDataset(Dataset): def __init__(self, X, y): self.X = X self.y = y self.count = X.shape def __len__(self): return self.count def __getitem__(self, index): X = self.X[index] y = self.y[index] return (torch.tensor(X, dtype=torch.float32), y) ds_train = MyCustomDataset(X_train, y_train) ds_validation = MyCustomDataset(X_validation, y_validation) ds_test = MyCustomDataset(X_test, y_test) dl_train = DataLoader(ds_train, batch_size=16, num_workers=2) dl_validation = DataLoader(ds_validation, batch_size=16, num_workers=2) dl_test = DataLoader(ds_test, batch_size=16, num_workers=2)
3. Create the Custom Model
## create the model class import pytorch_lightning as pl import torch from torch import nn from torch.nn import functional as F class MyModel(pl.LightningModule): def __init__(self): super().__init__() ## make the model self.classifier = nn.Sequential( nn.Linear(in_features=2, out_features=4), nn.ReLU(), nn.Linear(in_features=4, out_features=4) ) ## use cross entropy loss for categorical problems self.loss = F.cross_entropy def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=0.01) return optimizer def forward(self, x): x = self.classifier(x) return x def training_step(self, batch, batch_idx): x, y = batch y_logit = self(x) loss = self.loss(y_logit, y) pred = F.softmax(y_logit, dim=1) self.log('train/loss', loss, prog_bar=True, on_step=False, on_epoch=True) return loss def validation_step(self, batch, batch_idx): x, y = batch y_logit = self(x) loss = self.loss(y_logit, y) pred = F.softmax(y_logit, dim=1) self.log("val/loss", loss, prog_bar=True) def test_step(self, batch, batch_idx): x, y = batch y_logit = self(x) loss = self.loss(y_logit, y) pred = F.softmax(y_logit, dim=1) self.log("test/loss", loss)
4. Train the Model
trainer = pl.Trainer( max_epochs=10 ) model = MyModel() trainer.fit( model, train_dataloader=dl_train, val_dataloaders=dl_validation )
With the above code we execute the
validation_step of the model to train and also validate the model.
5. Test the Model
After training the model, we can use the test set to check the model performance with unseen data.
trainer.test( model, test_dataloaders=dl_test )
The output of training and testing of the model is as follows.
6. Save the Checkpoint
We will manually save the model for now, but the
Trainer has more advanced options that allows us to automate the saving of models. You can check the documentation for more details.
In this article we have gone through each of the main steps that are necessary for using the PyTorch Lightning framework. The framework is a really wonderful addition on top of PyTorch framework. I will be posting more topics regarding PyTorch Lightning and Deep Learning in general.