使用 PyTorch 数据加载训练一个简单神经网络#
版权所有 2018 JAX 作者。
根据 Apache 许可证 2.0 版(“许可证”)获得许可;除非遵守许可证,否则您不得使用此文件。您可以在以下网址获取许可证副本:
https://apache.ac.cn/licenses/LICENSE-2.0
除非适用法律要求或经书面同意,根据许可证分发的软件均“按原样”提供,不附带任何明示或暗示的保证或条件。有关许可证下特定语言的管理权限和限制,请参阅许可证。
让我们结合快速入门中展示的所有内容,来训练一个简单的神经网络。我们将首先使用 JAX 进行计算,在 MNIST 数据集上指定并训练一个简单的 MLP。我们将使用 PyTorch 的数据加载 API 来加载图像和标签(因为它非常出色,而且世界不需要另一个数据加载库)。
当然,您可以将 JAX 与任何兼容 NumPy 的 API 一起使用,以使模型指定更具即插即用性。在这里,仅为解释目的,我们不会使用任何神经网络库或特殊 API 来构建我们的模型。
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
超参数#
让我们先处理一些准备工作。
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
keys = random.split(key, len(sizes))
return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))
自动批处理预测#
让我们首先定义预测函数。请注意,我们是为单个图像示例定义此函数。我们将使用 JAX 的 vmap
函数自动处理小批量,且不会有性能损失。
from jax.scipy.special import logsumexp
def relu(x):
return jnp.maximum(0, x)
def predict(params, image):
# per-example predictions
activations = image
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
让我们检查一下我们的预测函数是否只适用于单个图像。
# This works on single examples
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
(10,)
# Doesn't work with a batch
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
preds = predict(params, random_flattened_images)
except TypeError:
print('Invalid shapes!')
Invalid shapes!
# Let's upgrade it to handle batches using `vmap`
# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))
# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)
(10, 10)
至此,我们已经拥有了定义和训练神经网络所需的所有要素。我们已经构建了 predict
的自动批处理版本,应该可以在损失函数中使用。我们应该能够使用 grad
来计算损失对神经网络参数的导数。最后,我们应该能够使用 jit
来加速所有操作。
实用函数和损失函数#
def one_hot(x, k, dtype=jnp.float32):
"""Create a one-hot encoding of x of size k."""
return jnp.array(x[:, None] == jnp.arange(k), dtype)
def accuracy(params, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
return jnp.mean(predicted_class == target_class)
def loss(params, images, targets):
preds = batched_predict(params, images)
return -jnp.mean(preds * targets)
@jit
def update(params, x, y):
grads = grad(loss)(params, x, y)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
使用 PyTorch 加载数据#
JAX 专注于程序转换和加速器支持的 NumPy,因此我们不将数据加载或整理功能包含在 JAX 库中。市面上已经有很多出色的数据加载器,所以我们直接使用它们,而不是重复发明轮子。我们将使用 PyTorch 的数据加载器,并做一点小小的垫片(shim)使其与 NumPy 数组兼容。
!pip install torch torchvision
Requirement already satisfied: torch in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (2.4.1)
Requirement already satisfied: torchvision in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (0.19.1)
Requirement already satisfied: filelock in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (3.16.0)
Requirement already satisfied: typing-extensions>=4.8.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (4.12.2)
Requirement already satisfied: sympy in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (1.13.2)
Requirement already satisfied: networkx in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (3.3)
Requirement already satisfied: jinja2 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (2024.9.0)
Requirement already satisfied: setuptools in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (73.0.1)
Requirement already satisfied: numpy in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torchvision) (1.26.4)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torchvision) (10.4.0)
Requirement already satisfied: MarkupSafe>=2.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from jinja2->torch) (2.1.5)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from sympy->torch) (1.3.0)
/home/m/.opt/miniforge3/envs/jax/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid, fd = os.forkpty()
import numpy as np
from jax.tree_util import tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision.datasets import MNIST
def numpy_collate(batch):
"""
Collate function specifies how to combine a list of data samples into a batch.
default_collate creates pytorch tensors, then tree_map converts them into numpy arrays.
"""
return tree_map(np.asarray, default_collate(batch))
def flatten_and_cast(pic):
"""Convert PIL image to flat (1-dimensional) numpy array."""
return np.ravel(np.array(pic, dtype=jnp.float32))
# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=flatten_and_cast)
# Create pytorch data loader with custom collate function
training_generator = DataLoader(mnist_dataset, batch_size=batch_size, collate_fn=numpy_collate)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw
100.0%
100.0%
100.0%
100.0%
# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)
# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)
/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:76: UserWarning: train_data has been renamed data
warnings.warn("train_data has been renamed data")
/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:66: UserWarning: train_labels has been renamed targets
warnings.warn("train_labels has been renamed targets")
/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:81: UserWarning: test_data has been renamed data
warnings.warn("test_data has been renamed data")
/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:71: UserWarning: test_labels has been renamed targets
warnings.warn("test_labels has been renamed targets")
训练循环#
import time
for epoch in range(num_epochs):
start_time = time.time()
for x, y in training_generator:
y = one_hot(y, n_targets)
params = update(params, x, y)
epoch_time = time.time() - start_time
train_acc = accuracy(params, train_images, train_labels)
test_acc = accuracy(params, test_images, test_labels)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Epoch 0 in 5.53 sec
Training set accuracy 0.9156666994094849
Test set accuracy 0.9199000000953674
Epoch 1 in 1.13 sec
Training set accuracy 0.9370499849319458
Test set accuracy 0.9383999705314636
Epoch 2 in 1.12 sec
Training set accuracy 0.9490833282470703
Test set accuracy 0.9467999935150146
Epoch 3 in 1.21 sec
Training set accuracy 0.9568833708763123
Test set accuracy 0.9532999992370605
Epoch 4 in 1.17 sec
Training set accuracy 0.9631666541099548
Test set accuracy 0.9574999809265137
Epoch 5 in 1.17 sec
Training set accuracy 0.9675000309944153
Test set accuracy 0.9615999460220337
Epoch 6 in 1.11 sec
Training set accuracy 0.9709500074386597
Test set accuracy 0.9652999639511108
Epoch 7 in 1.17 sec
Training set accuracy 0.9736999869346619
Test set accuracy 0.967199981212616
我们现在已经使用了 JAX API 的全部功能:grad
用于求导,jit
用于加速,以及 vmap
用于自动向量化。我们使用 NumPy 指定了所有的计算,并借用了 PyTorch 优秀的数据加载器,然后在 GPU 上运行了整个过程。