博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
gluoncv 目标检测,训练自己的数据集
阅读量:6466 次
发布时间:2019-06-23

本文共 4222 字,大约阅读时间需要 14 分钟。

https://gluon-cv.mxnet.io/build/examples_datasets/detection_custom.html

官方提供两种方案,一种是lst文件,一种是xml文件(voc的格式);

voc 格式的标注有标注工具,但是你如果是json文件标注的信息,或者其他格式的,你就要转成voc格式的。

于是就选择第一种数据格式lst序列文件格式,格式很简单。

根据你自己的json或者其他格式文件转换一下。

import jsonimport osimport cv2import numpy as npdef write_line(img_path, im_shape, boxes, ids, idx):    h, w, c = im_shape    # for header, we use minimal length 2, plus width and height    # with A: 4, B: 5, C: width, D: height    A = 4    B = 5    C = w    D = h    # concat id and bboxes    labels = np.hstack((ids.reshape(-1, 1), boxes)).astype('float')    # normalized bboxes (recommanded)    labels[:, (1, 3)] /= float(w)    labels[:, (2, 4)] /= float(h)    # flatten    labels = labels.flatten().tolist()    str_idx = [str(idx)]    str_header = [str(x) for x in [A, B, C, D]]    str_labels = [str(x) for x in labels]    str_path = [img_path]    line = '\t'.join(str_idx + str_header + str_labels + str_path) + '\n'    return linefiles = os.listdir('train_front')json_url = []cnt = 0for file in files:    tmp = os.listdir('train_front/'+file)    for js in tmp:        if js.endswith('json'):            json_url.append('train_front/'+file+'/'+js)            cnt+=1print(cnt)fwtrain = open("train.lst","w")fwval = open("val.lst","w")first_flag = []flag = Truecnt = 0cnt1 = 0cnt2 = 0for json_url_index in json_url:    file = open(json_url_index,'r')    for line in file:        js = json.loads(line)        if 'person' in js:            boxes = []            ids = []            for i in range(len(js['person'])):                if js['person'][i]['attrs']['ignore'] == 'yes' or js['person'][i]['attrs']['occlusion']== 'heavily_occluded' or js['person'][i]['attrs']['occlusion']== 'invisible':                    continue                bbox = js['person'][i]['data']                url = '/mnt/hdfs-data-4/data/jian.yin/'+json_url_index[:-5]+'/'+js['image_key']                width = js['width']                height = js['height']                boxes.append(bbox)                ids.append(0)                print(url)                print(bbox)            if len(boxes) > 0:                if flag:                    flag = False                    first_flag = boxes                ids = np.array(ids)                if cnt < 27853//2:                    line = write_line(url,(height,width,3),boxes,ids,cnt1)                    fwtrain.write(line)                    cnt1+=1                if cnt >= 27853//2:                    line = write_line(url, (height, width, 3), boxes, ids, cnt2)                    fwval.write(line)                    cnt2+=1                cnt += 1fwtrain.close()fwval.close()print(first_flag)

lst文件就转换好了。

 

然后添加自己的数据集:

https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/faster_rcnn/train_faster_rcnn.py#L73

这里不能直接套用前面的导入数据的过程。

按照教程给出的方式添加。投机取巧的验证方式,直接引用前面的。

或者不验证:https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/faster_rcnn/train_faster_rcnn.py#L393 部分注释掉。

elif dataset.lower() == 'pedestrian':        lst_dataset = LstDetection('train_val.lst',root=os.path.expanduser('.'))        print(len(lst_dataset))        first_img = lst_dataset[0][0]        print(first_img.shape)        print(lst_dataset[0][1])                train_dataset = LstDetection('train.lst',root=os.path.expanduser('.'))        val_dataset = LstDetection('val.lst',root=os.path.expanduser('.'))        classs = ('pedestrian',)        val_metric = VOC07MApMetric(iou_thresh=0.5,class_names=classs)

训练参数:

https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/faster_rcnn/train_faster_rcnn.py#L73

添加自己的训练参数或者直接套用。

if args.dataset == 'voc' or args.dataset == 'pedestrian':        args.epochs = int(args.epochs) if args.epochs else 20        args.lr_decay_epoch = args.lr_decay_epoch if args.lr_decay_epoch else '14,20'        args.lr = float(args.lr) if args.lr else 0.001        args.lr_warmup = args.lr_warmup if args.lr_warmup else -1        args.wd = float(args.wd) if args.wd else 5e-4

model_zoo.py添加自己的数据集映射方案。这里如果是pip install gluoncv ,就要到site-package里面改。

https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_zoo.py#L32

'faster_rcnn_resnet50_v1b_pedestrian': faster_rcnn_resnet50_v1b_voc,

 

转载于:https://www.cnblogs.com/TreeDream/p/10174899.html

你可能感兴趣的文章
完整版:《开源框架实战宝典电子书V1.0.0》内测版下载地址!
查看>>
OCA读书笔记(3) - 使用DBCA创建Oracle数据库
查看>>
CKEditor的使用-编辑文本
查看>>
HDU------checksum
查看>>
puppet来管理文件和软件包
查看>>
Python基础进阶之路(一)之运算符和输入输出
查看>>
阻塞非阻塞异步同步 io的关系
查看>>
ClickStat业务
查看>>
DMA32映射问题
查看>>
POJ 1269 Intersecting Lines(判断两直线位置关系)
查看>>
MSSQL数据库跨表和跨数据库查询方法简(转)
查看>>
spring3.0.7中各个jar包的作用总结
查看>>
Windows 10 /win10 上使用GIT慢的问题,或者命令行反应慢的问题
查看>>
梯度下降(Gradient descent)
查看>>
Windows平台分布式架构实践 - 负载均衡
查看>>
Android快速开发常用知识点系列目录
查看>>
EJB2的配置
查看>>
最容易理解的对卷积(convolution)的解释
查看>>
《机器学习实战》知识点笔记目录
查看>>
完美解决NC502手工sql的查询引擎排序及合计问题
查看>>