import torch
import copy
import os
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Define the path to directory containing MNIST
mnist_directory = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath("mnist.ipynb")), "..","..","..","..","..","..","..", "/usr/data1/vision/data/"))
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# Load the MNIST training dataset
train_dataset = datasets.MNIST(root=mnist_directory, train=True, download=False, transform=transform)
test_dataset = datasets.MNIST(root=mnist_directory, train=False, download=False, transform=transform)
# Create data loaders
batch_size = 512
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)