CNN卷积网络实现MNIST数据集手写数字识别
创始人
2024-11-14 08:43:39
0

步骤一:加载MNIST数据集

train_data = MNIST(root='./data',train=True,download=False,transform=transforms.ToTensor()) train_loader = DataLoader(train_data,shuffle=True,batch_size=64) # 测试数据集 test_data = MNIST(root='./data',train=False,download=False,transform=transforms.ToTensor()) test_loader = DataLoader(test_data,shuffle=False,batch_size=64)

首先,通过MNIST类创建了train_data对象,指定了数据集的路径root='./data',并且将数据集标记为训练集train=Truedownload=False表示不自动从网络上下载数据集,而是使用已经下载好的数据集。我是之前自己已经下载过该数据集所以这里填的是False,如果之前没有下载的话就要填True。下面测试集也是一样。transforms.ToTensor()将数据转换为张量形式。

然后,通过DataLoader类创建了train_loader对象,指定了使用train_data作为数据源。shuffle=True表示在每个epoch开始时,将数据打乱顺序。batch_size=64表示每次抓取64个样本。

接下来,同样的步骤也被用来创建了测试集的数据加载器test_loader。不同的是,这里将数据集标记为测试集train=False,并且shuffle=False表示不需要打乱顺序。

加载完的数据集存在MNIST文件夹的raw文件夹下内容如下:

其中t10k-images-idx3-ubyte是测试集的图像,t10k-labels-idx3-ubyte是测试集的标签。train-images-idx3-ubyte是训练集的图像,train-labels-idx1-ubyte是训练集的标签。

存下来的这些数据集是二进制的形式,可以通过下面的代码(1.py)读取:

""" Created on Sat Jul 27 15:26:39 2024  @author: wangyiyuan """ # 导入包 import struct import numpy as np from PIL import Image  class MnistParser:    # 加载图像    def load_image(self, file_path):         # 读取二进制数据        binary = open(file_path,'rb').read()         # 读取头文件        fmt_head = '>iiii'        offset = 0         # 读取头文件        magic_number,images_number,rows_number,columns_number = struct.unpack_from(fmt_head,binary,offset)         # 打印头文件信息        print('图片数量:%d,图片行数:%d,图片列数:%d'%(images_number,rows_number,columns_number))         # 处理数据        image_size = rows_number * columns_number        fmt_data = '>'+str(image_size)+'B'        offset = offset + struct.calcsize(fmt_head)         # 读取数据        images = np.empty((images_number,rows_number,columns_number))        for i in range(images_number):            images[i] = np.array(struct.unpack_from(fmt_data, binary, offset)).reshape((rows_number, columns_number))            offset = offset + struct.calcsize(fmt_data)            # 每1万张打印一次信息            if (i+1) % 10000 == 0:                print('> 已读取:%d张图片'%(i+1))         # 返回数据        return images_number,rows_number,columns_number,images      # 加载标签    def load_labels(self, file_path):        # 读取数据        binary = open(file_path,'rb').read()         # 读取头文件        fmt_head = '>ii'        offset = 0         # 读取头文件        magic_number,items_number = struct.unpack_from(fmt_head,binary,offset)         # 打印头文件信息        print('标签数:%d'%(items_number))         # 处理数据        fmt_data = '>B'        offset = offset + struct.calcsize(fmt_head)         # 读取数据        labels = np.empty((items_number))        for i in range(items_number):            labels[i] = struct.unpack_from(fmt_data, binary, offset)[0]            offset = offset + struct.calcsize(fmt_data)            # 每1万张打印一次信息            if (i+1)%10000 == 0:                print('> 已读取:%d个标签'%(i+1))         # 返回数据        return items_number,labels      # 图片可视化    def visualaztion(self, images, labels, path):        d = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0, 9:0}        for i in range(images.__len__()):             im = Image.fromarray(np.uint8(images[i]))             im.save(path + "%d_%d.png"%(labels[i], d[labels[i]]))             d[labels[i]] += 1             # im.show()                          if (i+1)%10000 == 0:                print('> 已保存:%d个图片'%(i+1))                  # 保存为图片格式 def change_and_save():     mnist =  MnistParser()      trainImageFile = './train-images-idx3-ubyte'     _, _, _, images = mnist.load_image(trainImageFile)     trainLabelFile = './train-labels-idx1-ubyte'     _, labels = mnist.load_labels(trainLabelFile)     mnist.visualaztion(images, labels, "./images/train/")      testImageFile = './train-images-idx3-ubyte'     _, _, _, images = mnist.load_image(testImageFile)     testLabelFile = './train-labels-idx1-ubyte'     _, labels = mnist.load_labels(testLabelFile)     mnist.visualaztion(images, labels, "./images/test/")   # 测试 if __name__ == '__main__':     change_and_save()   

将这个1.py文件和下载好的数据集放在同一个文件夹下:

新建一个文件夹images,在文件夹images里面新建两个文件夹分别叫test和train。

运行完可以发现train和test里的内容如下:

步骤二:建立模型

class Model(nn.Module):     def __init__(self):         super(Model,self).__init__()         self.linear1 = nn.Linear(784,256)         self.linear2 = nn.Linear(256,64)         self.linear3 = nn.Linear(64,10) # 10个手写数字对应的10个输出      def forward(self,x):         x = x.view(-1,784) # 变形         x = torch.relu(self.linear1(x))         x = torch.relu(self.linear2(x))         # x = torch.relu(self.linear3(x))         return x 

这里是建立了一个神经网络模型类(Model)。这个模型有三个线性层(linear1、linear2、linear3)。输入维度为784(因为每一张图片的大小是28*28=784),输出维度为256、64、10(因为有十个类)。forward函数定义了模型的前向传播过程,其中x.view(-1, 784)将输入张量x变形为(batch_size, 784)的大小。然后经过三个线性层和relu激活函数进行运算,最后返回输出结果x。

步骤三:训练模型

model = Model() criterion = nn.CrossEntropyLoss() # 交叉熵损失,相当于Softmax+Log+NllLoss optimizer = torch.optim.SGD(model.parameters(),0.8) # 第一个参数是初始化参数值,第二个参数是学习率  # 模型训练 # def train(): for index,data in enumerate(train_loader):         input,target = data # input为输入数据,target为标签         optimizer.zero_grad() # 梯度清零         y_predict = model(input) # 模型预测         loss = criterion(y_predict,target) # 计算损失         loss.backward() # 反向传播         optimizer.step() # 更新参数         if index % 100 == 0: # 每一百次保存一次模型,打印损失             torch.save(model.state_dict(),"./model/model.pkl") # 保存模型             torch.save(optimizer.state_dict(),"./model/optimizer.pkl")             print("损失值为:%.2f" % loss.item()) 

首先创建了一个模型对象model,一个损失函数对象criterion和一个优化器对象optimizer。然后使用一个for循环遍历训练数据集train_loader,每次取出一个batch的数据。接着将优化器的梯度清零,然后使用模型前向传播得到预测结果y_predict,计算损失值loss,然后进行反向传播和参数更新。每训练100个batch,保存模型和优化器的参数,并打印当前的损失值。

步骤四:保存模型参数

if os.path.exists('./model/model.pkl'):     model.load_state_dict(torch.load("./model/model.pkl")) # 加载保存模型的参数

在当前文件夹下新建一个名叫model的文件夹。保存步骤三中训练完模型的参数。

步骤五:检验模型

     correct = 0 # 正确预测的个数     total = 0 # 总数     with torch.no_grad(): # 测试不用计算梯度         for data in test_loader:             input,target = data             output=model(input) # output输出10个预测取值,其中最大的即为预测的数             probability,predict=torch.max(output.data,dim=1) # 返回一个元组,第一个为最大概率值,第二个为最大值的下标             total += target.size(0) # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小             correct += (predict == target).sum().item() # predict和target均为(batch_size,1)的矩阵,sum()求出相等的个数         print("准确率为:%.2f" % (correct / total))   

参数说明:

  • correct:记录正确预测的个数
  • total:记录总样本数
  • test_loader:测试集的数据加载器
  • input:输入数据
  • target:目标标签
  • output:模型的输出结果
  • probability:最大概率值
  • predict:最大值的下标

过程:

  • 使用torch.no_grad()包装测试过程,表示不需要计算梯度
  • 遍历测试集中的每个数据,获取输入数据和目标标签
  • 将输入数据输入模型,得到模型的输出结果
  • 使用torch.max()函数返回预测结果中的最大概率值和最大值的下标
  • 更新总数和正确预测的个数
  • 最后计算并输出准确率。

步骤六:检测自己的手写数据

if __name__ == '__main__':     # 自定义测试     image = Image.open('C:/Users/wangyiyuan/Desktop/20201116160729670.jpg') # 读取自定义手写图片     image = image.resize((28,28)) # 裁剪尺寸为28*28     image = image.convert('L') # 转换为灰度图像     transform = transforms.ToTensor()     image = transform(image)     image = image.resize(1,1,28,28)     output = model(image)     probability,predict=torch.max(output.data,dim=1)     print("此手写图片值为:%d,其最大概率为:%.2f" % (predict[0],probability))     plt.title('此手写图片值为:{}'.format((int(predict))),fontname="SimHei")     plt.imshow(image.squeeze())     plt.show() 

这里的C:/Users/wangyiyuan/Desktop/20201116160729670.jpg是我自己从网上找的的手写图片。这段代码意思如下:

  1. 打开并读取一张手写图片,图片的路径为'C:/Users/wangyiyuan/Desktop/20201116160729670.jpg'。
  2. 调整图片尺寸为28x28。
  3. 将图片转换为灰度图像,以便后续处理。
  4. 使用transforms.ToTensor()将图片转换为PyTorch张量。
  5. 调整图片尺寸为(1, 1, 28, 28)以适应模型的输入要求。
  6. 将处理后的图片输入模型,获取预测输出。
  7. 通过torch.max函数获得输出中的最大值及其索引,即预测的数字和其概率。
  8. 打印预测的数字和概率。
  9. 在图像上显示预测结果和手写图片。
  10. 展示图像。

步骤七:结果展示

我的原图是:

测试得到的结果为:


损失值为:4.16
损失值为:0.93
损失值为:0.31
损失值为:0.19
损失值为:0.24
损失值为:0.15
损失值为:0.13
损失值为:0.11
损失值为:0.18
损失值为:0.02
此手写图片值为:2,其最大概率为:6.57

----------------------码字不易,请多多关注博主!-----------------------------------------------主程序可以关注博主后,私信秒发-------------------

相关内容

热门资讯

wepoke确实有挂(透视)W... wepoke确实有挂(透视)We辅poker助(详细辅助解密教程)确实是真的有挂(可靠计算辅助)1、...
5分钟普及!卡丁互娱辅助器&q... 5分钟普及!卡丁互娱辅助器"详细辅助玩家教程"(切实存在有挂)卡丁互娱辅助器辅助器中分为三种模型:卡...
wepoke辅助挂(透视)we... wepoke辅助挂(透视)wepoke靠谱吗(详细辅助解密教程)竟然存在有挂(黑科技插件)1、wep...
三分钟普及!微信十三张脚本&q... 三分钟普及!微信十三张脚本"详细辅助解密教程"(原来是真的有挂)运微信十三张脚本辅助工具,进入游戏界...
第2分钟普及!桂林八一字牌辅助... 第2分钟普及!桂林八一字牌辅助"详细辅助揭秘攻略"(总是是有挂);在进入桂林八一字牌辅助辅助挂后,参...
WePoKe透视挂(透视)we... WePoKe透视挂(透视)wepoke好友助力(详细辅助安装教程)切实是有挂(科普辅助技巧)1、让任...
4分钟普及!贰柒拾辅助软件是真... 4分钟普及!贰柒拾辅助软件是真的吗"详细辅助详细教程"(竟然是有挂)1、贰柒拾辅助软件是真的吗系统规...
wepoke计算辅助(透视)w... wepoke计算辅助(透视)wepoke有挂吗(详细辅助wepoke教程)本来真的是有挂(专业软件透...
第2分钟普及!爱玩辅助&quo... 第2分钟普及!爱玩辅助"详细辅助揭秘教程"(确实存在有挂);1、爱玩辅助ai辅助优化,爱玩辅助发牌逻...
wepoke是真的有挂(透视)... wepoke是真的有挂(透视)wepoke有规律吗(详细辅助可靠技巧)其实是有挂(大神模拟器)wep...