Эффективные методы уменьшения размера набора данных в PyTorch Dataloader

При работе с большими наборами данных в PyTorch часто необходимо уменьшить размер набора данных по разным причинам, например из-за ограничений памяти или более быстрого обучения. В этой статье мы рассмотрим несколько методов с примерами кода, позволяющих эффективно уменьшить размер набора данных в Dataloader PyTorch, обеспечивая более плавную и эффективную обработку данных.

  1. Случайная выборка.
    Один из самых простых методов уменьшения размера набора данных — применение случайной выборки. Этот метод случайным образом выбирает подмножество исходного набора данных, сохраняя общее распределение данных.
import torch
from torch.utils.data import SubsetRandomSampler
# Create original dataset
dataset = YourDataset()
# Define the percentage of data to keep
keep_percentage = 0.8
# Calculate the number of samples to keep
num_samples = int(len(dataset) * keep_percentage)
# Create a random sampler
sampler = SubsetRandomSampler(range(num_samples))
# Create dataloader with reduced dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=sampler)
  1. Стратифицированная выборка.
    Стратифицированная выборка гарантирует, что сокращенный набор данных сохраняет то же распределение классов, что и исходный набор данных. Этот метод особенно полезен при работе с несбалансированными наборами данных.
import torch
from torch.utils.data import SubsetRandomSampler
from sklearn.model_selection import StratifiedShuffleSplit
# Create original dataset
dataset = YourDataset()
labels = dataset.get_labels()  # Get the labels of the dataset
# Define the percentage of data to keep
keep_percentage = 0.8
# Create stratified shuffle split
sss = StratifiedShuffleSplit(n_splits=1, train_size=keep_percentage)
# Generate the indices for train and validation sets
train_indices, _ = next(sss.split(range(len(dataset)), labels))
# Create a random sampler
sampler = SubsetRandomSampler(train_indices)
# Create dataloader with reduced dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=sampler)
  1. Фильтрация с использованием пользовательских критериев.
    Вы можете уменьшить набор данных, применив пользовательские критерии фильтрации для исключения определенных образцов. Этот метод позволяет выборочно удалять образцы в зависимости от конкретных условий.
import torch
from torch.utils.data import Subset
# Create original dataset
dataset = YourDataset()
# Define a custom filtering function
def custom_filter(sample):
    # Apply your filtering condition here
    return sample['label'] == desired_label
# Apply the custom filter to create a reduced dataset
reduced_dataset = Subset(dataset, [idx for idx, sample in enumerate(dataset) if custom_filter(sample)])
# Create dataloader with reduced dataset
dataloader = torch.utils.data.DataLoader(reduced_dataset, batch_size=32, shuffle=True)

Уменьшение размера набора данных в Dataloader PyTorch имеет решающее значение для эффективного обучения и обработки больших наборов данных. В этой статье мы рассмотрели различные методы, такие как случайная выборка, стратифицированная выборка и пользовательская фильтрация. Реализуя эти методы, вы можете эффективно уменьшить размер набора данных, сохраняя при этом целостность данных и оптимизируя производительность.