Pytorch_学习_cnn_多尺度 cnn pytorch-程序员宅基地

技术标签: pytorch  

直接代码:

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt


# Hyper parameters
torch.manual_seed(1)
EPOCH = 1  #训练一批次
BATCH_SICE = 50  # 一次50个
LR = 0.001  # 每次的学习率
DOWNLOAD_MNSIT = False  # 需要下载为True, 下载好了可以设置为False

# download data
train_data = torchvision.datasets.MNIST(
    root='./mnist/',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNSIT
)

test_data = torchvision.datasets.MNIST(
    root='./mnist/',
    train=False
)

# train_loader
train_loader = Data.DataLoader(
    dataset=train_data,
    batch_size=BATCH_SICE,
    shuffle=True
)

# ready_for_test 前20000个
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.Tensor)[:2000]/255.
test_y = test_data.test_labels[:2000]
# print(test_x.size())
# print(test_y.size())
# print(test_x[1])
# print(test_y[1])

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1  = nn.Sequential(     # (1, 28, 28)
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2
            ),             # output shape : N = (W - F + 2P)/S + 1
            # W :图片大小 W*W 28, F:filter大小 F*F,5, S:步长, 1  P:padding 2
            # N = (28 - 5 + 2*2)/1 + 1 = 28 因为有16个filter
            # (16, 28, 28)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
            # Maxpooling 2*2 向下采样取最大(想想如果4*4图片,采用2*2maxpooling,就变成了2*2图片)
            # 所以这里图片大小为 (16, 14, 14)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),  # (32, 14, 14)
            nn.ReLU(),
            nn.MaxPool2d(2)  # (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)  # 展平多维的卷积图成 (batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output

cnn = CNN()
print(cnn)



# training
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()


# training and testing
for epoch in range(EPOCH):
    for step, (b_x, b_y) in enumerate(train_loader):   # 分配 batch data, normalize x when iterate train_loader
        output = cnn(b_x)               # cnn output
        loss = loss_func(output, b_y)   # cross entropy loss
        optimizer.zero_grad()           # clear gradients for this training step
        loss.backward()                 # backpropagation, compute gradients
        optimizer.step()                # apply gradients

        if step % 50 == 0:
            print('train loss = ', loss.data.numpy())

 

 

结果

CNN(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (out): Linear(in_features=1568, out_features=10, bias=True)
)
train loss =  2.3104544
train loss =  0.61840093
train loss =  0.12703253
train loss =  0.23725168
train loss =  0.4044273
train loss =  0.08451731
train loss =  0.19528298
train loss =  0.109065436
train loss =  0.123532385
train loss =  0.06730688
train loss =  0.2240683
train loss =  0.2114728
train loss =  0.024014007
train loss =  0.08469809
train loss =  0.21586336
train loss =  0.10181876
train loss =  0.043114547
train loss =  0.09106462
train loss =  0.055737924
train loss =  0.10089029
train loss =  0.032855053
train loss =  0.021929108
train loss =  0.025414439
train loss =  0.117736794

Process finished with exit code 0

 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/sinat_15355869/article/details/86552617

智能推荐

cd如何省略空格 linux_如何用Linux 终端指令打开带有空格或特殊符号的目录-程序员宅基地

文章浏览阅读395次。在Linux下使用终端指定打开文件夹,提示目录名称中包含语法错误,此时可以有两个选择:1、按照Linux推荐的文件命名规范,对文件夹名进行修改;2、使用转义符 \;3、成对使用双引号 "";Linux文件命名规范简介Linux系统区分英文字符的大小写。命名目录和命名文件的规则是相同的。除非有特别的原因否则用户创建的文件和目录名要使用小写字符。大多数的Linux命令也使用小写字符。Linux系统下..._linux指令cd进到指定目录,目录有特殊字符

Calendar_import calendar-程序员宅基地

文章浏览阅读161次。Calendar代码:public class Demo01Calendar {public static void main(String[] args) { Calendar c=Calendar.getInstance();//多态。//获取日历信息 System.out.println(c); System.out.println(c.get(Calendar..._import calendar

【漫谈】人工智能那些事儿:还很弱小,但怀期待-程序员宅基地

文章浏览阅读186次。关注:决策智能与机器学习,一篇文章Get一个知识点回顾自从1956年达特茅斯人工智能夏季研讨会首次提出“人工智能”的概念以来,两起两落,到现在第三次浪潮,人工智能已经可以算是完成特定任...

tensorflow/stream_executor/cuda/cuda_dnn.cc:378] Loaded runtime CuDNN library: 7301--2019.5.12-程序员宅基地

文章浏览阅读7.9k次,点赞2次,收藏5次。安装的cudnn的版本是7.1.0.3,而要求的cudnn版本是7.3.0.0。将tensorflow版本从1.5换成1.8,顺利运行程序(升级tensorflow版本来解决)ll 命令查看 连接 /usr/local/cuda/lib64下 把对应的 libcudnn.so.7,3,1连到 libcudnn.so.7 在连到libcudnn.sosudo ln -sf li..._tensorflow/stream_executor/cuda/cuda_dnn.cc:378] loaded runtime cudnn librar

SQL——limit和offset的用法_sql limit-程序员宅基地

文章浏览阅读2k次。【代码】SQL——limit和offset的用法。_sql limit

你并不在意的 HTTPS 证书吊销机制,或许会给你造成灾难性安全问题!-程序员宅基地

文章浏览阅读1.5k次。缘起偶刷《长安十二时辰》,午睡时,梦到我穿越到了唐朝,在长安城中的靖安司,做了一天的靖安司司丞。当徐宾遇害消失的时候我不在司内,当时的情形我不得而知。后来徐宾醒了,据他描述说“通传陆三”..._enable_ocsp_must_staple

随便推点

使用NFS自动挂载文件系统_nfs设置账号密码挂载-程序员宅基地

文章浏览阅读186次。使用NFS自动挂载文件系统便利性大大增强,首先连接网线,传输数据,此外保证服务器和单板能够ping的通。单板上电(开发板已经烧录有u-boot和内核),任意键中止boot进程,q退出菜单,输入print,更改“bootargs=noinitrd root=/dev/mtdblock3 init=/linuxrc console=ttySAC0,115200”这一句。使用 set bootarg..._nfs设置账号密码挂载

使用JAVA获取KAFKA中指定TOPIC的OFFSET-程序员宅基地

文章浏览阅读1.4k次。2019独角兽企业重金招聘Python工程师标准>>> ..._java 通过kafka jmx mbeanserverconnection 获取offset

数论 - 容斥原理-程序员宅基地

文章浏览阅读291次,点赞2次,收藏2次。在计数时,必须注意没有重复,没有遗漏。为了使重叠部分不被重复计算,人们研究出一种新的计数方法,这种方法的基本思想是:先不考虑重叠的情况,把包含于某内容中的所有对象的数目先计算出来,然后再把计数时重复计算的数目排斥出去,使得计算的结果既无遗漏又无重复,这种计数的方法称为容斥原理

goto的应用举例及详解_goto使用-程序员宅基地

文章浏览阅读4.1k次,点赞16次,收藏12次。从理论上 goto语句是没有必要的,实践中没有goto语句也可以很容易的写出代码。但是某些场合下goto语句还是用得着的,最常见的用法就是终止程序在某些深度嵌套的结构的处理过 程。从上我们可以看出,goto语句真正适合的引用场景其实就是:当我们写了很多for循环时,我们需要写很多个break来跳出来for循环时,我们可以直接用goto语句来跳出for循环。综上,我们需要了解goto语句就行,但是goto语句的应用场景不是很多,提及的时候我们还是要会运用。下面我们来简单的写一个有趣的关机小程序。_goto使用

Win10系统下怎么将普通账户设置为管理员账户_更改账户类型为管理员灰色-程序员宅基地

文章浏览阅读9.8k次,点赞2次,收藏14次。在win10系统中,很多用户会新建用户来使用,但是会发现新建的用户只是普通用户,导致在安装软件的时候没有管理员账户权限无法安装,那么要怎么将普通账户设置为管理员账户呢?、然后选择其它人员,然后点击要要设置的账户,点击更改账户类型按钮;新帐户就是管理员权限了,注销之后登录即可。、设置为管理员,然后确定。以原来的管理员账户登录;、登录之后点击开始菜音,..._更改账户类型为管理员灰色

数据异常解决方法汇总_数据异常矫正的方法-程序员宅基地

文章浏览阅读1w次。文章目录Step1:积极与需求方沟通Step2:将问题进行树枝细化,直至最小单元Step 3. 基于最小单元,梳理相关因素,进行猜想验证Step 4. 测算每个因素对结果的“贡献度”碰到实在分析不出原因的数据异常怎么办?本文转载自公众号:数据分析师成长记录Step1:积极与需求方沟通数据异常很大一部分原因是自身对问题的理解与需求方意图不一致导致的,所以需要积极与需求方沟通,从以下方面依次进行排查问题:数据口径不一致等理解差异;数据源更新延迟等数仓侧原因;数据未上报/未采集等开发侧原因;St_数据异常矫正的方法

推荐文章

热门文章

相关标签