pytorch之数据:pack_padded_sequence()与pad_packed_sequence()-程序员宅基地

技术标签: pytorch  

前言

可以结合最下面的例子来理解

pack_padded_sequence()pad_packed_sequence()这两个函数属于torch.nn.utils.rnn,很明显,意义就是为了rnn包来处理数据的。前者pack用于压紧数据,处理经过填充(padded)后的数据;后面pad用于解压数据,把原来咋填充的给你咋释放回去。

一. 官方+理解

1. pack_padded_sequence
'官方函数'
torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False) → PackedSequence
↓'返回'
return   一个PackedSequence对象

功能:这里的pack,理解成压紧比较好,压紧被填充过的数据。 将一个 填充过的变长序列 压紧。返回PackedSequence对象。

重点

  • 输入的数据input可以是[Seq_max, batch_size, *]格式,且Seq_max是序列的最大长度。且input必须按序列长度的长短排序,长的在前面,短的在后面,第一个时间步的数据必须是最长的数据。
  • input:我们要压缩的数据。当batch_firstFalse时候,shape的输入格式是[B,S *],其中Bbatch_size,S是seq_len(该batch中最长序列的长度),*可以是任何维度。如果batch_firstTrue时候,相应的的数据格式必须是[S,B,*]。关于数据格式转置参考
  • lengths:输入数据的每个序列的长度。
  • batch_first:当为True,数据格式必须[B, S, *],反之,默认是False
2. pad_packed_sequence
torch.nn.utils.rnn.pad_packed_sequence(sequence, batch_first=False)tuple'返回'
return (sequence_pad , list) # 这个元祖包含被填充后的序列 , 和batch中序列的长度列表。

功能:上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来,默认按返回中list最大的数字填充。

重点

  • sequence:将要被填充的batch,是一个PackedSequence对象。
  • batch_first:作用同上pack_padded_sequence的batch_size,但是影响的是输出的数据格式。

二. 结合例子。

注意上述官方+理解,我们明白了输入的数据一定是要经过排序的

  1. 准备例子数据
import torch

batch_size = 3   # 这个batch有3个序列
max_len = 6       # 最长序列的长度是6
embedding_size = 8 # 嵌入向量大小8
hidden_size = 16   # 隐藏向量大小16
vocab_size = 20    # 词汇表大小20

input_seq = [[3, 5, 12, 7, 2, ], [4, 11, 14, ], [18, 7, 3, 8, 5, 4]]
lengths = [5, 3, 6]   # batch中每个seq的有效长度。
# embedding
embedding = torch.nn.Embedding(vocab_size, embedding_size, padding_idx=0)
# GRU的RNN循环神经网络
gru = torch.nn.GRU(embedding_size, hidden_size)
  1. 排序数据
'由大到小排序'
input_seq = sorted(input_seq, key = lambda tp: len(tp), reverse=True)
lengths = sorted(lengths, key = lambda tp: tp, reverse=True)
'''
outputs:
input_seq: [[18, 7, 3, 8, 5, 4], [3, 5, 12, 7, 2], [4, 11, 14]]
lengths : [6, 5, 3]
'''
  1. 填充数据
PAD_token = 0 # 填充下标是0
def pad_seq(seq, seq_len, max_length):
	seq = seq
	seq += [PAD_token for _ in range(max_length - seq_len)]
	return seq

pad_seqs = []  # 填充后的数据
for i,j in zip(input_seq, lengths):
	pad_seqs.append(pad_seq(i, len_i, max_len))
'''
填充后数据
pad_seqs : [[18, 7, 3, 8, 5, 4], [3, 5, 12, 7, 2, 0, 0], [4, 11, 14, 0, 0, 0, 0, 0, 0]]
'''
  1. 使用pack和pad函数
pad_seqs = torch.tensor(pad_seqs)
embeded = embedding(pad_seqs)

# 压缩,设置batch_first为true
pack = torch.nn.utils.rnn.pack_padded_sequence(embeded, lengths, batch_first=True)
'这里如果不写batch_first,你的数据必须是[s,b,e],不然会报错lenghth错误'

# 利用gru循环神经网络测试结果
state = None
pade_outputs, _ = gru(pack, state)
# 设置batch_first为true;你可以不设置为true,为false时候只影响结构不影响结果
pade_outputs, others = torch.nn.utils.rnn.pad_packed_sequence(pade_outputs, batch_first=True)

# 查看输出的元祖
print(pade_outputs.shape) 'torch.Size([3, 6, 16])'
print(others) 'tensor([6, 5, 3])' 

到这里,大家基本知道如何使用这个函数了。

总结

这个函数就是用到填充tensor数据和解压数据的,进行循环神经网络的运算。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/xinjieyuan/article/details/108562360

智能推荐

Java知识总结-基础

** * 外部内、内部类 */ public class Outer { public static IAnimal getInnerInstance(String speak){ return new IAnimal(){ @Override public void speak(){ System.out.println(speak);当程序第一次引用该类的静态成员时,就会触发这个类的加载。

Apache Kylin Buid Cube详细流程_kylin buid cuboid的时候每一步都要等待2分钟-程序员宅基地

文章浏览阅读506次。Build Cube流程主要分为四个阶段:根据用户的cube信息计算出多个cuboid文件根据cuboid文件生成htable更新cube信息回收临时文件1.流程一:作业整体描述把构建Cube的来源表总行数写到指定的HDFS文件中2.流程二:生成中间临时数据这一步的操作是根据Cube设计中的定义生成原始数据,这里会新创建一个Hive外部表,然后再_kylin buid cuboid的时候每一步都要等待2分钟

团队管理视角-程序员宅基地

文章浏览阅读779次。一个管理者要带团队有三重视角:第一个视角是管理者,第二个视角是教练员,第三个视角是指挥员。管理者视角管理者视角,最常规的视角。比如效率和质量如何衡量?如何分解和分配任务?项目进展怎么样了?进度怎么估算?有没有瓶颈,瓶颈在哪?根因是什么?绩效怎么考核?而今天我们要讨论的是管理者的首要目标——求生存。如果一个团队在公司里没有价值了,那么整个团队都会被裁掉。所以,作为管理者最重要的是先能生存下来,证明自己是能胜任的,然后证明自己的团队是能胜任的。证明自己理解何为胜任,就是领导布置任.._管理视角

什么场景要使用策略模式,什么场景不能使用?_java 策略模式 什么情况不适合-程序员宅基地

文章浏览阅读916次。需完整版面试文档扫描左侧二维码拿!滴,老年卡;滴,学生卡;滴正常卡。我们在坐公交车的时候啊,这个场景每天都在上演。那如果,让你来设计这样一套刷卡的结算逻辑,你最先想到的是用什么设计模式呢?如果,让我来设计,我最先想到的就是策略模式。另外,我把往期面试题解析的配套文档我已经准备好,想获得的可以在我的煮叶简介中找到。那么什么场景要使用策略模式,什么场景又不应该使用策略模式呢?我们可以先来看官方对策略模式的定义。1、官方定义官方原文是:Define a family of algor_java 策略模式 什么情况不适合

【解决问题】:fatal error C1034: iostream: 不包括路径集-程序员宅基地

文章浏览阅读8.1k次,点赞13次,收藏20次。在VS2019的环境变量配置好cl.exe系统环境后报错fatal error C1034: iostream: 不包括路径集【解决方法】此电脑电脑右键->属性->高级系统设置->环境变量->系统变量->新建INCLUDE编辑环境变量五个用英文分号隔开如图保存之后可能报错:fatal error LNK1104: 无法打开文件“libcpmt.lib【解决方法】此电脑电脑右键->属性->高级系统设置->环境变量->系统变量->新_fatal error c1034: iostream: 不包括路径集

word 的使用 —— 快捷键(分节符 分页符 分栏符)_word分节快捷键-程序员宅基地

文章浏览阅读1.5w次。word 的使用 —— 快捷键(分节符 分页符 分栏符)_word分节快捷键

随便推点

信息安全风险评估---矩阵法计算风险_威胁程度计算-程序员宅基地

文章浏览阅读1.6w次,点赞16次,收藏29次。 矩阵法计算风险假设:有以下信息系统中资产面临威胁利用脆弱性的情况:共有两项重要财产:资产A1和资产A2;资产A1面临一个主要威胁T1;资产A2面临两个主要威胁T2,T3;威胁T1可以利用资产A1存在的两个..._威胁程度计算

《SoC设计方法与实现》(1)_soc设计方法与实现 epub-程序员宅基地

文章浏览阅读347次。SOC(System On Chip)即系统级芯片,又称片上系统,其将系统的主要功能综合到一块芯片中,本质上是在做一种复杂的IC设计。现在的SOC芯片上可整体实现CPU、DSP、数字电路、模拟电路、存储器、片上可编程逻辑阵列等多种电路,综合实现图像处理、语音处理、通信协议、通信机能、数据处理等功能。SOC的优势有:可以实现更为复杂的系统、具有较低的设计成本、具有更高的可靠性、缩短产品设计时间、减少产品反复的次数、可以满足更小尺寸的设计要求、可达到低功耗的设计要求。_soc设计方法与实现 epub

Linux内核模块动态添加方法_linux 将moudle动态加入内核-程序员宅基地

文章浏览阅读391次。Linux内核模块动态添加方法 今天下午通过一番折腾,终于琢磨除了Linux内核模块的动态加载方法,网上大部分教程基于旧版本做的,有很多地方不一样,走了很多弯路,不过最后终于成功了,方法如下:1、建立C++源文件,假设文件目录为path,文件名为hello.c源代码如下:#include #include

【Matlab】图像裁剪函数imcrop的原点、长度、宽度问题_im1.crop函数-程序员宅基地

文章浏览阅读2.9w次,点赞12次,收藏36次。【Matlab】图像裁剪函数imcrop的原点、长度、宽度问题[toc] Matlab中,函数imcrop用来裁剪图像,但有几个问题要探讨一下。 先说imcrop的简单用法:I2 = imcrop(I,RECT)I代表原图,RECT是裁剪区域。 RECT的形式是这样的:[XMIN YMIN WIDTH HEIGHT]问:区域RECT的原点是怎么定义的,或者说在哪?区域RECT的长度和宽度,我_im1.crop函数

[转]内嵌WORD/OFFICE的WINFORM程序——DSOFRAMER使用小结-程序员宅基地

文章浏览阅读408次。最近一直想用VC#2005做个内嵌WORD/OFFICE的WINFORM程序,目前主要有以下解决途径:1、直接通过API把WORD/OFFICE的窗口句柄给放到WINFORM中(感觉较为复杂);2、通过WEB BROWSER;3、利用DSOFRAMER。本人都测试了一下,觉得DSOFRAMER更符合自己的愿望,故决定使用DSOFRAMER来实现。操作步骤:WinForm..._vs dso framer control object

【Tensorflow】读取TFRecord文件时,Image和Label无法一一对应_tensorflow label和data没有对齐-程序员宅基地

文章浏览阅读1.0k次。问题Image和Label数据成对写入TFRecord文件,按理训练过程中读取的Image和Label也应该是一一对应的,但有的时候发现Image和Label并不能匹配。如:将以下数据写入TFrecord中:Image 1 —— Label 1Image 2 —— Label 2Image 3 —— Label 3读取TFRecord时,数据发生错位:Image 1 —— Labe..._tensorflow label和data没有对齐