深度学习pytorch小实验

让我们开始进行一个简单的深度学习实验吧!我们将使用PyTorch来实现。在这个实验中,我们将训练一个基本的人工神经网络(Artificial Neural Network,ANN)来进行手写数字的识别。首先,我们需要导入相关的库和模块。请确保你已经安装了PyTorch和torchvision。

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

接下来,我们需要加载手写数字的数据集。PyTorch提供了许多常见的数据集,我们将使用MNIST数据集。MNIST数据集是一个包含有手写数字图像的数据集,它包含了60000个训练样本和10000个测试样本。我们可以使用torchvision中的datasetsDataLoader来加载这些数据。

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

然后,我们定义一个简单的神经网络模型。在这个实验中,我们将使用一个简单的两层全连接神经网络。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

然后,我们实例化这个神经网络模型,并定义一个损失函数和一个优化器。

net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

接下来,我们要进行训练。我们将进行多个epoch的训练,并在每个epoch之后测试模型的性能。

for epoch in range(2):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

在每个epoch之后,我们使用测试集来评估模型的性能。

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

这是一个简单的深度学习实验,我们使用PyTorch实现了一个用于手写数字识别的神经网络模型,并在MNIST数据集上进行了训练和测试。你可以尝试不同的网络结构、超参数和优化算法来改进模型的性能。祝你好运!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/569369.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

1W 1.5KVDC 3KVDC 隔离宽范围输入,单、双输出 DC/DC 电源模块——TP2L-1W 系列

TP2L-1W系列是一款高性能、超小型的电源模块,宽范围2:1,4:1输入,输出有稳压和连续短路保护功能,隔离电压为1.5KVDC,3KVDC工作温度范围为–40℃到85℃。特别适合对输出电压的精度有严格要求的地方,外部遥控功…

【Python】爬虫-基础入门

目录 一、什么是爬虫 二、爬虫的主要用途 三、学会爬虫需要掌握的技能 四、爬虫使用的语言 五、编写爬虫需要的库,以python为例 六、爬虫示例-python 示例一 示例二 示例三 一、什么是爬虫 爬虫,又称网络爬虫或网页爬虫,是一种用来自…

《智能前沿:应对ChatGPT算力挑战》

在全球人工智能热潮中,以 ChatGPT 为代表的 AIGC 技术引发了广泛关注。人工智能和机器学习等技术对数据规模及处理速度等提出了更高要求。在数据成为主要生产要素的当下和未来,如何跟上时代的发展步伐,构建适应 AI 需求的数据中心&#xff0c…

Keil和VSCode协同开发STM32程序

系列文章 STM32单片机系列专栏 C语言术语和结构总结专栏 文章目录 1. 配置环境 2. 测试打开工程 3. 测试编译工程 随着项目的复杂度上升,开发者不仅需要强大的硬件支持,还需要一个高效和灵活的开发环境。 vscode是一款集成大量可以便携开发插件的代码…

自动化软件测试策略

作为一名软件开发人员,我在不同的公司工作过,具有不同的软件测试流程。在大多数情况下,没有特定/记录的测试方法......因此该过程的内容/方式取决于各个开发人员。与大多数情况一样,当没有强制执行或至少记录在案的政策时&#xf…

齐护K210系列教程(七)_LCD显示数据

LCD显示数据 文章目录 LCD显示数据1,显示英文2,显示传感器的数值3,显示中文4,课程资源 联系我们 LCD的最大分辨率为320*240,所以当我们设置文字或图像坐标时,后面要记住这一点,当然,…

如何将web content项目导入idea并部署到tomcat

将Web Content项目导入IntelliJ IDEA并部署到Tomcat主要涉及以下几个步骤: 1. 导入Web Content项目 打开IntelliJ IDEA。选择“File” -> “New” -> “Project from Existing Sources…”。浏览到你的Web Content项目的文件夹,并选择它。Intell…

QA的成长之路——深入测试的奇妙之旅

引言 功能测试的小伙伴,你们是否遇到过这些问题: 1、工作中重复性很高:尽管尽可能地让一个 case 覆盖更多场景,但仍有许多重复性 case,耗费大量时间,让人感到枯燥疲惫; 2、覆盖度不全&#x…

Bitmap 原理简述

之前写过一篇 bitmap 应用场景的文章https://blog.csdn.net/maray/article/details/136923316 本文介绍 bitmap 的原理: 下面有三张表:user_info_base, user_prefer, user_device,我们希望查询“喜欢电子产品并且使用iPhone的女性用户”&…

食用油5G智能工厂数字孪生可视化平台,推进食品制造业数字化转型

食用油5G智能工厂数字孪生可视化平台,推进食品制造业数字化转型。在食用油产业中,数字化转型已成为提升生产效率、优化供应链管理、确保产品质量和满足消费者需求的关键。食用油5G智能工厂数字孪生可视化平台作为这一转型的重要工具,正在推动…

数据结构之顺序表(java版)

目录 一.线性表 1.1线性表的概念 二.顺序表 2.1顺序表的概念 2.2顺序表的实现 1.顺序表的接口 1.2顺序表的功能实现 1.顺序表初始化 2.新增元素功能: 3.清空顺序表是否为空&&获取顺序表长度&&打印顺序表: 4.判断是否包含某个…

关于开设YOLOv8专栏及更新内容的一些说明

​ 专栏介绍:YOLOv9改进系列 | 包含深度学习最新创新,助力高效涨点!!! 专栏介绍 ⭐后期更新包含模块、卷积、检测头、损失等改进,目前已有70!现在入手仅$ 69.9,早入早发论文!⭐ ⭐…

【前端技术】HTML基础入门篇

1.1 HTML简介 ​ HTML(HyperText Markup Language:超文本标记语言)是一种标识性的语言。它包括一系列标签.通过这些标签可以将网络上的文档格式统一,使分散的Internet资源连接为一个逻辑整体。HTML文本是由HTML命令组…

uView u-parse 在nvue页面中无作用踩坑

问题起因: 在uni-app开发的app nvue页面中有需要回显渲染字符串形式的富文本内容 但使用v-html和uniapp的rich-text组件都无法起到作用,就想到了使用uView中u-parse进行尝试。 uView我是使用uniApp插件市场导入的方式将插件导入项目的uni_modules中 …

2024年教你学浪视频抓取#小浪助手

在2024年,学浪平台已经成为学习者们追逐知识、获取学习资源的热门平台之一。然而,尽管学习平台提供了丰富多样的学习内容,但有时候我们还是希望能够将这些学习资源下载下来,以便随时随地进行学习。那么,如何学习学浪视…

【layoutlmv3推理】无法识别的pdf使用ocr识别代码demo实例

目录 前情提要一、安装依赖1、直接安装的依赖2、需要编译的依赖1)Leptonica2)icu3)Tesseract 3、需要自行配置的依赖 二、模型下载三、更改transformers源码四、加载光学字符识别语言包五、运行代码 前情提要 在做pdf转文本时,发…

用于割草机器人,商用服务型机器人的陀螺仪

介绍一款EPSON推出适用于割草机器人,商用服务型机器人的高精度陀螺仪模组GGPM61,具体型号为GGPM61-C01。模组GGPM61是一款基于QMEMS传感器的低成本航向角输出的传感器模组,它可以输出加速度、角速度及姿态角等信息,为控制机器人运…

航空业微服务架构中台的构建与实践

随着航空业的快速发展,航空公司需要面对更加复杂的业务环境和客户需求。在这样的背景下,构建一个稳健、高效的微服务架构中台成为了航空公司的当务之急。本文将探讨航空业微服务架构中台的设计理念、关键技术以及实践经验,帮助航空公司构建具…

「Java开发指南」如何利用MyEclipse启用Spring DSL?(二)

本教程将引导您通过启用Spring DSL和使用Service Spring DSL抽象来引导Spring和Spring代码生成项目,本教程中学习的技能也可以很容易地应用于其他抽象。在本教程中,您将学习如何: 为Spring DSL初始化一个项目创建一个模型包创建一个服务和操…

面向多源异质遥感影像地物分类的自监督预训练方法

源自:测绘学报 作者:薛志祥, 余旭初, 刘景正, 杨国鹏, 刘冰, 余岸竹, 周嘉男, 金上鸿 摘 要 近年来,深度学习改变了遥感图像处理的方法。由于标注高质量样本费时费力,标签样本数量不足的现实问题会严重影响深层神经网络模型的性能。为解决这一突出矛盾…
最新文章