交通标识识别(神经网络与深度学习)

引言

本次博客将分享Udacity无人驾驶纳米学位的另一个项目,交通标志的识别。 本次项目实现主要采用CNN卷积神经网络,具体的网络结构参考Lecun提出的LeNet结构。参考文献:Lecun Paper

项目流程图

本项目的实现流程如下所示: 交通标识识别(神经网络与深度学习)_第1张图片

代码实现及解释

接下来我们就按照项目流程图来逐块实现,本项目数据集:German data 如果打不开,则有备用链接:备用icon-default.png?t=N7T8http://benchmark.ini.rub.de/?section=gtsrb&subsection=dataset

#import important packages/libraries
import numpy as np
import tensorflow as tf
import pickle
import matplotlib.pyplot as plt
import random
import csv
from sklearn.utils import shuffle
from tensorflow.contrib.layers import flatten
from skimage import transform as transf
from sklearn.model_selection import train_test_split
import cv2
from prettytable import PrettyTable
%matplotlib inline
SEED = 2018
/home/ora/anaconda3/envs/tensorflow/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
WARNING:tensorflow:From /home/ora/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.

# 导入数据并可视化
training_file = 'data/train.p'
testing_file = 'data/test.p'
​
with open(training_file,mode='rb') as f:
    train = pickle.load(f)
with open(testing_file,mode='rb') as f:
    test = pickle.load(f)
​
X_train,y_train = train['features'],train['labels']
X_test,y_test = test['features'],test['labels']

Dataset Summary and Expoloration

下面我们对德国交通指示牌数据集进行可视化处理

n_train = len(X_train)
n_test = len(X_test)
​
_,IMG_HEIGHT,IMG_WIDTH,IMG_DEPTH = X_train.shape
image_shape = (IMG_HEIGHT,IMG_WIDTH,IMG_DEPTH)
​
with open('data/signnames.csv','r') as sign_name:
    reader = csv.reader(sign_name)
    sign_names = list(reader)
​
sign_names = sign_names[1::]
NUM_CLASSES = len(sign_names)
print('Total number of classes:{}'.format(NUM_CLASSES))
​
n_classes = len(np.unique(y_train))
assert (NUM_CLASSES== n_classes) ,'1 or more class(es) not represented in training set'
​
n_test = len(y_test)
​
print('Number of training examples =',n_train)
print('Number of testing examples =',n_test)
print('Image data shape=',image_shape)
print('Number of classes =',n_classes)
Total number of classes:43
Number of training examples = 34799
Number of testing examples = 12630
Image data shape= (32, 32, 3)
Number of classes = 43

#data visualization,show 20 images
def visualize_random_images(list_imgs,X_dataset,y_dataset):
    #list_imgs:20 index
    _,ax = plt.subplots(len(list_imgs)//5,5,figsize=(20,10))
    row,col = 0,0
    for idx in list_imgs:
        img = X_dataset[idx]
        ax[row,col].imshow(img)
        ax[row,col].annotate(int(y_dataset[idx]),xy=(2,5),color='red',fontsize='20')
        ax[row,col].axis('off')
        col+=1
        if col==5:
            row,col = row+1,0
    plt.show()
ls = [random.randint(0,len(y_train)) for i in range(20)]
visualize_random_images(ls,X_train,y_train)

交通标识识别(神经网络与深度学习)_第2张图片

def get_count_imgs_per_class(y, verbose=False):
    num_classes = len(np.unique(y))
    count_imgs_per_class = np.zeros( num_classes )
​
    for this_class in range( num_classes ):
        if verbose: 
            print('class {} | count {}'.format(this_class, np.sum( y  == this_class )) )
        count_imgs_per_class[this_class] = np.sum(y == this_class )
    #sanity check
    return count_imgs_per_class
class_freq = get_count_imgs_per_class(y_train)
print('------- ')
print('Highest count: {} (class {})'.format(np.max(class_freq), np.argmax(class_freq)))
print('Lowest count: {} (class {})'.format(np.min(class_freq), np.argmin(class_freq)))
print('------- ')
plt.bar(np.arange(NUM_CLASSES), class_freq , align='center')
plt.xlabel('class')
plt.ylabel('Frequency')
plt.xlim([-1, 43])
plt.title("class frequency in Training set")
plt.show()
sign_name_table = PrettyTable()
sign_name_table.field_names = ['class value', 'Name of Traffic sign']
for i in range(len(sign_names)):
    sign_name_table.add_row([sign_names[i][0], sign_names[i][1]] )
    
print(sign_name_table)
------- 
Highest count: 2010.0 (class 2)
Lowest count: 180.0 (class 0)
------- 

交通标识识别(神经网络与深度学习)_第3张图片

+-------------+----------------------------------------------------+
| class value |                Name of Traffic sign                |
+-------------+----------------------------------------------------+
|      0      |                Speed limit (20km/h)                |
|      1      |                Speed limit (30km/h)                |
|      2      |                Speed limit (50km/h)                |
|      3      |                Speed limit (60km/h)                |
|      4      |                Speed limit (70km/h)                |
|      5      |                Speed limit (80km/h)                |
|      6      |            End of speed limit (80km/h)             |
|      7      |               Speed limit (100km/h)                |
|      8      |               Speed limit (120km/h)                |
|      9      |                     No passing                     |
|      10     |    No passing for vechiles over 3.5 metric tons    |
|      11     |       Right-of-way at the next intersection        |
|      12     |                   Priority road                    |
|      13     |                       Yield                        |
|      14     |                        Stop                        |
|      15     |                    No vechiles                     |
|      16     |      Vechiles over 3.5 metric tons prohibited      |
|      17     |                      No entry                      |
|      18     |                  General caution                   |
|      19     |            Dangerous curve to the left             |
|      20     |            Dangerous curve to the right            |
|      21     |                    Double curve                    |
|      22     |                     Bumpy road                     |
|      23     |                   Slippery road                    |
|      24     |             Road narrows on the right              |
|      25     |                     Road work                      |
|      26     |                  Traffic signals                   |
|      27     |                    Pedestrians                     |
|      28     |                 Children crossing                  |
|      29     |                 Bicycles crossing                  |
|      30     |                 Beware of ice/snow                 |
|      31     |               Wild animals crossing                |
|      32     |        End of all speed and passing limits         |
|      33     |                  Turn right ahead                  |
|      34     |                  Turn left ahead                   |
|      35     |                     Ahead only                     |
|      36     |                Go straight or right                |
|      37     |                Go straight or left                 |
|      38     |                     Keep right                     |
|      39     |                     Keep left                      |
|      40     |                Roundabout mandatory                |
|      41     |                 End of no passing                  |
|      42     | End of no passing by vechiles over 3.5 metric tons |
+-------------+----------------------------------------------------+

def histograms_randImgs(label,channel,n_imgs=5,ylim=50):
    '''
    Histogram (pixel intensity distribution) for a selection of images with the same label.
    For better visualization, the images are shown in grayscale
    label - the label of the images
    n_imgs - number of images to show (default=5)
    channel - channel used to compute histogram
    ylim - range of y axis values for histogram plot (default=50)
    '''
    assert channel < 3,'image are RGB,choose channel value between in the range[0,2]'
    assert (np.sum(y_train==label))>=n_imgs,'reduce your number of images'
    
    all_imgs = np.ravel(np.argwhere(y_train==label))
    
    #随机选择5张图片
    ls_idx = np.random.choice(all_imgs,size=n_imgs,replace=False)
    _,ax = plt.subplots(n_imgs,2,figsize=(10,10))
    print('Histogram of selected images from the class{} ......'.format(label))
    row,col = 0,0
    for idx in ls_idx:
        img = X_train[idx,:,:,channel]
        #print(img.shape)
        ax[row,col].imshow(img,cmap='gray')
        ax[row,col].axis('off')
        
        hist = np.histogram(img,bins=256)
        ax[row,col+1].hist(hist,bins=256)
        ax[row,col+1].set_xlim([0,100])
        ax[row,col+1].set_ylim([0,ylim])
        col,row = 0,row+1
    plt.show()
histograms_randImgs(38,1)
Histogram of selected images from the class38 ......

交通标识识别(神经网络与深度学习)_第4张图片

接下来对数据做进一步处理

我们完成以下几个步骤:

  • 数据增强

  • 将RGB转换成Grayscale

  • 数据尺度变换

Note:数据集的划分必须在数据增强完成前(防止验证集被合成图像污染)

数据增强具体步骤

这里的数据增强主要是:1.增加训练集的大小 2.调整了类别分布(类别分布是不均衡的,因为测试集可能相较与训练集来讲,有着不同的分布,因此我们希望在类别分布均衡的数据集上训练,给不同类别相同的权重,然后在不均衡的数据集上测试时可以有更好的效果) 数据增强后,我们得到每个类别4000张图片 数据增强的方法主要就是从原始数据集中随机选取图片,并应用仿射变换

  • 旋转角度我限制在【-10,10】度之间,如果旋转角度过大,有些交通标志的意思可能就会发生变化了

  • 水平、垂直移动的话,范围限制在【-3,3】px之间

  • 伸缩变换限制在【0.8,1.2】

def random_transform(img,angle_range=[-10,10],
                    scale_range=[0.8,1.2],
                    translation_range=[-3,3]):
    '''
    The function takes an image and performs a set of random affine transformation.
    img:original images
    ang_range:angular range of the rotation [-15,+15] deg for example
    scale_range: [0.8,1.2]
    shear_range:[10,-10]
    translation_range:[-2,2]
    '''
    img_height,img_width,img_depth = img.shape
    # Generate random parameter values
    angle_value = np.random.uniform(low=angle_range[0],high=angle_range[1],size=None)
    scaleX = np.random.uniform(low=scale_range[0],high=scale_range[1],size=None)
    scaleY = np.random.uniform(low=scale_range[0],high=scale_range[1],size=None)
    translationX = np.random.randint(low=translation_range[0],high=translation_range[1]+1,size=None)
    translationY = np.random.randint(low=translation_range[0],high=translation_range[1]+1,size=None)
    
    center_shift = np.array([img_height,img_width])/2. - 0.5
    transform_center = transf.SimilarityTransform(translation=-center_shift)
    transform_uncenter = transf.SimilarityTransform(translation=center_shift)
    
    transform_aug = transf.AffineTransform(rotation=np.deg2rad(angle_value),
                                          scale=(1/scaleY,1/scaleX),
                                          translation = (translationY,translationX))
    #Image transformation : includes rotation ,shear,translation,zoom
    full_tranform = transform_center + transform_aug + transform_uncenter
    new_img = transf.warp(img,full_tranform,preserve_range=True)
    
    return new_img.astype('uint8')

def data_augmentation(X_dataset,y_dataset,augm_nbr,keep_dist=True):
    '''
    X_dataset:image dataset to augment
    y_dataset:label dataset
    keep_dist - True:keep class distributio

你可能感兴趣的:(图像处理,数据分析,生成对抗网络,深度学习,神经网络)