最近使用yolov4训练了一个车辆检测模型,但是数据量不够,一张一张标注又很费时费力,所以想通过已有的模型来检测新的图像来自动生成一些标注信息。在网上找了很久发现一个代码,具体出处不记得了。贴上代码和使用方法:
import argparse
import os
import glob
import random
import darknet
import time
import cv2
import numpy as np
import darknet
from tqdm import tqdm
def parser ( ) :
parser = argparse. ArgumentParser( description= "YOLO Object Detection Label Tools" )
parser. add_argument( "--input" , type = str , default= "" ,
help = "image source. It can be a single image, a"
"txt with paths to them, or a folder. Image valid"
" formats are jpg, jpeg or png."
"If no input is given, " )
parser. add_argument( "--batch_size" , default= 1 , type = int ,
help = "number of images to be processed at the same time" )
parser. add_argument( "--weights" , default= "yolov4.weights" ,
help = "yolo weights path" )
parser. add_argument( "--show" , action= 'store_false' ,
help = "windown inference display. For headless systems" )
parser. add_argument( "--ext_output" , action= 'store_true' ,
help = "display bbox coordinates of detected objects" )
parser. add_argument( "--save_labels" , action= 'store_true' ,
help = "save detections bbox for each image in voc format" )
parser. add_argument( "--config_file" , default= "./cfg/yolov4-defect.cfg" ,
help = "path to config file" )
parser. add_argument( "--data_file" , default= "./cfg/defect.data" ,
help = "path to data file" )
parser. add_argument( "--thresh" , type = float , default= .25 ,
help = "remove detections with lower confidence" )
return parser. parse_args( )
def check_arguments_errors ( args) :
assert 0 < args. thresh < 1 , "Threshold should be a float between zero and one (non-inclusive)"
if not os. path. exists( args. config_file) :
raise ( ValueError( "Invalid config path {}" . format ( os. path. abspath( args. config_file) ) ) )
if not os. path. exists( args. weights) :
raise ( ValueError( "Invalid weight path {}" . format ( os. path. abspath( args. weights) ) ) )
if not os. path. exists( args. data_file) :
raise ( ValueError( "Invalid data file path {}" . format ( os. path. abspath( args. data_file) ) ) )
if args. input and not os. path. exists( args. input ) :
raise ( ValueError( "Invalid image path {}" . format ( os. path. abspath( args. input ) ) ) )
def check_batch_shape ( images, batch_size) :
"""
Image sizes should be the same width and height
"""
shapes = [ image. shape for image in images]
if len ( set ( shapes) ) > 1 :
raise ValueError( "Images don't have same shape" )
if len ( shapes) > batch_size:
raise ValueError( "Batch size higher than number of images" )
return shapes[ 0 ]
def load_images ( images_path) :
"""
If image path is given, return it directly
For txt file, read it and return each line as image path
In other case, it's a folder, return a list with names of each
jpg, jpeg and png file
"""
input_path_extension = images_path. split( '.' ) [ - 1 ]
if input_path_extension in [ 'jpg' , 'jpeg' , 'png' ] :
return [ images_path]
elif input_path_extension == "txt" :
with open ( images_path, "r" ) as f:
return f. read( ) . splitlines( )
else :
return glob. glob(
os. path. join( images_path, "*.jpg" ) ) + \
glob. glob( os. path. join( images_path, "*.png" ) ) + \
glob. glob( os. path. join( images_path, "*.jpeg" ) )
def prepare_batch ( images, network, channels= 3 ) :
width = darknet. network_width( network)
height = darknet. network_height( network)
darknet_images = [ ]
for image in images:
image_rgb = cv2. cvtColor( image, cv2. COLOR_BGR2RGB)
image_resized = cv2. resize( image_rgb, ( width, height) ,
interpolation= cv2. INTER_LINEAR)
custom_image = image_resized. transpose( 2 , 0 , 1 )
darknet_images. append( custom_image)
batch_array = np. concatenate( darknet_images, axis= 0 )
batch_array = np. ascontiguousarray( batch_array. flat, dtype= np. float32) / 255.0
darknet_images = batch_array. ctypes. data_as( darknet. POINTER( darknet. c_float) )
return darknet. IMAGE( width, height, channels, darknet_images)
def image_detection ( image_path, network, class_names, class_colors, thresh) :
width = darknet. network_width( network)
height = darknet. network_height( network)
darknet_image = darknet. make_image( width, height, 3 )
image = cv2. imread( image_path)
image_rgb = cv2. cvtColor( image, cv2. COLOR_BGR2RGB)
image_resized = cv2. resize( image_rgb, ( width, height) ,
interpolation= cv2. INTER_LINEAR)
darknet. copy_image_from_bytes( darknet_image, image_resized. tobytes( ) )
detections = darknet. detect_image( network, class_names, darknet_image, thresh= thresh)
darknet. free_image( darknet_image)
image = darknet. draw_boxes( detections, image_resized, class_colors)
return cv2. cvtColor( image, cv2. COLOR_BGR2RGB) , detections
def batch_detection ( network, images, class_names, class_colors,
thresh= 0.25 , hier_thresh= .5 , nms= .45 , batch_size= 4 ) :
image_height, image_width, _ = check_batch_shape( images, batch_size)
darknet_images = prepare_batch( images, network)
batch_detections = darknet. network_predict_batch( network, darknet_images, batch_size, image_width,
image_height, thresh, hier_thresh, None , 0 , 0 )
batch_predictions = [ ]
for idx in range ( batch_size) :
num = batch_detections[ idx] . num
detections = batch_detections[ idx] . dets
if nms:
darknet. do_nms_obj( detections, num, len ( class_names) , nms)
predictions = darknet. remove_negatives( detections, class_names, num)
images[ idx] = darknet. draw_boxes( predictions, images[ idx] , class_colors)
batch_predictions. append( predictions)
darknet. free_batch_detections( batch_detections, batch_size)
return images, batch_predictions
def image_classification ( image, network, class_names) :
width = darknet. network_width( network)
height = darknet. network_height( network)
image_rgb = cv2. cvtColor( image, cv2. COLOR_BGR2RGB)
image_resized = cv2. resize( image_rgb, ( width, height) ,
interpolation= cv2. INTER_LINEAR)
darknet_image = darknet. make_image( width, height, 3 )
darknet. copy_image_from_bytes( darknet_image, image_resized. tobytes( ) )
detections = darknet. predict_image( network, darknet_image)
predictions = [ ( name, detections[ idx] ) for idx, name in enumerate ( class_names) ]
darknet. free_image( darknet_image)
return sorted ( predictions, key= lambda x: - x[ 1 ] )
def convert2relative ( image, bbox) :
"""
YOLO format use relative coordinates for annotation
"""
x, y, w, h = bbox
height, width, _ = image. shape
return x/ width, y/ height, w/ width, h/ height
def save_annotations ( name, image, detections, class_names) :
"""
Files saved with image_name.txt and relative coordinates
"""
file_name = name. split( "." ) [ : - 1 ] [ 0 ] + ".txt"
with open ( file_name, "w" ) as f:
for label, confidence, bbox in detections:
print ( bbox)
x, y, w, h = convert2relative( image, bbox)
label = class_names. index( label)
f. write( "{} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f}\n" . format ( label, x, y, w, h, float ( confidence) ) )
def writexml ( fileDir, imgInfo, rectLists) :
f = open ( fileDir, "w" , encoding= 'UTF-8' )
f. write( "\n" )
f. write( " train_images \n" )
f. write( " {} \n" . format ( os. path. basename( imgInfo[ 0 ] ) ) )
f. write( " {} \n" . format ( imgInfo[ 0 ] ) )
f. write( " \n" )
f. write( " Unknow \n" )
f. write( " \n" )
f. write( " \n" )
f. write( " %d \n" % imgInfo[ 2 ] )
f. write( " %d \n" % imgInfo[ 1 ] )
f. write( " {} \n" . format ( imgInfo[ 3 ] ) )
f. write( " 0 \n" )
f. write( " \n" )
for rectList in rectLists:
f. write( " \n" )
f. write( " {} \n" . format ( rectList[ 0 ] ) )
f. write( " 0 \n" )
f. write( " 0 \n" )
f. write( " 0 \n" )
f. write( " \n" )
f. write( " %d \n" % rectList[ 1 ] )
f. write( " %d \n" % rectList[ 2 ] )
f. write( " %d \n" % rectList[ 3 ] )
f. write( " %d \n" % rectList[ 4 ] )
f. write( " \n" )
f. write( " \n" )
f. write( "\n" )
f. close( )
def getImgInfo ( path) :
img = cv2. imread( path)
height, width, depth = img. shape
imgInfo = [ os. path. abspath( path) , height, width, depth]
return imgInfo
def save_xml ( name, image, detections, xmlDir, labels) :
"""
Files saved with image_name.xml and relative coordinates
"""
file_name = os. path. join( xmlDir, os. path. basename( name) . split( '.' ) [ 0 ] + '.xml' )
imgInfo = getImgInfo( name)
im_height, im_width = imgInfo[ 1 : - 1 ]
rectLists = [ ]
for label, confidence, bbox in detections:
x, y, w, h = convert2relative( image, bbox)
xmin = max ( float ( x) - float ( w) / 2 , 0 )
xmax = min ( float ( x) + float ( w) / 2 , 1 )
ymin = max ( float ( y) - float ( h) / 2 , 0 )
ymax = min ( float ( y) + float ( h) / 2 , 1 )
xmin = int ( im_width * xmin)
xmax = int ( im_width * xmax)
ymin = int ( im_height * ymin)
ymax = int ( im_height * ymax)
rectList = [ label, xmin, ymin, xmax, ymax]
if labels != [ ] :
if label in labels:
rectLists. append( rectList)
else :
rectLists. append( rectList)
if rectLists != [ ] :
writexml( file_name, imgInfo, rectLists)
else :
print ( "No objects were detected!\n" )
def batch_detection_example ( ) :
args = parser( )
check_arguments_errors( args)
batch_size = 3
random. seed( 3 )
network, class_names, class_colors = darknet. load_network(
args. config_file,
args. data_file,
args. weights,
batch_size= batch_size
)
image_names = [ 'data/horses.jpg' , 'data/horses.jpg' , 'data/eagle.jpg' ]
images = [ cv2. imread( image) for image in image_names]
images, detections, = batch_detection( network, images, class_names,
class_colors, batch_size= batch_size)
for name, image in zip ( image_names, images) :
cv2. imwrite( name. replace( "data/" , "" ) , image)
print ( detections)
def main ( labels) :
args = parser( )
check_arguments_errors( args)
random. seed( 3 )
network, class_names, class_colors = darknet. load_network(
args. config_file,
args. data_file,
args. weights,
batch_size= args. batch_size
)
images = load_images( args. input )
if args. save_labels:
xmlDir = input ( "Please input saved path of xml: " )
os. makedirs( xmlDir, exist_ok= True )
for image_name in tqdm( sorted ( images) ) :
try :
image, detections = image_detection(
image_name, network, class_names, class_colors, args. thresh
)
except :
print ( image_name + ' cannot read, skip!\n' )
continue
if args. save_labels:
save_xml( image_name, image, detections, xmlDir, labels)
if not args. show:
cv2. imshow( 'Inference' , image)
if cv2. waitKey( ) & 0xFF == ord ( 'q' ) :
break
if __name__ == "__main__" :
labels = [ 'vehicle' ]
main( labels)
将代码保存至darknet主目录下,文件名为darknet_label.py
需要修改的地方为第31行到35行,将default="./cfg/yolov4-defect.cfg",default="./cfg/defect.data"修改为自己的配置文件,thresh阈值根据自己的需要设置。
parser. add_argument( "--config_file" , default= "./cfg/yolov4-defect.cfg" ,
help = "path to config file" )
parser. add_argument( "--data_file" , default= "./cfg/defect.data" ,
help = "path to data file" )
parser. add_argument( "--thresh" , type = float , default= .25 ,
help = "remove detections with lower confidence" )
第327行修改为想要检测的标签名称
labels = [ 'vehicle' ]
最后在终端输入:
python darknet_label. py - - input 这里写为待检测图像存储的路径的txt文件 - - weights 输入自己的权重路径及名称 - - save_labels
运行后会显示要想保存xml文件的路径,输入即可: 到这里就结束了,在你刚才输入的路径就可以找到检测完成标注好的xml文件。虽然检测会有很多错误和漏检,但是可以减轻一定工作量。