Tutorial

A Guide to the DataLoader Class and Abstractions in PyTorch

A Guide to the DataLoader Class and Abstractions in PyTorch

When working with Neural Networks, especially in large-scale deep learning projects, efficiently managing and preprocessing data can be just as critical as designing the model architecture itself. A common challenge faced by developers and researchers is feeding data into the model in a way that supports high-performance training—this involves batching, shuffling, and potentially applying transformations to data on the fly. Without a streamlined solution, developers are often left writing extensive boilerplate code to handle these operations manually, which can be error-prone, hard to debug, and inefficient.

This is where PyTorch excels by providing powerful abstractions for data handling, with the Dataset and DataLoader classes forming the core components of its data pipeline. These tools help manage everything from loading images from disk to applying real-time data augmentations and managing device transfers, all while keeping training pipelines clean and scalable.

Take, for instance, a basic image classification task using the MNIST dataset—a simple scenario on the surface, yet one that still requires significant effort to load, normalize, batch, and shuffle images and labels effectively. Without abstractions like DataLoader, such seemingly straightforward tasks can quickly become cumbersome as datasets grow or as experiments become more complex.

This guide explores the role and functionality of the DataLoader class in PyTorch, why it’s essential for modern deep learning workflows, and how to use it effectively in your own projects. Whether you’re working on standard datasets like MNIST or custom image, text, or tabular data, understanding how to leverage DataLoader will help you build faster, more reliable training pipelines.

Prerequisites

To follow along with this tutorial, you will need a sufficiently powerful NVIDIA GPU with at least 8GB of VRAM. A basic understanding of Python classes and objects will also be crucial for understanding the full discussion.

Working on Datasets

If you are working on a real-time project involving Deep Learning, it’s common that most of your time goes into handling data rather than the neural network that you would build. This is because data is like fuel for your network: the more appropriate it is, the faster and the more accurate the results are! One of the main reasons for your neural network to underperform might be due to bad or poorly understood data. Hence, it is important to understand, preprocess, and load your data into the network in a more intuitive way.

In many cases, we train neural networks on default or well-known datasets like MNIST or CIFAR. While working on these, we can easily achieve accuracy greater than 90% for prediction- and classification-type problems. The reason is that these datasets are neatly organized and easy to preprocess. But when you are working on your own dataset, it’s quite tricky and challenging to achieve high accuracy. We’ll learn about working on custom datasets in the next sections. Before that, we’ll have a quick look at the datasets included in the PyTorch library.

PyTorch comes with several built-in datasets, all of which are pre-loaded in the torch class. datasets. In the previous example, when we were classifying MNIST images, we used the same class to download our images.

What’s in the package torch and torchvision?

The package torch contains all the essential classes and methods needed to implement neural networks. In contrast, torchvision is a supplementary package that includes popular datasets, model architectures, and common image transformations specifically for computer vision tasks. Additionally, there is a package called torchtext, which provides fundamental utilities for Natural Language Processing (NLP) with PyTorch. This package includes datasets related to text processing.

Here’s a quick overview of datasets that are included in the classes torchvision and torchtext.

Datasets in Torchvision

MNIST: MNIST is a dataset consisting of handwritten images that are normalized and center-cropped. It has over 60,000 training images and 10,000 test images. This is one of the most-used datasets for learning and experimenting purposes. To load and use the dataset, you can import using the below syntax after the torchvision package is installed.

  • torchvision.datasets.MNIST()

Fashion MNIST: This dataset is similar to MNIST, but instead of handwritten digits, this dataset includes clothing items like T-shirts, trousers, bags, etc. The number of training and testing samples is 60,000 and 10,000, respectively.

  • torchvision.datasets.FashionMNIST()

CIFAR: The CIFAR dataset has two versions, CIFAR10 and CIFAR100. CIFAR10 consists of images of 10 different labels, while CIFAR100 has 100 different classes. These include common images like trucks, frogs, boats, cars, deer, etc. This dataset is recommended for building CNNs.

  • torchvision.datasets.CIFAR10()
  • torchvision.datasets.CIFAR100()

COCO: This dataset consists of over 100,000 everyday objects like people, bottles, stationery, books, etc. This dataset of images is widely used for object detection and image captioning applications. Below is the location from which COCO can be loaded:

  • torchvision.datasets.CocoCaptions()

EMNIST: This dataset is an advanced version of the MNIST dataset. It consists of images including both numbers and letters. If you are working on a problem that is based on recognizing text from images, this is the right dataset to train with. Below is the class:

  • torchvision.datasets.EMNIST()

IMAGE-NET: ImageNet is one of the flagship datasets that is used to train high-end neural networks. It consists of over 1.2 million images spread across 10,000 classes. Usually, this dataset is loaded on a high-end hardware system, as a CPU alone cannot handle datasets this big in size. Below is the class to load the ImageNet dataset:

  • torchvision.datasets.ImageNet()

These are a few datasets that are the most frequently used while building neural networks in PyTorch. A few others include KMNIST, QMNIST, LSUN, STL10, SVHN, PhotoTour, SBU, Cityscapes, SBD, USPS, Kinetics-400. You can learn more about these from the PyTorch official documentation.

Datasets in Torchtext

As discussed previously, torchtext is a supporting package that consists of all the basic utilities for Natural Language Processing. If you are new to NLP, it is a subfield of Artificial Intelligence that processes and analyzes large amounts of natural language data (mostly relating to text).

Now, let’s take a look at a few popular text datasets to experiment with and work with.

IMDB: This is a dataset for sentiment classification that contains a set of 25,000 highly polar movie reviews for training and another 25,000 for testing. We can load this data by using the following class from torchtext:

  • torchtext.datasets.IMDB()

WikiText2: This language modelling dataset is a collection of over 100 million tokens. It is extracted from Wikipedia and retains the punctuation and the actual letter case. It is widely used in applications that involve long-term dependencies. This data can be loaded from torchtext as follows:

  • torchtext.datasets.WikiText2()

Besides the above two popular datasets, there are still many more available in the torchtext library, such as SST, TREC, SNLI, MultiNLI, WikiText-2, WikiText103, PennTreebank, Multi30k, etc.

So far, we’ve seen datasets that are based on a predefined set of images and text. But what if you have your own? How do you load it? For now, let’s learn the ImageFolder class, which you can use to load your own image datasets.

ImageFolder Class

ImageFolder is a generic data loader class in torchvision that helps you load your own image dataset. Imagine you are working on a classification problem, building a neural network to identify whether a given image is an apple or an orange. To do this in PyTorch, the first step is to arrange images in a default folder structure, as shown below:

 root
├── orange
│   ├── orange_image1.png
│   └── orange_image1.png
├── apple
│   └── apple_image1.png
│   └── apple_image2.png
│   └── apple_image3.png

After you arrange your dataset as shown, you can use the ImageLoader class to load all these images. Below is the code snippet you would use to do so:

torchvision.datasets.ImageFolder(root, transform)

In the next section, let’s see how to load data into our programs.

Data Loading in PyTorch

Data loading is one of the first steps in building a Deep Learning pipeline or training a model. This task becomes more challenging when the complexity of the data increases. In this section, we will learn about the DataLoader class in PyTorch that helps us to load and iterate over elements in a dataset. This class is available as DataLoader in the torch.utils.data module. DataLoader can be imported as follows:

from torch.utils.data import DataLoader

Let’s now discuss in detail the parameters the DataLoader class accepts, shown below.

from torch.utils.data import DataLoader

DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
 )

1. Dataset: The first parameter in the DataLoader class is the dataset, from which we load the data.

2. Batching the data: batch_size refers to the number of training samples used in one iteration. Usually, we split our data into training and testing sets, and we may have different batch sizes for each.

3. Shuffling the data: shuffle is another argument passed to the DataLoader class. The argument takes in a Boolean value (True/False). If shuffle is set to True, then all the samples are shuffled and loaded in batches. Otherwise, they are sent one by one without any shuffling.

4. Allowing multi-processing: As deep learning involves training models with a lot of data, running only single processes ends up taking a lot of time. In PyTorch, you can increase the number of processes running simultaneously by allowing multiprocessing with the argument num_workers. This also depends on the batch size, but I wouldn’t set num_workers to the same number because each worker loads a single batch and returns it only once it’s ready.

  • num_workers=0 means that the main process loads the data when needed.
  • num_workers=1 means you only have a single worker, so it might be slow.

5. Merging datasets: The collate_fn argument is used if we want to merge datasets. This argument is optional and mostly used when batches are loaded from map-style datasets.

6. Loading data on CUDA tensors: You can directly load datasets as CUDA tensors using the pin_memory argument. It is an optional parameter that takes in a Boolean value; if set to True, the DataLoader class copies Tensors into CUDA-pinned memory before returning them.

Let’s take a look at an example to better understand the usual data loading pipeline.

Looking at the MNIST Dataset in-Depth

PyTorch’s torchvision repository hosts a handful of standard datasets, MNIST being one of the most popular. Now, we’ll see how PyTorch loads the MNIST dataset from the pytorch/vision repository. Let’s first download the dataset and load it into a variable named data_train. Then, we’ll print a sample image.

# Import MNIST
from torchvision.datasets import MNIST

# Download and Save MNIST 
data_train = MNIST('~/mnist_data', train=True, download=True)

# Print Data
print(data_train)
print(data_train[12])

Output:

Dataset MNIST Number of datapoints: 60000 Root location: /Users/viharkurama/mnist_data Split: Train (<PIL.Image.Image image mode=L size=28x28 at 0x11164A100>, 3)

Let’s now try extracting the tuple wherein the first value would correspond to the image, and the second value would correspond to its respective label. Below is the code snippet:

import matplotlib.pyplot as plt

random_image = data_train[0][0]
random_image_label = data_train[0][1]

# Print the Image using Matplotlib
plt.imshow(random_image)
print("The label of the image is:", random_image_label)

Most of the time, you wouldn’t access images with indices but rather send matrices containing the images to your model. This comes in handy when you need to prepare data batches (and perhaps shuffle them before every run). Now let’s see how this works in real-time. Let’s use the DataLoader class to load the dataset, as shown below.

import torch
from torchvision import transforms

data_train = torch.utils.data.DataLoader(
    MNIST(
          '~/mnist_data', train=True, download=True, 
          transform = transforms.Compose([
              transforms.ToTensor()
          ])),
          batch_size=64,
          shuffle=True
          )

for batch_idx, samples in enumerate(data_train):
      print(batch_idx, samples)

This is how we load a simple dataset using DataLoader. However, we can’t always rely on DataLoader for every dataset. We often deal with large or irregular datasets containing images of asymmetric resolutions, and this is where GPUs play an important role.

Loading the Data on GPUs

We can use GPUs to train our models more quickly. Let’s look at how to configure CUDA (GPU support for PyTorch) when loading data. Here is an example code snippet:

device = "cuda" if torch.cuda.is_available() else "cpu"
kwargs = {'num_workers': 1, 'pin_memory': True} if device=='cuda' else {}

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True),
  batch_size=batch_size_train, **kwargs)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('files/', train=False, download=True),
  batch_size=batch_size, **kwargs)

In the above, we declared a new variable named device. Next, we write a simple if condition that checks the current hardware configuration. If it supports GPU, it would set the device to cuda, else it would set it to cpu. The variable num_workers denotes the number of processes that generate batches in parallel. For data loading, passing pin_memory=True to the DataLoader class will automatically put the fetched data tensors in pinned memory and thus enables faster data transfer to CUDA-enabled GPUs.

In the next section, we’ll learn about Transforms, which define the preprocessing steps for loading the data.

Transforms and Rescaling the Data

PyTorch transforms are used to apply simple image transformation techniques that convert an entire dataset into a uniform format. For instance, if we have a dataset containing pictures of various cars in different resolutions, it’s important for all the images in the training dataset to have the same resolution size. Manually converting each image to the required size can be time-consuming, so we can utilize transforms instead. With just a few lines of PyTorch code, we can easily resize all the images in our dataset to the desired input size and resolution.

The transforms module offers several commonly used operations, including transforms.Resize() to resize images, transforms.CenterCrop() to crop the images from the center, and transforms.RandomResizedCrop() to randomly resize images throughout the dataset. These tools help streamline the preprocessing of images, ensuring consistency and efficiency in your workflow.

Let’s now load CIFAR10 from torchvision.datasets and apply the following transforms:

  1. Resizing all the images to 32×32
  2. Applying a center crop transform to the images
  3. Converting the cropped images to tensors
  4. Normalizing the images

First, we import the necessary modules and transforms from the torchvision module. Then, we visualize the dataset using the NumPy and Matplotlib libraries.

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

Next, we will define a variable called transforms, where we will write all the preprocessing steps sequentially. We used the Compose class to chain together all the transformation operations.

transform = transforms.Compose([
    # resize
    transforms.Resize(32),
    # center-crop
    transforms.CenterCrop(32),
    # to-tensor
    transforms.ToTensor(),
    # normalize
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
  • resize: This Resize transform converts all images to the defined size. In this case, we want to resize all images to 32×32. Hence, we pass 32 as an argument.
  • center-crop: Next, we crop the images using the CenterCrop transform. The argument we send is also the resolution/size, but since we already resized the image to 32x32, the images would be center-aligned with this crop. This means the images would be cropped by 32 units from the center (both vertically and horizontally).
  • to-tensor: We used the method ToTensor() to convert the images to the Tensor datatype.
  • normalize: This normalizes all the values in the tensor so that they lie between 0.5 and 1.
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=False)

We fetched the CIFAR dataset from torchvision.datasets, setting the train and download arguments to True. Next, we set the transform argument to the defined transform variable. The DataLoader iterable was initialized, and we passed the trainset as an argument. The batch_size was set to 4, and shuffle was set to False. Next, we can visualize the images using the code snippet below.

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def imshow(img):
     img = img / 2 + 0.5
     npimg = img.numpy()
     plt.imshow(np.transpose(npimg, (1, 2, 0)))
     plt.show()
    
dataiter = iter(trainloader)
images, labels = dataiter.next()    

imshow(torchvision.utils.make_grid(images))

print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

Besides Resize(), CenterCrop(), and RandomResizedCrop(), there are various other Transform classes available. Let’s look at the most-used ones.

Transform Classes

  1. RandomCrop: This class in PyTorch crops the given PIL Image at a random location. The following are the arguments that RandomCrop accepts:
torchvision.transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0)
  • size: This argument takes an integer, which indicates the desired output size of the random crop. For example, if the size is set to 32, the output will be a randomly cropped image of size 32×32.
  • padding: This is an integer argument that is initially set to None. If set to an integer, it adds an additional border to the image. For example, if the padding is set to 4, it pads the left, top, right, and bottom borders by 4 units each.
  • pad_if_needed: This is an optional parameter that takes a Boolean value. If it’s set to True, it pads a smaller area around the image to avoid minimal resolution errors. By default, this parameter is set to False.
  • fill: This constant value initializes the values of all the padded pixels. The default fill value is 0.

2. RandomHorizontalFlip: Sometimes, to make the model robust while training, we flip the images randomly. The class RandomHorizontalFlip is used to achieve such results. It has one default argument, p, which indicates the probability of the image being flipped (between 0 and 1). The default value is 0.5.

torchvision.transforms.RandomHorizontalFlip(p=0.5)

3. Normalize: This normalizes the images, with the mean and standard deviation as arguments. This class takes four arguments, shown below:

torchvision.transforms.functional.normalize(tensor, mean, std, inplace=False)
  • The tensor argument takes in a Tensor with three values: C, H, and W. They represent the number of channels, height, and width, respectively. Based on the given argument, all the pixel values of the input images are normalized.
  • The mean and std arguments take in a sequence of means and standard deviations with respect to each channel.
  • The inplace argument is a Boolean value. If set to True, all the operations shall be computed in-place.

4. ToTensor: This class converts the PIL Image or a NumPy n-dimensional array to a tensor.

torchvision.transforms.functional.to_tensor(img)

Now, let’s understand the mechanisms behind loading a custom dataset rather than using the built-in datasets.

Creating Custom Datasets in PyTorch

So far, we’ve learned to load datasets along with various ways to preprocess the data. In this section, we’ll create a simple custom dataset consisting of numbers and text. We’ll talk about the Dataset object in PyTorch that helps to handle numerical and text files and how one could go about optimizing the pipeline for a certain task. The trick here is to abstract the getitem() and len() methods in the Dataset class.

  • The getitem() method returns the selected sample in the dataset by indexing.
  • The len() method returns the total size of the dataset. For example, if your dataset contains 1,00,000 samples, the len method should return 1,00,000.

Note that the data is not yet loaded into memory at this point.

Below is an abstract view explaining the implementations of getitem() and len() methods:

class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

Creating a custom dataset isn’t complex, but as an additional step to the typical procedure of loading data, it is necessary to build an interface to get a nice abstraction (a nice syntactic sugar, to say the least). Now, we’ll create a new dataset that has numbers and their squared values. Let us call our dataset SquareDataset. Its purpose is to return squares of values in the range [a,b]. Below is the relevant code:

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

class SquareDataset(Dataset):
     def __init__(self, a=0, b=1):
         super(Dataset, self).__init__()
         assert a <= b
         self.a = a
         self.b = b
        
     def __len__(self):
         return self.b - self.a + 1
        
     def __getitem__(self, index):
        assert self.a <= index <= self.b
        return index, index**2

data_train = SquareDataset(a=1,b=64)
data_train_loader = DataLoader(data_train, batch_size=64, shuffle=True)
print(len(data_train))

In the above code block, we created a Python class named SquareDataset that inherits from the Dataset class from PyTorch. Next, we called an init() constructor where a and b were initialized to 0 and 1, respectively. The super class is used to access the len and get_item methods from the inherited Dataset class. Next, we used the assert statement to check if a is less than or equal to b, as we want to create a dataset wherein the values would lie between a and b.

We then created a dataset using the SquareDataset class, where the data values lie in the range 1 to 64. We loaded this into a variable named data_train. Lastly, the Dataloader class created an iterator over the data stored in data_train_loader with a batch_size initialized to 64 and shuffle set to True.

Data loaders exploit the goodness of Python by employing pieces of object-oriented programming concepts. A good exercise would be to go through a variety of data loaders with a number of popular datasets, including CelebA, PIMA, COCO, ImageNet, CIFAR-10/100, etc.

FAQ

1. What is the DataLoader class used for in PyTorch?
DataLoader is used to efficiently load data in mini-batches, shuffle it, and feed it to your model during training or evaluation. It handles parallel data loading and prefetching to speed up training.

2. How do Dataset and DataLoader work together in PyTorch?
A Dataset provides access to individual samples, while DataLoader wraps the dataset and enables batch loading, shuffling, and multiprocessing.

3. What are the key parameters of DataLoader?
Important parameters include:

  • batch_size: Number of samples per batch
  • shuffle: Whether to shuffle data each epoch
  • num_workers: Number of subprocesses to use for data loading
  • pin_memory: Whether to use pinned memory for faster GPU transfer
  • drop_last: Drop the last batch if it’s smaller than batch_size
  • collate_fn: Custom function to combine samples into a batch

4. How does num_workers affect performance in DataLoader?
Higher num_workers values enable parallel data loading, reducing data loading bottlenecks. However, setting it too high may lead to CPU overload or memory issues.

5. When should I use pin_memory=True in a DataLoader?
Use pin_memory=True when loading data to a GPU. It allows faster transfer from host to device memory, especially helpful for training on CUDA.

6. What is the purpose of collate_fn in DataLoader?
collate_fn lets you customize how a list of dataset items is merged into a batch. It’s useful for handling variable-length inputs like text or complex nested data structures.

7. Why is my DataLoader so slow, and how can I optimize it?
Common reasons include:

  • Too few num_workers
  • Slow I/O (disk or network)
  • Expensive transforms in the main thread
    Optimize by increasing num_workers, using cached datasets, optimizing transforms, or using pin_memory=True.

8. How do I handle imbalanced datasets with DataLoader?
Use WeightedRandomSampler to oversample minority classes or under-sample majority ones while keeping the DataLoader interface.

9. What’s the difference between drop_last=True and False?

  • True: Drops the last batch if it has fewer than batch_size samples.
  • False: Keeps the last batch even if it’s smaller than batch_size.

10. Can I use DataLoader with custom data sources like databases or APIs?
Yes. Create a custom Dataset class that queries your data source (e.g., SQL, REST API) in __getitem__. DataLoader can then wrap it normally.

Summary

In this post, we took a hands-on journey through PyTorch’s data loading ecosystem. We started by exploring built-in datasets from libraries like torchvision and torchtext, diving into popular examples to understand how data is structured and accessed. Then, we introduced the DataLoader class—PyTorch’s flexible tool for batching, shuffling, and managing datasets efficiently—and saw how it simplifies data handling during model training.

We spent some time with the MNIST dataset, examining different ways to load and preprocess it. Along the way, we discussed the role of Transforms, showcasing useful techniques like RandomCrop, RandomHorizontalFlip, Normalize, ToTensor, and RandomRotation to augment and prepare data effectively.

We also touched on why GPUs often outperform CPUs for deep learning workloads and demonstrated how to leverage PyTorch’s CUDA capabilities. If you’re experimenting with models or training at scale, using GPU-backed environments like DigitalOcean’s GPU Droplets can make a huge difference in speed and productivity.

Finally, we wrapped things up by showing how easy it is to create your own custom dataset using PyTorch’s Dataset class—a task that sounds complicated but is actually quite approachable with just a few lines of code.

Resources

Thanks for learning with the DigitalOcean Community. Check out our offerings for compute, storage, networking, and managed databases.

Learn more about our products

Still looking for an answer?

Ask a questionSearch for more help

Was this helpful?
 
Leave a comment


This textbox defaults to using Markdown to format your answer.

You can type !ref in this text area to quickly search our full set of tutorials, documentation & marketplace offerings and insert the link!

Join the Tech Talk
Success! Thank you! Please check your email for further details.

Please complete your information!

Become a contributor for community

Get paid to write technical tutorials and select a tech-focused charity to receive a matching donation.

DigitalOcean Documentation

Full documentation for every DigitalOcean product.

Resources for startups and SMBs

The Wave has everything you need to know about building a business, from raising funding to marketing your product.

Get our newsletter

Stay up to date by signing up for DigitalOcean’s Infrastructure as a Newsletter.

New accounts only. By submitting your email you agree to our Privacy Policy

The developer cloud

Scale up as you grow — whether you're running one virtual machine or ten thousand.

Get started for free

Sign up and get $200 in credit for your first 60 days with DigitalOcean.*

*This promotional offer applies to new accounts only.