PyTorch实战-利用卷积神经网络完成手写数字识别
发表于:2024-09-01 | 分类: AI
字数统计: 14.5k | 阅读时长: 81分钟 | 阅读量:

前言

卷积神经网络(Convolutional Neural Networks, CNNs)是一种特殊类型的神经网络,在图像和视频识别、推荐系统、图像分类、医学图像分析、自然语言处理等领域有着广泛的应用。它们能够自动从原始图像中提取特征,并通过多层网络结构学习这些特征的高级表示。本文通过手写数字识别项目带大家学习卷积神经网络。

卷积神经网络基本概念

CNN结构

输入层(Input Layer)–> {卷积层(Convolutional Layer) –> 池化层(Pooling Layer)–> 卷积层(Convolutional Layer) –> 池化层(Pooling Layer)}(重复) –> 全连接层(Fully Connected Layer)

卷积层

在深度学习和计算机视觉中,尤其是在处理卷积神经网络(CNN)时,计算输出尺寸(Output Size)的公式非常重要。您给出的公式是卷积层输出尺寸计算的一个基本公式,但通常我们会用更直观的符号来表示它。下面是该公式转换为常用表示方法的形式:

$$ O = \left\lfloor \frac{I + 2P - K}{S} \right\rfloor + 1 $$

其中:

  • $O$ 代表输出尺寸(Output Size),通常是输出特征图的高度或宽度(假设它们是相等的,即正方形特征图)。
  • $I$ 代表输入尺寸(Input Size),即输入特征图的高度或宽度。
  • $P$ 代表填充(Padding)的大小,即在输入特征图的边界上添加的零的层数。注意,这里的 $2P$ 表示在输入特征图的两侧(或上下两侧,取决于维度)都添加了 $P$ 层的零。
  • $K$ 代表卷积核(Kernel Size)的大小,即卷积核的高度或宽度(在正方形卷积核的情况下)。
  • $S$ 代表步长(Stride),即卷积核在输入特征图上移动的步数。
  • $\left\lfloor \cdot \right\rfloor$ 表示向下取整操作,因为像素数必须是整数。

这个公式适用于计算卷积层(包括标准卷积层和转置卷积层,但转置卷积层有额外的参数和复杂性)后的输出特征图尺寸。在实际应用中,了解如何根据这些参数调整网络结构以得到期望的输出尺寸是非常重要的。

激活函数(Activation Function)

激活函数用于在卷积层(以及其他类型的神经网络层)之后引入非线性。常见的激活函数包括ReLU(Rectified Linear Unit,修正线性单元)、sigmoid和tanh等。ReLU因其简单性和减少梯度消失问题的能力而在CNN中广泛使用。

池化层

池化层(Pooling Layer)在卷积神经网络(CNN)中扮演着重要的角色,主要用于特征融合和降维,以减少计算量和控制过拟合。池化层的输入尺寸和输出尺寸计算公式可以根据不同的参数设置而有所不同,但基本思路是相似的。

池化层的输出尺寸计算公式

池化层的输出尺寸计算公式可以表示为:

$$ O = \left\lfloor \frac{I + 2P - F}{S} + 1 \right\rfloor $$

其中:

  • $O$ 代表输出尺寸(Output Size),即池化层输出的特征图的高度和宽度。
  • $I$ 代表输入尺寸(Input Size),即输入特征图的高度和宽度。
  • $F$ 是池化窗口(Pooling Window)的大小,即池化操作覆盖的输入特征图的区域大小。
  • $S$ 是步长(Stride),即池化窗口在输入特征图上移动的步数。
  • $P$ 是填充(Padding),即在输入特征图的边界上添加的零的层数,用于控制输出尺寸。
  • $\left\lfloor \cdot \right\rfloor$ 表示向下取整操作,因为像素数必须是整数。

注意事项

  • 池化层通常不涉及权重和偏置参数,因此它们不会影响模型的学习能力,但对于减少计算量和控制过拟合非常有帮助。
  • 在实际应用中,池化窗口的大小$F$和步长$S$通常设置为相同的值,如2或3,这样可以更有效地降低特征图的维度。
  • 填充$P$的值可以是0(无填充),也可以是其他正整数(有填充),具体取决于需要保持输出特征图尺寸与输入特征图尺寸的比例关系。

全连接层(Fully Connected Layer, FC Layer)

在CNN的末端,通常会有几个全连接层。这些层中的每个神经元都与前一层的所有神经元相连接。全连接层的作用是将前面层学到的“分布式特征表示”映射到样本标记空间。在分类任务中,全连接层的输出可以传递给softmax函数来生成最终的类别概率分布。

参数共享(Parameter Sharing)

在CNN中,卷积核的参数是在整个输入数据上共享的。这意味着无论数据中的哪个位置,卷积核都使用相同的权重和偏置参数进行卷积操作。这种参数共享机制减少了模型的参数量,并有助于模型学习到输入数据的空间层次结构。

局部连接(Local Connectivity)

在卷积层中,每个神经元仅与输入数据的一个局部区域(即感受野)相连接,而不是与整个输入数据相连接。这种局部连接机制使得CNN能够学习到数据的局部特征,这与人类视觉系统的处理机制相似。

反向传播(Backpropagation)

反向传播算法是训练CNN(以及其他类型的神经网络)的关键算法。在训练过程中,通过计算损失函数关于网络参数的梯度,并利用梯度下降(或其变体)来更新网络参数,以最小化损失函数。反向传播算法通过链式法则在网络的每一层中传播梯度信息。

MNIST数据集介绍

在探索机器学习领域的广阔天地时,手写数字识别作为一个经典且基础的任务,始终占据着重要的地位。而MNIST(Modified National Institute of Standards and Technology)数据集,正是这一任务中最常用、最经典的数据集之一。本文将首先介绍MNIST数据集,为后续的手写数字识别模型训练与测试打下坚实的基础。

MNIST数据集概述

MNIST数据集由Yann LeCun等人搜集整理,是一个大型的手写体数字数据库。该数据集最初来源于NIST(National Institute of Standards and Technology)的两个数据库:专用数据库1(Special Database 1)和特殊数据库3(Special Database 3)。通过精心筛选与预处理,MNIST最终成为了一个包含大量手写数字图像的标准数据集,广泛应用于各种图像处理系统和机器学习算法的训练与测试中。

数据集的构成

MNIST数据集由60,000个训练样本和10,000个测试样本组成,每个样本都是一张28x28像素的灰度图像,表示一个手写数字(0-9)。这些图像均已被归一化,像素值范围在0到255之间,其中0代表黑色,255代表白色。数据集的图像由来自不同人群的手写体构成,包括高中生和美国人口普查局的工作人员,确保了数据的多样性和代表性。

数据集的特点

  1. 简单性:虽然MNIST数据集包含的手写数字种类繁多,但由于其图像尺寸小(28x28像素)、像素深度低(灰度图像),使得处理起来相对简单。这使其成为机器学习初学者练习图像识别、模式识别等任务的理想选择。

  2. 代表性:MNIST数据集中的手写数字覆盖了各种书写风格和变体,使得训练出的模型能够较好地泛化到未知的手写数字上。因此,该数据集在评估机器学习算法性能时具有很高的参考价值。

  3. 广泛应用:由于其简单性和代表性,MNIST数据集在机器学习领域得到了广泛应用。从简单的神经网络到复杂的深度学习模型,几乎所有的图像识别算法都会使用MNIST数据集进行训练和测试。

数据集的下载与使用

MNIST数据集可以通过多种途径下载,其中最常用的方式是通过互联网直接下载。用户可以从Yann LeCun的官方网站(http://yann.lecun.com/exdb/mnist/)或其他数据共享平台获取该数据集。下载后的数据集通常包含四个文件:训练集图像、训练集标签、测试集图像和测试集标签。这些文件均为压缩格式,用户需要解压后才能使用。

在使用MNIST数据集时,用户需要根据自己的需求进行预处理和加载操作。例如,可以使用Python的NumPy库或Pandas库来读取和处理数据集中的图像和标签信息;也可以使用深度学习框架(如TensorFlow、PyTorch等)中提供的数据加载工具来简化这一过程。

下载并导入数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

# 数据预处理
# 使用Compose组合多个变换,这里将数据转换为张量并进行标准化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# 下载MNIST数据集并划分为训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

查看训练集属性:

1
2
3
4
5
6
7
8
9
10
11
# 查看整个训练集的样本数量及单个样本的形状
print(f"训练集大小: {len(train_dataset)}")
# 查看第一个样本的数据和标签
first_sample, first_label = train_dataset[0]
print(f"首个样本数据形状: {first_sample.shape}")
print(f"首个样本标签: {first_label}")

# 如果想查看前几个样本的具体数据内容,可以通过循环实现
for i in range(5):
data, label = train_dataset[i]
print(f"样本 {i+1} 的数据:\n{data}\n标签: {label}\n")
训练集大小: 60000
首个样本数据形状: torch.Size([1, 28, 28])
首个样本标签: 5
样本 1 的数据:
tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.9765, -0.8588,
          -0.8588, -0.8588, -0.0118,  0.0667,  0.3725, -0.7961,  0.3020,
           1.0000,  0.9373, -0.0039, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.7647, -0.7176, -0.2627,  0.2078,  0.3333,  0.9843,
           0.9843,  0.9843,  0.9843,  0.9843,  0.7647,  0.3490,  0.9843,
           0.8980,  0.5294, -0.4980, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.6157,  0.8667,  0.9843,  0.9843,  0.9843,  0.9843,  0.9843,
           0.9843,  0.9843,  0.9843,  0.9686, -0.2706, -0.3569, -0.3569,
          -0.5608, -0.6941, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.8588,  0.7176,  0.9843,  0.9843,  0.9843,  0.9843,  0.9843,
           0.5529,  0.4275,  0.9373,  0.8902, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.3725,  0.2235, -0.1608,  0.9843,  0.9843,  0.6078,
          -0.9137, -1.0000, -0.6627,  0.2078, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -0.8902, -0.9922,  0.2078,  0.9843, -0.2941,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000,  0.0902,  0.9843,  0.4902,
          -0.9843, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.9137,  0.4902,  0.9843,
          -0.4510, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.7255,  0.8902,
           0.7647,  0.2549, -0.1529, -0.9922, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.3647,
           0.8824,  0.9843,  0.9843, -0.0667, -0.8039, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.6471,  0.4588,  0.9843,  0.9843,  0.1765, -0.7882, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.8745, -0.2706,  0.9765,  0.9843,  0.4667, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000,  0.9529,  0.9843,  0.9529, -0.4980,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.6392,  0.0196,  0.4353,  0.9843,  0.9843,  0.6235, -0.9843,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.6941,  0.1608,
           0.7961,  0.9843,  0.9843,  0.9843,  0.9608,  0.4275, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.8118, -0.1059,  0.7333,  0.9843,
           0.9843,  0.9843,  0.9843,  0.5765, -0.3882, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.8196, -0.4824,  0.6706,  0.9843,  0.9843,  0.9843,
           0.9843,  0.5529, -0.3647, -0.9843, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.8588,
           0.3412,  0.7176,  0.9843,  0.9843,  0.9843,  0.9843,  0.5294,
          -0.3725, -0.9294, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -0.5686,  0.3490,  0.7725,
           0.9843,  0.9843,  0.9843,  0.9843,  0.9137,  0.0431, -0.9137,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000,  0.0667,  0.9843,  0.9843,
           0.9843,  0.6627,  0.0588,  0.0353, -0.8745, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000]]])
标签: 5

样本 2 的数据:
tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.6000,  0.2471,  0.9843,  0.2471, -0.6078, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.6235,  0.8667,  0.9765,  0.9765,  0.9765,  0.8588, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.5765,
           0.7804,  0.9843,  0.9765,  0.8745,  0.8275,  0.9765, -0.5529,
          -0.9529, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.9216, -0.5294,  0.7569,
           0.9765,  0.9843,  0.9765,  0.5843, -0.3412,  0.9765,  0.9843,
          -0.0431, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000,  0.2784,  0.9765,  0.9765,
           0.9765,  0.9843,  0.9765,  0.9765, -0.2471,  0.4824,  0.9843,
           0.3098, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.6000,  0.8667,  0.9843,  0.9843,
           0.4902, -0.1059,  0.9843,  0.7882, -0.6314, -0.3804,  1.0000,
           0.3176, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -0.6235,  0.8667,  0.9765,  0.9765,  0.4039,
          -0.9059, -0.4118, -0.0510, -0.8353, -1.0000, -1.0000,  0.9843,
           0.9059, -0.6078, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.7020,  0.2941,  0.9843,  0.8275,  0.6314, -0.3412,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,  0.9843,
           0.9765,  0.2941, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.9451,  0.3961,  0.9765,  0.8824, -0.4431, -0.8510, -0.7804,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,  0.9843,
           0.9765,  0.5294, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.5529,  0.9765,  0.9765, -0.5059, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,  0.9843,
           0.9765,  0.5294, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           0.5529,  0.9843,  0.4902, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,  1.0000,
           0.9843,  0.5373, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.4039,
           0.9294,  0.9765, -0.1216, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,  0.9843,
           0.9765,  0.1608, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.3333,
           0.9765,  0.8039, -0.8039, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.9451,  0.0588,  0.9843,
           0.4588, -0.9059, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.3333,
           0.9765,  0.7490, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.9451,  0.0275,  0.9765,  0.7647,
          -0.4431, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.3333,
           0.9765,  0.1373, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -0.6235,  0.2941,  0.9765,  0.3569, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.3255,
           0.9843,  0.7647, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.1059,  0.8667,  0.9843,  0.2706, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.3333,
           0.9765,  0.9529,  0.1451, -0.6235, -0.7725, -0.3333,  0.3961,
           0.7647,  0.9843,  0.7490,  0.3098, -0.5608, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.3333,
           0.9765,  0.9765,  0.9765,  0.7961,  0.6863,  0.9765,  0.9765,
           0.9765,  0.5373,  0.0196, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.7804,
           0.5608,  0.9765,  0.9765,  0.9843,  0.9765,  0.9765,  0.8275,
           0.1373, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.8039,  0.0039,  0.9765,  0.9843,  0.9765,  0.1059, -0.7098,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000]]])
标签: 0

样本 3 的数据:
tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.4745,
           0.8196, -0.6941, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -0.5137, -0.3647, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.0588,
           0.4118, -0.6941, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -0.0118,  0.2784, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.9843,  0.2000,
           0.6471, -0.6863, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000,  0.7255,  0.2784, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.7882,  0.9922,
           0.2706, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000,  0.7412,  0.2784, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,  0.4353,  0.9922,
          -0.0196, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -0.6392,  0.9216,  0.2784, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,  0.5529,  0.9922,
          -0.5608, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -0.0588,  0.9922,  0.2784, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.8196,  0.8118,  0.9922,
          -0.7725, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000,  0.2471,  0.9922, -0.0588, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000,  0.2784,  0.9922,  0.6941,
          -0.8745, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000,  0.2471,  0.9922, -0.4745, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.8902, -0.3255,  0.3961,  0.9451,  0.9922, -0.2863,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000,  0.2471,  0.9922, -0.3333, -1.0000,
          -1.0000, -1.0000, -0.6314, -0.6157, -0.0902,  0.1294,  0.1765,
           0.8902,  0.9059,  0.8353,  0.4039,  0.8902,  0.9765, -0.6863,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000,  0.1765,  0.9843,  0.8588,  0.6235,
           0.6235,  0.6235,  0.9843,  0.9922,  0.9608,  0.8824,  0.5529,
           0.1216, -0.2863, -0.7804, -0.9608,  0.8275,  0.9608, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -0.0667,  0.3882,  0.3882,
           0.3882,  0.3882,  0.3882, -0.2314, -0.5608, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.2000,  0.9922,  0.7255, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000,  0.3255,  0.9922,  0.0745, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000,  0.3255,  0.9922, -0.5529, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000,  0.3255,  0.9922, -0.5529, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000,  0.3255,  1.0000, -0.2627, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000,  0.3255,  0.9922, -0.2471, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000,  0.3255,  0.9922,  0.2000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000,  0.3255,  1.0000,  0.2000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.2471,  0.9922,  0.2000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000]]])
标签: 4

样本 4 的数据:
tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.0275,  0.9843,  1.0000,
          -0.5059, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.2471,  0.9137,  0.9686,  0.9843,
          -0.5137, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.0039,  0.9686,  0.9686,  0.9843,
          -0.5137, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -0.4667,  0.8510,  0.9686,  0.6549, -0.7569,
          -0.9373, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.5294,  0.7882,  0.9686,  0.9686, -0.2627, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000,  0.2157,  0.9843,  0.9843,  0.4824, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.8431,  0.9843,  0.9686,  0.8431, -0.4824, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.7490,
           0.6078,  0.9843,  0.9686, -0.0118, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.1843,
           0.9686,  0.9843,  0.4431, -0.8824, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.3725,  0.8824,
           0.9686,  0.5137, -0.8196, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.7490,  0.9843,  0.9843,
           0.9843,  0.2471, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000,  0.1843,  0.9686,  0.9686,
           0.9686, -0.6941, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.6235,  0.7333,  0.9686,  0.9686,
           0.3490, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000,  0.8353,  0.9686,  0.9686,  0.5373,
          -0.9059, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000,  0.9843,  0.9686,  0.9686, -0.3020,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000,  0.2471,  1.0000,  0.9843,  0.9843, -0.7569,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.6235,  0.7882,  0.9843,  0.9373,  0.0980, -0.9373,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.4980,  0.9686,  0.9843,  0.7255, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.4980,  0.9686,  0.9843,  0.7255, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.8118,  0.5137,  0.9843,  0.7255, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000]]])
标签: 1

样本 5 的数据:
tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.5686,  0.1608,
           0.6471,  0.9843,  0.9843, -0.1137, -0.3176,  0.1608, -0.5686,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.3176,  0.8196,  0.9765,
           0.9843,  0.4824,  0.6471,  0.9765,  0.9765,  0.9843,  0.3176,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -0.9686, -0.5529,  0.8980,  0.9765,  0.4902,
          -0.4902, -0.9608, -0.9059,  0.4275,  0.9765,  0.9843, -0.0902,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -0.2471,  0.9765,  0.9765,  0.4353, -0.8902,
          -1.0000, -1.0000, -0.2784,  0.9765,  0.9765,  0.7647, -0.8353,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000,  0.0353,  0.9843,  0.9765,  0.1451, -0.8902, -1.0000,
          -1.0000, -1.0000,  0.6863,  0.9765,  0.9765, -0.3804, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.0118,  0.9843,  0.9373,  0.3804, -0.9294, -1.0000, -1.0000,
          -0.9373, -0.3882,  0.9216,  0.9843,  0.0118, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.8745,
           0.8196,  0.9765,  0.3804, -1.0000, -1.0000, -1.0000, -0.7176,
           0.5765,  0.9765,  0.9765,  0.3255, -0.9137, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.8275,
           0.9765,  0.9765, -0.7647, -0.8275, -0.0667,  0.5451,  0.8902,
           0.9843,  0.9765,  0.9686, -0.3961, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.8745,
           0.8118,  0.9765,  0.9843,  0.9765,  0.9765,  0.9765,  0.7725,
           0.7804,  0.9765,  0.8118, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.5686,  0.8431,  0.9843,  0.7020,  0.0824, -0.6706, -0.8118,
           0.5059,  0.9765,  0.1216, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.5137,
           1.0000,  0.9843, -0.1451, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.4431,
           0.9843,  0.9765, -0.8353, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           0.9843,  0.9765, -0.8353, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.4431,
           0.9843,  0.9765, -0.8353, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.1686,
           0.9843,  0.9765, -0.8353, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.6471,
           1.0000,  0.9843, -0.8353, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           0.7098,  0.9765, -0.5608, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.2471,  0.9765,  0.4824, -0.6706, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.8902,  0.4431,  0.9765,  0.3333, -0.9137, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.8902,  0.1529,  0.9765, -0.6706, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000]]])
标签: 9

展示训练集中的第一幅图片:

1
2
3
4
5
6
7
8
9
10
11
12
13
import matplotlib.pyplot as plt

# 获取第一个样本的数据和标签
first_image, first_label = train_dataset[0]

# 将图像数据从形状 (1, 28, 28) 转换为 (28, 28),以便显示
first_image = first_image.squeeze().numpy() # 去除通道维度,并转换为numpy数组

# 显示图像
plt.imshow(first_image, cmap='gray') # 使用灰度色阶显示图像
plt.title(f'Label: {first_label}') # 显示图像标签作为标题
plt.axis('off') # 不显示坐标轴
plt.show()

png

1
2
3
4
5
6
7
8
9
10
# 获取第一个样本的数据
first_image, _ = train_dataset[0]

# 将张量转换为numpy数组以便打印
first_image_np = first_image.numpy()

# 打印原始的三维数组(包含通道维度)
print("原始图像数据(包含通道维度):")
print(first_image_np)

原始图像数据(包含通道维度):
[[[-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -0.9764706  -0.85882354 -0.85882354
   -0.85882354 -0.01176471  0.06666672  0.37254906 -0.79607844
    0.30196083  1.          0.9372549  -0.00392157 -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -0.7647059  -0.7176471
   -0.26274508  0.20784318  0.33333337  0.9843137   0.9843137
    0.9843137   0.9843137   0.9843137   0.7647059   0.34901965
    0.9843137   0.8980392   0.5294118  -0.4980392  -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -0.6156863   0.8666667   0.9843137
    0.9843137   0.9843137   0.9843137   0.9843137   0.9843137
    0.9843137   0.9843137   0.96862745 -0.27058822 -0.35686272
   -0.35686272 -0.56078434 -0.69411767 -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -0.85882354  0.7176471   0.9843137
    0.9843137   0.9843137   0.9843137   0.9843137   0.5529412
    0.427451    0.9372549   0.8901961  -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -0.372549    0.22352946
   -0.1607843   0.9843137   0.9843137   0.60784316 -0.9137255
   -1.         -0.6627451   0.20784318 -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -0.8901961
   -0.99215686  0.20784318  0.9843137  -0.29411763 -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.          0.09019613  0.9843137   0.4901961  -0.9843137
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -0.9137255   0.4901961   0.9843137  -0.45098037
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -0.7254902   0.8901961   0.7647059
    0.254902   -0.15294117 -0.99215686 -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -0.36470586  0.88235295
    0.9843137   0.9843137  -0.06666666 -0.8039216  -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -0.64705884
    0.45882356  0.9843137   0.9843137   0.17647064 -0.7882353
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -0.8745098  -0.27058822  0.9764706   0.9843137   0.4666667
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.          0.9529412   0.9843137   0.9529412
   -0.4980392  -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -0.6392157
    0.0196079   0.43529415  0.9843137   0.9843137   0.62352943
   -0.9843137  -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -0.69411767  0.16078436  0.79607844
    0.9843137   0.9843137   0.9843137   0.9607843   0.427451
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -0.8117647  -0.10588235  0.73333335  0.9843137   0.9843137
    0.9843137   0.9843137   0.5764706  -0.38823527 -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -0.81960785 -0.4823529
    0.67058825  0.9843137   0.9843137   0.9843137   0.9843137
    0.5529412  -0.36470586 -0.9843137  -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -0.85882354  0.3411765   0.7176471   0.9843137
    0.9843137   0.9843137   0.9843137   0.5294118  -0.372549
   -0.92941177 -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -0.5686275
    0.34901965  0.77254903  0.9843137   0.9843137   0.9843137
    0.9843137   0.9137255   0.04313731 -0.9137255  -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.          0.06666672
    0.9843137   0.9843137   0.9843137   0.6627451   0.05882359
    0.03529418 -0.8745098  -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.         -1.         -1.
   -1.         -1.         -1.        ]]]
1
2
3
4
5

# 如果你想去掉通道维度,打印二维矩阵
first_image_2d = first_image_np.squeeze() # 去掉通道维度
print("\n去掉通道后的二维矩阵:")
print(first_image_2d)
去掉通道后的二维矩阵:
[[-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -0.9764706  -0.85882354 -0.85882354 -0.85882354 -0.01176471  0.06666672
   0.37254906 -0.79607844  0.30196083  1.          0.9372549  -0.00392157
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -0.7647059  -0.7176471  -0.26274508  0.20784318
   0.33333337  0.9843137   0.9843137   0.9843137   0.9843137   0.9843137
   0.7647059   0.34901965  0.9843137   0.8980392   0.5294118  -0.4980392
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -0.6156863   0.8666667   0.9843137   0.9843137   0.9843137
   0.9843137   0.9843137   0.9843137   0.9843137   0.9843137   0.96862745
  -0.27058822 -0.35686272 -0.35686272 -0.56078434 -0.69411767 -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -0.85882354  0.7176471   0.9843137   0.9843137   0.9843137
   0.9843137   0.9843137   0.5529412   0.427451    0.9372549   0.8901961
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -0.372549    0.22352946 -0.1607843   0.9843137
   0.9843137   0.60784316 -0.9137255  -1.         -0.6627451   0.20784318
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -0.8901961  -0.99215686  0.20784318
   0.9843137  -0.29411763 -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.          0.09019613
   0.9843137   0.4901961  -0.9843137  -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -0.9137255
   0.4901961   0.9843137  -0.45098037 -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -0.7254902   0.8901961   0.7647059   0.254902   -0.15294117 -0.99215686
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -0.36470586  0.88235295  0.9843137   0.9843137  -0.06666666
  -0.8039216  -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -0.64705884  0.45882356  0.9843137   0.9843137
   0.17647064 -0.7882353  -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -0.8745098  -0.27058822  0.9764706
   0.9843137   0.4666667  -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.          0.9529412
   0.9843137   0.9529412  -0.4980392  -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -0.6392157   0.0196079   0.43529415  0.9843137
   0.9843137   0.62352943 -0.9843137  -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -0.69411767  0.16078436  0.79607844  0.9843137   0.9843137   0.9843137
   0.9607843   0.427451   -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -0.8117647  -0.10588235
   0.73333335  0.9843137   0.9843137   0.9843137   0.9843137   0.5764706
  -0.38823527 -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -0.81960785 -0.4823529   0.67058825  0.9843137
   0.9843137   0.9843137   0.9843137   0.5529412  -0.36470586 -0.9843137
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -0.85882354  0.3411765   0.7176471   0.9843137   0.9843137   0.9843137
   0.9843137   0.5294118  -0.372549   -0.92941177 -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -0.5686275   0.34901965
   0.77254903  0.9843137   0.9843137   0.9843137   0.9843137   0.9137255
   0.04313731 -0.9137255  -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.          0.06666672  0.9843137
   0.9843137   0.9843137   0.6627451   0.05882359  0.03529418 -0.8745098
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]]

定义模型、优化器、损失函数

进行2次卷积和2次池化,得到64717,再进行2次全连接,得到10个输出。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# 定义批次大小
batch_size = 64
# 使用DataLoader加载数据,以便在训练过程中更方便地迭代数据
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 初始化卷积层、池化层、全连接层和dropout层
# 定义第一个卷积层,用于提取特征
# 输入通道数为1(适用于灰度图像),输出通道数为32,卷积核大小为5x5,步长为1,padding为2
# 第一次卷积后生成的特征图大小为32*28*28
self.conv1 = nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2)

# 定义最大池化层,用于降低特征维度,减少计算量
# 池化窗口大小为2x2,步长为2,无padding
# 第一次池化后生成的特征图大小为32*14*14
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

# 定义第二个卷积层,进一步提取和整合特征
# 输入通道数为32,输出通道数为64,卷积核大小为5x5,步长为1,padding为2
# 第二次卷积后生成的特征图大小为64*14*14
self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2)

# 定义第一个全连接层,用于分类前的特征转换
# 输入大小为64*7*7(这里的尺寸为64*7*7是因为在前向传播时对第二次卷积进行了池化操作),输出大小为1024
self.fc1 = nn.Linear(64 * 7 * 7, 1024)

# 定义第二个全连接层,用于最终的分类输出
# 输入大小为1024,输出大小为10(假设分类任务有10个类别)
self.fc2 = nn.Linear(1024, 10)

# 定义Dropout层,用于训练过程中的正则化,防止过拟合
# Dropout比例为0.5,即在训练过程中随机将50%的元素置为0
self.dropout = nn.Dropout(p=0.5)

def forward(self, x):
# 定义前向传播过程
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x

# 实例化模型
model = Net()

# 定义优化器
# 使用Adam优化器更新模型参数
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# 定义损失函数
# 使用交叉熵损失函数进行分类任务
criterion = nn.CrossEntropyLoss()

训练模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 将模型转移到GPU设备上(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 定义训练轮数
num_epochs = 10
# 开始训练过程
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# 将数据转移到GPU设备上(如果可用)
images, labels = images.to(device), labels.to(device)

# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)

# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()

# 打印损失信息
if (i + 1) % 100 == 0:
print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')

# 保存模型参数到文件
torch.save(model.state_dict(), 'model.pth')
Epoch [1/10], Step [100/938], Loss: 0.0037
Epoch [1/10], Step [200/938], Loss: 0.0029
Epoch [1/10], Step [300/938], Loss: 0.0022
Epoch [1/10], Step [400/938], Loss: 0.0651
Epoch [1/10], Step [500/938], Loss: 0.0038
Epoch [1/10], Step [600/938], Loss: 0.0022
Epoch [1/10], Step [700/938], Loss: 0.0079
Epoch [1/10], Step [800/938], Loss: 0.0019
Epoch [1/10], Step [900/938], Loss: 0.0016
Epoch [2/10], Step [100/938], Loss: 0.0016
Epoch [2/10], Step [200/938], Loss: 0.0102
Epoch [2/10], Step [300/938], Loss: 0.0341
Epoch [2/10], Step [400/938], Loss: 0.0060
Epoch [2/10], Step [500/938], Loss: 0.0257
Epoch [2/10], Step [600/938], Loss: 0.0013
Epoch [2/10], Step [700/938], Loss: 0.0767
Epoch [2/10], Step [800/938], Loss: 0.0018
Epoch [2/10], Step [900/938], Loss: 0.0343
Epoch [3/10], Step [100/938], Loss: 0.0063
Epoch [3/10], Step [200/938], Loss: 0.0096
Epoch [3/10], Step [300/938], Loss: 0.0007
Epoch [3/10], Step [400/938], Loss: 0.0002
Epoch [3/10], Step [500/938], Loss: 0.0124
Epoch [3/10], Step [600/938], Loss: 0.0109
Epoch [3/10], Step [700/938], Loss: 0.0340
Epoch [3/10], Step [800/938], Loss: 0.0004
Epoch [3/10], Step [900/938], Loss: 0.0586
Epoch [4/10], Step [100/938], Loss: 0.0002
Epoch [4/10], Step [200/938], Loss: 0.0554
Epoch [4/10], Step [300/938], Loss: 0.0008
Epoch [4/10], Step [400/938], Loss: 0.0029
Epoch [4/10], Step [500/938], Loss: 0.0036
Epoch [4/10], Step [600/938], Loss: 0.0009
Epoch [4/10], Step [700/938], Loss: 0.0281
Epoch [4/10], Step [800/938], Loss: 0.0826
Epoch [4/10], Step [900/938], Loss: 0.0003
Epoch [5/10], Step [100/938], Loss: 0.0001
Epoch [5/10], Step [200/938], Loss: 0.0240
Epoch [5/10], Step [300/938], Loss: 0.0040
Epoch [5/10], Step [400/938], Loss: 0.0003
Epoch [5/10], Step [500/938], Loss: 0.0107
Epoch [5/10], Step [600/938], Loss: 0.0019
Epoch [5/10], Step [700/938], Loss: 0.0002
Epoch [5/10], Step [800/938], Loss: 0.0006
Epoch [5/10], Step [900/938], Loss: 0.0008
Epoch [6/10], Step [100/938], Loss: 0.0003
Epoch [6/10], Step [200/938], Loss: 0.0001
Epoch [6/10], Step [300/938], Loss: 0.0003
Epoch [6/10], Step [400/938], Loss: 0.0226
Epoch [6/10], Step [500/938], Loss: 0.0024
Epoch [6/10], Step [600/938], Loss: 0.0020
Epoch [6/10], Step [700/938], Loss: 0.0005
Epoch [6/10], Step [800/938], Loss: 0.0007
Epoch [6/10], Step [900/938], Loss: 0.0188
Epoch [7/10], Step [100/938], Loss: 0.0286
Epoch [7/10], Step [200/938], Loss: 0.0007
Epoch [7/10], Step [300/938], Loss: 0.0004
Epoch [7/10], Step [400/938], Loss: 0.0008
Epoch [7/10], Step [500/938], Loss: 0.0001
Epoch [7/10], Step [600/938], Loss: 0.0006
Epoch [7/10], Step [700/938], Loss: 0.0005
Epoch [7/10], Step [800/938], Loss: 0.0007
Epoch [7/10], Step [900/938], Loss: 0.0432
Epoch [8/10], Step [100/938], Loss: 0.0005
Epoch [8/10], Step [200/938], Loss: 0.0005
Epoch [8/10], Step [300/938], Loss: 0.0000
Epoch [8/10], Step [400/938], Loss: 0.0013
Epoch [8/10], Step [500/938], Loss: 0.0005
Epoch [8/10], Step [600/938], Loss: 0.0002
Epoch [8/10], Step [700/938], Loss: 0.0004
Epoch [8/10], Step [800/938], Loss: 0.0111
Epoch [8/10], Step [900/938], Loss: 0.0001
Epoch [9/10], Step [100/938], Loss: 0.0004
Epoch [9/10], Step [200/938], Loss: 0.0693
Epoch [9/10], Step [300/938], Loss: 0.0071
Epoch [9/10], Step [400/938], Loss: 0.0000
Epoch [9/10], Step [500/938], Loss: 0.0003
Epoch [9/10], Step [600/938], Loss: 0.0003
Epoch [9/10], Step [700/938], Loss: 0.0001
Epoch [9/10], Step [800/938], Loss: 0.0000
Epoch [9/10], Step [900/938], Loss: 0.0029
Epoch [10/10], Step [100/938], Loss: 0.0008
Epoch [10/10], Step [200/938], Loss: 0.0001
Epoch [10/10], Step [300/938], Loss: 0.0000
Epoch [10/10], Step [400/938], Loss: 0.0273
Epoch [10/10], Step [500/938], Loss: 0.0001
Epoch [10/10], Step [600/938], Loss: 0.0002
Epoch [10/10], Step [700/938], Loss: 0.0010
Epoch [10/10], Step [800/938], Loss: 0.0019
Epoch [10/10], Step [900/938], Loss: 0.0000

评估模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 实例化一个与原模型结构相同的模型
model = Net().to(device) # 确保模型被放置在正确的设备上

# 加载模型参数
model.load_state_dict(torch.load('model.pth', map_location=device, weights_only=True))

# 将模型设置为评估模式
model.eval()
# 禁用梯度计算以减少内存消耗
with torch.no_grad():
correct = 0
total = 0
# 在测试集上进行预测
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

# 打印测试集上的准确率
print(f'Test Accuracy of the model on the {total} test images: {100 * correct / total}%')
Test Accuracy of the model on the 10000 test images: 99.29%

代码获取

关注公众号“生信之巅”,聊天窗口回复“a7fe”获取完整版代码下载链接。

生信之巅微信公众号 生信之巅小程序码

敬告:使用文中脚本请引用本文网址,请尊重本人的劳动成果,谢谢!Notice: When you use the scripts in this article, please cite the link of this webpage. Thank you!

上一篇:
Scikit-learn机器学习实战-PCA
下一篇:
Scikit-learn机器学习实战-HumanResourcesAnalytics