深度学习(二)车牌识别(40个类别)

图片加标签——创建数据集——创建模型——开始训练
代码如下:

import cv2
import numpy as np
import os
import torch.nn as nn
import  torch.utils.data as Data
from torch.autograd import Variable
from torch.utils.data import Dataset,DataLoader,TensorDataset
import torch

#给图片加标签
train_path = 'train_car'
train_list = os.listdir(train_path)

labels_list = []
images_path = []
images_list = []
hight = 32
width = 32
for train in train_list:
	one_train_list = os.listdir(train_path+'/'+train)
	num = len(one_train_list)

	for i in range(num):
		labels_list.append(int(train))

	for image_path in one_train_list:
		images_path.append(train_path+'/'+train+'/'+image_path)

for path in images_path:
	#image = cv_imread(path,0)
	image = cv2.imdecode(np.fromfile(path,dtype=np.uint8),-1)
	# if (hight==None or width==None):
	# 	hight,width = image.shape
	# 	if (hight>width):
	# 		hight = width
	# 	else:
	# 		width = hight

	image = cv2.resize(image,(hight,width))
	images_list.append(image)

#创建数据集
def data_loader(images_list,labels_list):
	train_list = np.asarray(images_list)
	labels_list = np.asarray(labels_list)
	train_list = torch.from_numpy(train_list.astype(np.float32))
	labels_list = torch.from_numpy(labels_list)

	dataset = TensorDataset(train_list,labels_list)
	dataloader = DataLoader(
		dataset,
		batch_size =100,
		shuffle = True,
		num_workers = 2
		)
	return dataloader

#开始训练
def train_data(images_list,labels_list):
	global hight,width
	dataloader = data_loader(images_list,labels_list)
	net = Net(hight,width)
	optimizer = torch.optim.SGD(net.parameters(),lr=0.001)
	loss_function = torch.nn.CrossEntropyLoss()

	for i in range(200):
		for item in dataloader:
			try:
				data_x = item[0]
				data_y = item[1]
				data_x = Variable(data_x).unsqueeze(0).view(100,1,hight,width)
				data_y = Variable(data_y).float()
				prediction = net(data_x)
				loss = loss_function(prediction,data_y.long())
				optimizer.zero_grad()
				loss.backward()
				optimizer.step()
				print(loss)
			except Exception as e:
				print(e)
	torch.save(net,'car_number.pkl')

创建模型
class Net(nn.Module):

	def __init__(self,hight,width):
		super(Net,self).__init__()
		self.body = nn.Sequential(
			nn.Conv2d(1,16,3,padding=1),
			nn.BatchNorm2d(16),
			nn.ReLU(True),
			)
		self.body1 = nn.Sequential(
			nn.Conv2d(16,32,3,padding=1),
			nn.BatchNorm2d(32),
			nn.ReLU(True),
			nn.MaxPool2d(kernel_size=2,stride=2)
			)
		self.body2 = nn.Sequential(
			nn.Conv2d(32,64,3,padding=1),
			nn.BatchNorm2d(64),
			nn.ReLU(True),
			nn.MaxPool2d(kernel_size=2,stride=2)
			)
		self.body3 = nn.Sequential(
			nn.Conv2d(64,128,3,padding=1),
			nn.BatchNorm2d(128),
			nn.ReLU(True),
			nn.MaxPool2d(kernel_size=2,stride=2)
			)
		tail = []
		tail.append(
			nn.Linear(int(hight*width*128/64),1024)
			)
		tail.append(
			nn.ReLU(True)
			)
		tail.append(
			nn.Linear(1024,256)
			)
		tail.append(
			nn.ReLU(True)
			)
		tail.append(
			nn.Linear(256,40)
			)
		self.tail = nn.Sequential(*tail)

	def forward(self,x):
		ret = self.body(x)
		ret = self.body1(ret)
		ret = self.body2(ret)
		ret = self.body3(ret)
		ret = ret.view(ret.size(0),-1)
		ret = self.tail(ret)
		return ret

def main():
	train_data(images_list,labels_list)

if __name__=='__main__':
	main()

测试模型代码如下:

import numpy as np
import cv2
import torch
import os
from train_car import Net

def main():
	path = 'test_images/'
	image_list = os.listdir(path)
	net = torch.load('car_number.pkl')
	loss_function = torch.nn.CrossEntropyLoss()

	for img in image_list:
		image_path = path+img
		image = cv2.imread(image_path,0)
		image = cv2.resize(image,(32,32))
		cv2.imshow('test',image)
		cv2.waitKey(2000)
		cv2.destroyAllWindows()

		h,w = image.shape
		imgae = np.asarray(image)
		image = torch.from_numpy(image).unsqueeze(0).unsqueeze(0).view(1,1,h,w)
		prediction = net(image.float())
		min_loss = 100
		result = None

		for i in range(40):
			value = np.asarray([i])
			value = torch.from_numpy(value)
			loss = loss_function(prediction,value.long())
			if (min_loss>loss):
				min_loss = loss
				result = i

		print(result)

if __name__=='__main__':
	main()

你可能感兴趣的:(图像处理,车牌识别)