EfficientDet 训练自己的数据集_efficientdet训练自己的数据-程序员宅基地

技术标签: python  深度学习  efficientDet  pytorch  json  

EfficientDet训练自己的数据集

项目安装

参考代码:https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch
安装及环境配置可参考作者介绍或者其他博客

数据准备

训练时需要将数据集转换为coco格式的数据集,本人使用的数据集为visdrone数据集,转换过程如下:txt->XML->coco.json

txt->XML

import os
from PIL import Image

# 把下面的路径改成你自己的路径即可
root_dir = "./VisDrone2019-DET-train/"
annotations_dir = root_dir+"annotations/"
image_dir = root_dir + "images/"
xml_dir = root_dir+"Annotations_XML/"
# 下面的类别也换成你自己数据类别,也可适用于其他的数据集转换
class_name = ['ignored regions','pedestrian','people','bicycle','car','van','truck','tricycle','awning-tricycle','bus','motor','others']

for filename in os.listdir(annotations_dir):
    fin = open(annotations_dir+filename, 'r')
    image_name = filename.split('.')[0]
    img = Image.open(image_dir+image_name+".jpg") # 若图像数据是“png”转换成“.png”即可
    xml_name = xml_dir+image_name+'.xml'
    with open(xml_name, 'w') as fout:
        fout.write('<annotation>'+'\n')
        
        fout.write('\t'+'<folder>VOC2007</folder>'+'\n')
        fout.write('\t'+'<filename>'+image_name+'.jpg'+'</filename>'+'\n')
        
        fout.write('\t'+'<source>'+'\n')
        fout.write('\t\t'+'<database>'+'VisDrone2018 Database'+'</database>'+'\n')
        fout.write('\t\t'+'<annotation>'+'VisDrone2018'+'</annotation>'+'\n')
        fout.write('\t\t'+'<image>'+'flickr'+'</image>'+'\n')
        fout.write('\t\t'+'<flickrid>'+'Unspecified'+'</flickrid>'+'\n')
        fout.write('\t'+'</source>'+'\n')
        
        fout.write('\t'+'<owner>'+'\n')
        fout.write('\t\t'+'<flickrid>'+'Haipeng Zhang'+'</flickrid>'+'\n')
        fout.write('\t\t'+'<name>'+'Haipeng Zhang'+'</name>'+'\n')
        fout.write('\t'+'</owner>'+'\n')
        
        fout.write('\t'+'<size>'+'\n')
        fout.write('\t\t'+'<width>'+str(img.size[0])+'</width>'+'\n')
        fout.write('\t\t'+'<height>'+str(img.size[1])+'</height>'+'\n')
        fout.write('\t\t'+'<depth>'+'3'+'</depth>'+'\n')
        fout.write('\t'+'</size>'+'\n')
        
        fout.write('\t'+'<segmented>'+'0'+'</segmented>'+'\n')

        for line in fin.readlines():
            line = line.split(',')
            fout.write('\t'+'<object>'+'\n')
            fout.write('\t\t'+'<name>'+class_name[int(line[5])]+'</name>'+'\n')
            fout.write('\t\t'+'<pose>'+'Unspecified'+'</pose>'+'\n')
            fout.write('\t\t'+'<truncated>'+line[6]+'</truncated>'+'\n')
            fout.write('\t\t'+'<difficult>'+str(int(line[7]))+'</difficult>'+'\n')
            fout.write('\t\t'+'<bndbox>'+'\n')
            fout.write('\t\t\t'+'<xmin>'+line[0]+'</xmin>'+'\n')
            fout.write('\t\t\t'+'<ymin>'+line[1]+'</ymin>'+'\n')
            # pay attention to this point!(0-based)
            fout.write('\t\t\t'+'<xmax>'+str(int(line[0])+int(line[2])-1)+'</xmax>'+'\n')
            fout.write('\t\t\t'+'<ymax>'+str(int(line[1])+int(line[3])-1)+'</ymax>'+'\n')
            fout.write('\t\t'+'</bndbox>'+'\n')
            fout.write('\t'+'</object>'+'\n')
             
        fin.close()
        fout.write('</annotation>')

XML->coco.json

    # coding=utf-8
import xml.etree.ElementTree as ET
import os
import json


voc_clses = ['aeroplane', 'bicycle', 'bird', 'boat',
    'bottle', 'bus', 'car', 'cat', 'chair',
    'cow', 'diningtable', 'dog', 'horse',
    'motorbike', 'person', 'pottedplant',
    'sheep', 'sofa', 'train', 'tvmonitor']


categories = []
for iind, cat in enumerate(voc_clses):
    cate = {
    }
    cate['supercategory'] = cat
    cate['name'] = cat
    cate['id'] = iind
    categories.append(cate)

def getimages(xmlname, id):
    sig_xml_box = []
    tree = ET.parse(xmlname)
    root = tree.getroot()
    images = {
    }
    for i in root:  # 遍历一级节点
        if i.tag == 'filename':
            file_name = i.text  # 0001.jpg
            # print('image name: ', file_name)
            images['file_name'] = file_name
        if i.tag == 'size':
            for j in i:
                if j.tag == 'width':
                    width = j.text
                    images['width'] = width
                if j.tag == 'height':
                    height = j.text
                    images['height'] = height
        if i.tag == 'object':
            for j in i:
                if j.tag == 'name':
                    cls_name = j.text
                cat_id = voc_clses.index(cls_name) + 1
                if j.tag == 'bndbox':
                    bbox = []
                    xmin = 0
                    ymin = 0
                    xmax = 0
                    ymax = 0
                    for r in j:
                        if r.tag == 'xmin':
                            xmin = eval(r.text)
                        if r.tag == 'ymin':
                            ymin = eval(r.text)
                        if r.tag == 'xmax':
                            xmax = eval(r.text)
                        if r.tag == 'ymax':
                            ymax = eval(r.text)
                    bbox.append(xmin)
                    bbox.append(ymin)
                    bbox.append(xmax - xmin)
                    bbox.append(ymax - ymin)
                    bbox.append(id)   # 保存当前box对应的image_id
                    bbox.append(cat_id)
                    # anno area
                    bbox.append((xmax - xmin) * (ymax - ymin) - 10.0)   # bbox的ares
                    # coco中的ares数值是 < w*h 的, 因为它其实是按segmentation的面积算的,所以我-10.0一下...
                    sig_xml_box.append(bbox)
                    # print('bbox', xmin, ymin, xmax - xmin, ymax - ymin, 'id', id, 'cls_id', cat_id)
    images['id'] = id
    # print ('sig_img_box', sig_xml_box)
    return images, sig_xml_box



def txt2list(txtfile):
    f = open(txtfile)
    l = []
    for line in f:
        l.append(line[:-1])
    return l


# voc2007xmls = 'anns'
voc2007xmls = '/data2/chenjia/data/VOCdevkit/VOC2007/Annotations'
# test_txt = 'voc2007/test.txt'
test_txt = '/data2/chenjia/data/VOCdevkit/VOC2007/ImageSets/Main/test.txt'
xml_names = txt2list(test_txt)
xmls = []
bboxes = []
ann_js = {
    }
for ind, xml_name in enumerate(xml_names):
    xmls.append(os.path.join(voc2007xmls, xml_name + '.xml'))
json_name = 'annotations/instances_voc2007val.json'
images = []
for i_index, xml_file in enumerate(xmls):
    image, sig_xml_bbox = getimages(xml_file, i_index)
    images.append(image)
    bboxes.extend(sig_xml_bbox)
ann_js['images'] = images
ann_js['categories'] = categories
annotations = []
for box_ind, box in enumerate(bboxes):
    anno = {
    }
    anno['image_id'] =  box[-3]
    anno['category_id'] = box[-2]
    anno['bbox'] = box[:-3]
    anno['id'] = box_ind
    anno['area'] = box[-1]
    anno['iscrowd'] = 0
    annotations.append(anno)
ann_js['annotations'] = annotations
json.dump(ann_js, open(json_name, 'w'), indent=4)  # indent=4 更加美观显示               

将生成的json及图片按照一下结构放置,注意修改json文件名称:

  • dadasets
    • visdrone2019
      • train2019
      • val2019
      • annotations
        • instances_train2019.json
        • instances_val2019.json

修改projects下coco.yml内容,按照自己的数据库情况修改

project_name: visdrone2019  # also the folder name of the dataset that under data_path folder
train_set: train2019
val_set: val2019
num_gpus: 1

# mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco.
mean: [0.373, 0.378, 0.364]
std: [0.191, 0.182, 0.194]

# this is coco anchors, change it if necessary
anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]'
anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]'

# must match your dataset's category_id.
# category_id is one_indexed,
# for example, index of 'car' here is 2, while category_id of is 3
obj_list: ["pedestrian","people","bicycle","car","van","truck","tricycle","awning-tricycle","bus","motor"]

训练

python train.py -c 2 --batch_size 8 --lr 1e-5 --num_epochs 10
–load_weights /path/to/your/weights/efficientdet-d2.pth

提前下载model文件,放置在文件夹中,建议d0,d1,d2(大了显存会溢出),如出现显存溢出情况,调整batch_size大小。
在这里插入图片描述

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/gu891221/article/details/105783719

智能推荐

ClickHouse 存储原理初窥-程序员宅基地

文章浏览阅读750次。更多内容关注微信公众号:fullstack888背景目前业务中有大量实时分析需求,随着数据量的增加,基于行存储的 OLTP 数据库已经不能满足性能的需求,我们对 ClickHouse 进行了基础调研与性能摸底,并最终决定引入 ClickHouse 作为新系统的 OLAP 方案。简介ClickHouse 是一个列式存储数据库管理系统(DBMS)。相比于其他传统行式数据库系统..._clickhouse列存储原理

MySQL模糊查询-程序员宅基地

文章浏览阅读544次,点赞3次,收藏7次。%通配符可以匹配任意字符,但是不能匹配NULL,也就是说SELECT * FROM blog where title_name like '%';--模糊匹配含有“xxx网xxx车xxx”的数据,如“滴滴网约车司机端,网络约车平台”2.“_”下划线通配符:表示只能匹配单个字符,不能多也不能少,就是一个字符。--查询前三个字符为xx网,后面任意匹配,如:“城通网盘、模具网平台”--模糊匹配含义“xx网x车xxx”的数据,如:“携程网约车客户端”--查询以“网”为结尾的,长度为三个字的数据,如:“链家网”

OpenWrt介绍及编译基础教程_openwrt编译-程序员宅基地

文章浏览阅读1.2w次,点赞11次,收藏96次。编译 Open­Wrt 的过程就像是复读机,除了选择系统组件外,几乎每次编译都是复制粘贴相同的命令。而理解每一条命令的作用、什么时候该去执行,这样才能更好的去解决编译中遇到的问题,更顺利的编译出固件。_openwrt编译

springcloud排除错误记录,Dependency org.springframework.cloud:spring-cloud-starter-netflix-eureka-cli_dependency 'org.springframework.cloud:spring-cloud-程序员宅基地

文章浏览阅读3.2k次,点赞4次,收藏3次。2.执行以下命令,删除未下载成功的依赖,即是存在.lastUpdated后缀的文件(3.使用maven的Reload Project刷新,重新下载删除的依赖。1.先到maven仓库所在的目录并且在该目录打开cmd命令;网络不好情况下,导致产生许多依赖未成功下载。_dependency 'org.springframework.cloud:spring-cloud-starter-netflix-eureka-cl

工具说明书 - 英语翻译软件对比和英语词典选择_百度网页英语词典插件安装-程序员宅基地

文章浏览阅读1.5k次。1,百度翻译桌面版:200+语言。界面简洁清爽,操作方便。可以选择领域更好的翻译术语。翻译结果可以语音播放。APP版功能更强大,拍照翻译、语音翻译、权威词典。网页版,功能和桌面版类似,里面还有个视频翻译功能。在官方介绍页面,还有个百度同传工具。2,google翻译有网页版,带了语音翻译和朗读。手机上可以下载APP。PC端还可以下载客户端:Client for Google Translatehttp://translateclient.com3,网易有道翻译.._百度网页英语词典插件安装

Ubuntu16双网卡上网设置_ubuntu 多网卡 设置默认路由-程序员宅基地

文章浏览阅读2.6k次,点赞2次,收藏9次。文章目录前言route方法一. 删网关方法二. 增大Metric注意微信公众号前言现在有两个网络:公司的局域网192.168.3.x, 通过有线连接, 可以访问公司内网.无线连接的192.168.1.x, 可以访问外网.想要无线主力, 只有访问内网东西时采用有线, 常见做法是用到内网插有线(会自动顶掉无线), 否则拔掉网线用无线. 有没有简单点的办法?route先连上无线, 再插..._ubuntu 多网卡 设置默认路由

随便推点

【再见,2020】 走向开源,拥抱RT-Thread_rt-thread现状-程序员宅基地

文章浏览阅读296次。过往云烟话说进入RT-Thread,是在2019年12月份。因为自己的十几年的瞎折腾,在嵌入式开发中,并没有大放异彩。人生的跌宕起伏,也验证了普通技术工作者的心路历程。 偶尔静下心来,发现,自己掌握的技术,还不够全面,不够深入,不够新颖,更重要的,是没有一家公司可以让你全面的享受技术开发,可以做很多精彩的事情。 我不是科学家,不是技术专家,非名校毕业,非海归,我只是普通人。但是,我们在求知、求职、工作、生活、家庭、社会等等场合,都折腾着,努力的活着,每天都期待着。让我们的梦想,对得起自己过往的、现在_rt-thread现状

20暨南大学计算机考研经验知乎,暨大应统经验转自知乎-程序员宅基地

文章浏览阅读554次。1. 准备考研脑中要有一个思想:“书贵精不贵多,一本书读三遍胜过三本书读一遍。”某些答案推荐了一大堆的书,看见那么多的书你怕不怕?直接吓的放弃考研。就算不怕,你那些书打算做几遍?每本书都做三遍那要花费多少时间?2. 关于对我\"考研从六月份开始准备就足够了\"言论的误读。考研以大三结束放暑假那个时间为节点(也就是六、七月份),前后分为兼职考研和全职考研。什么是兼职考研?就是指你平时该上课就好好上课..._暨南大学的计算机应用专业怎么样知乎

Python爬虫详解(一看就懂)-程序员宅基地

文章浏览阅读9.1w次,点赞249次,收藏1.7k次。爬虫简单的来说就是用程序获取网络上数据这个过程的一种名称。如果要获取网络上数据,我们要给爬虫一个网址(程序中通常叫URL),爬虫发送一个HTTP请求给目标网页的服务器,服务器返回数据给客户端(也就是我们的爬虫),爬虫再进行数据解析、保存等一系列操作。爬虫可以节省我们的时间,比如我要获取豆瓣电影 Top250 榜单,如果不用爬虫,我们要先在浏览器上输入豆瓣电影的 URL ,客户端(浏览器)通过解析查到豆瓣电影网页的服务器的 IP 地址,然后与它建立连接,浏览器再创造一个 HTTP 请求发送给豆瓣电影的服务器,_python爬虫

Demo project for Spring Boot 【1】spring-boot-starter【2】spring-boot-maven-plugin_win7 demo project for spring boot版本-程序员宅基地

文章浏览阅读474次。spring-boot-maven-plugin是一个Maven插件,用于简化Spring Boot应用程序的构建和部署过程。它提供了许多有用的功能,包括:打包可执行的JAR文件:该插件可以将Spring Boot应用程序打包为可执行的JAR文件,其中包含了所有的依赖和资源文件。自动重新加载:在开发过程中,该插件可以监视应用程序的源代码和资源文件的变化,并自动重新加载应用程序,以便快速查看修改的效果。自定义属性:通过该插件,可以方便地在构建过程中设置自定义的属性,例如应用程序的版本号、环境配置等。运行应用程_win7 demo project for spring boot版本

C语言笔记_学校进行长跑训练,规定学生第一天训练300米,第二天训练337.5米,第三天训练379.688-程序员宅基地

文章浏览阅读783次,点赞3次,收藏6次。C语言笔记程序与算法算法设计原则C程序的编写、编译和运行C程序的构成C程序编写、编译、链接程序与算法算法设计原则分而治之:复杂的程序可以分解成若干简单子程序模块化:常用程序模块是可以重复使用的C程序的编写、编译和运行C程序的构成预处理命令函数语句单词注释 //注意书写规范C程序编写、编译、链接c源程序的编写,扩展名为.c的源文件编译器编译,扩展名为.obj的目标程序链接器链接,扩展名.exe的可执行程序..._学校进行长跑训练,规定学生第一天训练300米,第二天训练337.5米,第三天训练379.688

论文浅尝 | 知识赋能的信息系统专题前言(软件学报 2023 年第 10 期)-程序员宅基地

文章浏览阅读517次。高 宏1, 陈华钧2, 赵 翔3, 李瑞轩41(浙江师范大学 计算机科学与技术学院, 浙江 金华 321019)2(浙江大学 计算机科学与技术学院, 浙江 杭州 310027)3(国防科技大学 大数据与决策实验室, 湖南 长沙 410073)4(华中科技大学 计算机科学与技术学院, 湖北 武汉 430074)通信作者: 高宏, E-mail: [email protected]; 陈..._高宏 浙江师范大学