源代码地址:grasp_multiObject_multiGrasp
原作者的开发环境是tensorflow-gpu+python2
而我的开发环境是tensorflow-cpu+python3.....还没有装matlab
首先修改一些配置文件使程序能够在CPU环境下运行,参考https://blog.csdn.net/m0_38024766/article/details/90712715
下载好Cornell Grasp dataset之后,首先要将数据集分类,用到的是PNG和POS.txt,因为系统里没装matlab,就自己写了一个python程序代替./data/scripts/dataPreprocessingTest_fasterrcnn_split.m
import random
import cv2
import skimage.io as io
import numpy as np
import os
import math
list_idx=list(range(100,950))+list(range(1000,1035))
random.shuffle(list_idx)
# print(len(list_idx))
train_list=list_idx[171:]
test_list=list_idx[:171]
cropSize=227
pi=3.14159
def dataPreprocessing_fasterrcnn(imagein,bbsIn_all,cropSize,translationShiftNumber,roatateAngleNumber):
imgCrop=imagein[65:414,145:494]
imgpadding=cv2.copyMakeBorder(imgCrop,75,75,75,75,cv2.BORDER_CONSTANT,value=[0,0,0])
count=1
rows, cols, channel = imgpadding.shape #绕图像的中心旋转 #参数:旋转中心 旋转度数 scale M = cv2.getRotationMatrix2D((cols/2, rows/2), 30, 1) #参数:原始图像 旋转参数 元素图像宽高 rotated = cv2.warpAffine(src, M, (cols, rows))
m,n=np.shape(bbsIn_all)
bbsnum=m//4
count=0
imgout={}
bbsout={}
dev=0
for i_rotate in range(roatateAngleNumber*translationShiftNumber*translationShiftNumber):
theta=random.randint(0,360)
dx=random.randint(0,101)-51
dy=random.randint(0,101)-51
M = cv2.getRotationMatrix2D((cols/2, rows/2), theta, 1)
imgrotate = cv2.warpAffine(imgpadding, M, (cols, rows))
imgcroprotate=imgrotate[int(rows/2-160-dy):int(rows/2+160-dy),int(cols/2-160-dx):int(cols/2+160-dx)]
imgCropRotateresize=cv2.resize(imgcroprotate,(cropSize,cropSize))
bbs=np.zeros((m,n))
lst=[]
for idx in range(bbsnum):
bbsin=bbsIn_all[idx*4:idx*4+4]
# print(bbsin)
# print(bbsin)
if np.isnan(bbsin).any():
dev=1
lst.extend([i for i in range(idx*4,idx*4+4)])
continue
bbsinshift=bbsin-[320,240]
# print(bbsinshift)
R=[[math.cos(theta/180*pi),-math.sin(theta/180*pi)],[math.sin(theta/180*pi),math.cos(theta/180*pi)]]
bbsRotated=np.matmul(bbsinshift,R)
bbsInShiftBack=(bbsRotated+[160,160]+[dx,dy])*cropSize/320
bbs[idx*4:idx*4+4]=bbsInShiftBack
imgout[count]=imgCropRotateresize
if dev:
bbs=np.delete(bbs,lst,axis=0)
dev=0
# print(bbs)
bbsout[count]=bbs
count+=1
return imgout,bbsout
imgDataDir = '/home/sjy/grasp_multiObject_multiGrasp/data/grasps/image/'
posDataDir='/home/sjy/grasp_multiObject_multiGrasp/data/grasps/pos_label/'
# negDataDir= '/home/sjy/grasp_multiObject_multiGrasp/data/grasps/neg_label/'
imgDataOutDir='/home/sjy/grasp_data/data/Images/'
annotationDataOutDir='/home/sjy/grasp_data/data/Annotations/'
imgSetTrain = '/home/sjy/grasp_data/data/ImageSets/train1.txt'
imgSetTrain1 = '/home/sjy/grasp_data/data/ImageSets/train.txt'
imgSetTest = '/home/sjy/grasp_data/data/ImageSets/test.txt'
imgFiles=[x.name for x in os.scandir(imgDataDir)]
txtFiles=[x.name for x in os.scandir(posDataDir)]
# f=open(imgSetTest,"a")
for x in imgFiles:
print(x)
if int(x[3:7]) in test_list:
f=open(imgSetTest,"a")
img=cv2.imread(imgDataDir+x)
img=img[80:399,160:479]
img=cv2.resize(img,(227,227))
cv2.imwrite(imgDataOutDir+x[0:8]+'_preprocessed_0.png',img)
f.write(x[0:8]+'_preprocessed_0\n')
f.close()
bbsIn_all=np.loadtxt(posDataDir+"pcd{}cpos.txt".format(x[3:7]))
m,n=np.shape(bbsIn_all)
bbsnum=m//4
bbsOut=np.zeros((m,n))
lst=[]
dev=0
for idx in range(bbsnum):
bbsin=bbsIn_all[idx*4:idx*4+4]
# print(bbsin)
# print(bbsin)
if np.isnan(bbsin).any():
dev=1
lst.extend([i for i in range(idx*4,idx*4+4)])
continue
bbsinshift=bbsin-[320,240]
# print(bbsinshift)
bbsInShiftBack=(bbsinshift+[160,160])*cropSize/320
bbsOut[idx*4:idx*4+4]=bbsInShiftBack
if dev:
bbsOut=np.delete(bbsOut,lst,axis=0)
file_writeID=open(annotationDataOutDir+x[0:8]+"_preprocessed_0.txt","w")
for j in range(len(bbsOut)//4):
points=bbsOut[4*j:4*j+4]
width=np.sqrt(sum(np.square(bbsOut[j*4]-bbsOut[j*4+1])))
height=np.sqrt(sum(np.square(bbsOut[j*4+1]-bbsOut[j*4+2])))
if bbsOut[j*4][0]>bbsOut[j*4+1][0]:
theta=math.atan((bbsOut[j*4+1][1]-bbsOut[j*4][1])/(bbsOut[j*4][0]-bbsOut[j*4+1][0]))
else:
theta=math.atan((bbsOut[j*4][1]-bbsOut[j*4+1][1])/(bbsOut[j*4+1][0]-bbsOut[j*4][0]))
x_crt,y_crt=sum(points)/4
x_min=x_crt-width/2
x_max=x_crt+width/2
y_min=y_crt-height/2
y_max=y_crt+height/2
if x_min<0 or y_min<0 or x_max>227 or y_max>227:
continue
# print(bbsOut[i])
orient=round((theta/pi*180+90)/10)+1
file_writeID.write('{} {} {} {} {}\n'.format(orient,x_min,y_min,x_max,y_max))
file_writeID.close()
continue
img=cv2.imread(imgDataDir+x)
bbsIn_all=np.loadtxt(posDataDir+"pcd{}cpos.txt".format(x[3:7]))
imageout,bbsOut=dataPreprocessing_fasterrcnn(img,bbsIn_all,227,1,1)
# print(imageout)
for i in range(len(imageout)):
# print(i)
file_writeID=open(annotationDataOutDir+x[0:8]+"_preprocessed_{}.txt".format(i),"w")
for j in range(len(bbsOut[i])//4):
points=bbsOut[i][4*j:4*j+4]
width=np.sqrt(sum(np.square(bbsOut[i][j*4]-bbsOut[i][j*4+1])))
height=np.sqrt(sum(np.square(bbsOut[i][j*4+1]-bbsOut[i][j*4+2])))
if bbsOut[i][j*4][0]>bbsOut[i][j*4+1][0]:
theta=math.atan((bbsOut[i][j*4+1][1]-bbsOut[i][j*4][1])/(bbsOut[i][j*4][0]-bbsOut[i][j*4+1][0]))
else:
theta=math.atan((bbsOut[i][j*4][1]-bbsOut[i][j*4+1][1])/(bbsOut[i][j*4+1][0]-bbsOut[i][j*4][0]))
x_crt,y_crt=sum(points)/4
x_min=x_crt-width/2
x_max=x_crt+width/2
y_min=y_crt-height/2
y_max=y_crt+height/2
orient=round((theta/pi*180+90)/10)+1
if x_min<0 or y_min<0 or x_max>227 or y_max>227:
continue
# print(bbsOut[i])
file_writeID.write('{} {} {} {} {}\n'.format(orient,x_min,y_min,x_max,y_max))
file_writeID.close()
# print(imgDataOutDir+x[0:8]+"_preprocessed_{}.png".format(i))
cv2.imwrite(imgDataOutDir+x[0:8]+"_preprocessed_{}.png".format(i),imageout[i])
file_writeID=open(imgSetTrain,"a")
file_writeID.write(x[0:8]+"_preprocessed_{}\n".format(i))
file_writeID.close()
file_writeID=open(imgSetTrain1,"a")
for x in open(imgSetTrain):
size = os.path.getsize('/home/sjy/grasp_data/data/Annotations/{}.txt'.format(x[:-1]))
if size!=0:
file_writeID.write(x[:-1]+'\n')
file_writeID.close()
其中输入地址和输出地址自己改一下就行。
注意,生成的数据类型的储存格式是
INRIA
|-- data
|-- Annotations
|-- *.txt (Annotation files)
|-- Images
|-- *.png (Image files)
|-- ImageSets
|-- train.txt
这在后面会用到,准备好数据之后,主要需要改的地方是.lib/datasets/factory.py
43行改成自己的数据地址
graspRGB_devkit_path='/home/sjy/grasp_data'
之后运行
./experiments/scripts/train_faster_rcnn.sh 0 graspRGB res50
这里我遇到一个报错
TypeError: bottleneck() argument after ** must be a mapping, not tuple
解决方法,打开python3.6/site-packages/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py
203行后加如下
unit={'depth':unit[0],'depth_bottleneck':unit[1],'stride':unit[2]}
折腾了这么久总算跑起来了
注意:重新训练需要删除data/cache 以及output