faster rcnn中损失函数(三)——理解faster-rcnn中计算rpn_loss_cls&rpn_loss_box的过程_Snoopy_Dream的博客-程序员资料

技术标签: pytorch  

首先感想来源与pytorch的rpn.py。

我们都知道,rpn通过制作lable和targe_ shift来构造rpn loss的计算。那具体是怎么构造的呢?


首先rpn_loss_cls计算:

我们应该首先想到的是: rpn_loss_cls = F.cross_entropy(rpn_cls_score, rpn_label)

维度分析

cross_entropy要求输入是Variable,预测的是2D,label是1D。

所以可以根据默认规定的初始的数据格式b,2*9,h,w进行推导。

rpn_cls_score:  b,2*9,h,w -> b*9*h*w,2   #二分类

然后去除掉不感兴趣的区域:

rpn_cls_score: (b*9*h*w - 标签-1的 ,2)    #二分类

rpn_label:(b*9*h*w - 标签-1的,)

 #return outputs [label ,target ,inside-weight ,outside_weight]
rpn_data = self.RPN_anchor_target((rpn_cls_score.data, gt_boxes, im_info, num_boxes))
rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2)# b 9h*w 2
rpn_label = rpn_data[0].view(batch_size, -1)#B 1 9*H W._>b,9*h*w

数据本身分析:

label包括:1 0 -1

首先需要做的是去除-1,即不感兴趣的目标

#~@@!通过的ne去除掉-1,返回非0的索引
rpn_keep = Variable(rpn_label.view(-1).ne(-1).nonzero().view(-1))#nonzero返回b*9h*w行1列,所以需要view变成一维
rpn_cls_score = torch.index_select(rpn_cls_score.view(-1,2), 0, rpn_keep)#从rpn_cls_score(b*9h*w,2)从第0轴按照rpn_keep索引找
rpn_label = torch.index_select(rpn_label.view(-1), 0, rpn_keep.data)#rpn_data上文就是tensor,不是Variable
rpn_label = Variable(rpn_label.long())#运算完后的输出再用Variable( Tensor.long())转换回来

注意:

rpn_score是 Variable 而rpn_label刚开始是tensor;

因为anchor_target_layer和prosal_layer.py不需要反向传播,了解他们的输入输出这一点很简单,他们本身就是生成rpn_label 等,做的事情是制定选出的规则,并没有对选出的东西进行计算,所以无需反向传播,所以里面的forwardinput都是Tensor输入的时候都需要 Variable.data, 运算完后的输出再用Variable( Tensor.long())转换回来

================================================

分析:

1. rpn_label.view(-1).ne(-1).nonzero().view(-1)

ne(-1)返回 是-1就返回0,不是-1,返回1

nonzero返回不是0的索引,n行1列的(n,1)

综上,返回不是-1的所有索引,列成1维数组(n,)

2. torch.index_select(rpn_cls_score.view(-1,2), 0, rpn_keep)#从 0维,按照rpn_keep索引,找rpn_cls_score.view(-1,2)

==================================================

得到rpn_loss_cls

self.rpn_loss_cls = F.cross_entropy(rpn_cls_score, rpn_label)#  (b*9*h*w,2)   (b*9*h*w,) 

至于rpn_loss_box的内容,可以具体差不多,思想理解就可以了,了解了rpn_loss_box的输入和输出就好了。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/e01528/article/details/84026513

智能推荐

IOS开发之----异常处理_沸腾的泪水05314的博客-程序员资料

转载自:http://blog.sina.com.cn/s/blog_71715bf8010166qf.html开篇大话: Object-C语言的异常处理符号和C++、JAVA相似。再加上使用NSException,NSError或者自定义的类,你可以在你的应用程序里添加强大的错误处理机制。异常处理机制是由这个四个关键字支持的:@try,@catch,@thorw,@finally。当代码

node相关报错问题_gyp info using [email protected]_lily1346891的博客-程序员资料

问题一:node-sass npm ERR! command failed解决:1、删除 npm uninstall node-sass2、安装 npm install node-sass问题二:npm ERR! gyp info it worked if it ends with oknpm ERR! gyp info using [email protected] ERR! gyp info using [email protected] | win32 | x64npm ERR! gyp ERR!

Android 连接USB默认选中MTP模式_Just_Paranoid的博客-程序员资料

Android 连接USB默认选中MTP模式需求分析Android系统默认连接USB会显示:正在通过USB为此设备充电,并且无法在电脑查看存储内容。需要实现的效果:Android 连接USB默认选中MTP模式,连接USB显示:正在通过USB传输文件,选择USB的使用方式的弹框下MTP模式为选中状态,并且可以在电脑端可以访问和写入存储空间。解决方案diff --git a/frameworks/base/services/usb/java/com/android/server/usb/UsbD

小机器人5岁了!细数Android甜点史_Ronys的博客-程序员资料

2012-11-05 作者: 出处:互联网 责编:联想Yoga分期付款月供279元五年前的11月5日,谷歌不仅宣布成立“开放手机联盟”(Open Handset Alliance),表示要帮助创建移动通信的开放标准,而且推出了Android平台——一个基于Linux的智能手机平台。以下是一篇简短的图文介绍,回顾了谷歌手机操作系统的发展。转

【C】函数指针_November's chopin的博客-程序员资料

案例环境代码#include <stdio.h>int max(int x, int y){ return x > y ? x : y;}int main(){ int (* p)(int ,int) = & max; int a, b, c, d; printf("请输入三个整数:\n"); scanf("%...

良好性能和高质量视觉效果_圣空老宅的博客-程序员资料

译者:赵菁菁(轩语轩缘)  审校:李笑达(DDBC4747)对于任何追求UE4性能最佳、同时又想保持极高质量视觉效果的人来说,本文有一些可遵循的一般性建议和原则。 局限性为了性能,你通常受CPU时间(通常和游戏设置相关)和GPU时间限制(渲染场景花费的时间)。CPU创建由GPU渲染的场景会耗费一些时间。通常情况下,当你发现游戏的运行速度不像你想要的那么快时,第一步是找出

随便推点

java面试(进阶四篇)解答_恐龙弟旺仔的博客-程序员资料

题目来自于网络,答案是笔者整理的。仅供参考,欢迎指正来源: https://mp.weixin.qq.com/s?__biz=MzI1NDQ3MjQxNA==&mid=2247485779&idx=1&sn=3b06b9923df7f40f887ead8b8a53e50e&chksm=e9c5f0e2deb279f47fbfc3a12a70896bf95fa8c...

mac中安装git并忽略.DS_Store_mac .gitignore 忽略ds__q2826621520的博客-程序员资料

一Homebrew安装git1.安装 Homebrew/usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"2.安装gitbrew install git二xcode安装git1.在mac终端中直接输入git.提示是否安装开发者...

执行main方法时出现java.lang.SecurityException异常_live_and_learn_CC的博客-程序员资料

1.执行main方法时弹出 Error: A JNI error has occurred, please check your installation and try againe2.执行后控制台报的错误3.进入ClassLoader.java中,4.原因:在开始执行main方法时就已经加载了以java开头的包路径,所有类加载器在加载文件时会抛出异常5.解决方法:改包...

【opencv】goodFeaturesToTrack源码分析-2-Shi-Tomasi角点检测_Denny#的博客-程序员资料

本文章是【opencv】goodFeaturesToTrack源码分析-1的后续,主要描述Shi-Tomasi角点检测算法原理及opencv实现。1、算法原理Shi-Tomasi算法是Harris算法的改进,在Harris算法中,是根据协方差矩阵M的两个特征值的组合来判断是否角点。而在Shi-Tomasi算法中,是根据较小的特征值是否大于阈值来判断是否角点。 这个判断依据是:较小的特征值表示在该特

tensorflow源码例子mnist源码——mnist.py_修炼打怪的小乌龟的博客-程序员资料

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# Y