一、创建数据集,作者整理了57种类别的交通标志图片
二、训练模型,作者使用TensorFlow深度学习框架训练所需的模型文件
def run(self):
# 1. 加载数据集
train_dataset, validate_dataset, class_names = self.m_data_load(self.train_dir, 224, 224, 16)
self.ui.train_dataset = train_dataset
self.ui.validate_dataset = validate_dataset
# 2. 加载模型
model = self.model_load(class_num=len(class_names))
self.ui.model = model
self.signal.emit(str(len(class_names)), class_names)
# 模型加载
def model_load(IMG_SHAPE=(224, 224, 3), class_num=214):
base_model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
base_model.trainable = False
model = tf.keras.models.Sequential([
tf.keras.layers.experimental.preprocessing.Rescaling(1. / 127.5, offset=-1, input_shape=(224, 224, 3)),
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(class_num, activation='softmax')
])
# 输出模型信息
model.summary()
model.compile(optimizer='adam', loss='categorical_crossentropy',
metrics=['accuracy'])
return model
三、将训练好的模型文件导入Android工程的assets文件夹下。
四、编写Android代码
public class MainActivity extends Activity {
// 类别的数量
private int number = 57;
// 类别名称
private String class_names[] = {
"限速15", "限速30", "限速40", "限速50", "限速60", "限速70", "限速80", "禁止直行和左转", "禁止直行和右转", "禁止直行", "禁止左转",
"禁止左转和右转", "禁止右转", "禁止超车", "禁止掉头", "禁止机动车通行", "禁止鸣喇叭", "解除限速40", "解除限速50",
"直行和右转", "直行", "左转", "左转和右转", "右转", "靠左侧道路行驶", "靠右侧道路行驶", "环岛行驶", "机动车行驶", "鸣喇叭", "非机动车行驶 ", "掉头", "注意避让",
"注意红绿灯", "注意危险", "注意行人", "注意非机动车", "注意儿童", "注意急右转弯", "注意急左转弯", "注意下坡", "注意上坡",
"注意慢行", "T型交叉", "T型交叉", "村庄", "反向弯路", "无人看守铁路道口", "施工", "连续弯路", "有人看守铁路道口", "事故易发路段",
"停车让行", "禁止通行", "禁止停车", "禁止驶入", "减速让行", "停车检查"
};
// 输入
private int[] input = {1, 224, 224, 3};
// 输出
private float[][] output = new float[1][number];
private Interpreter interpreter;
private Bitmap bitmap;
private ImageView iv_vegetable;
private String[] neededPermissions = new String[]{
Manifest.permission.READ_PHONE_STATE
};
private TextView tv_text;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.LOLLIPOP) {
Window window = this.getWindow();
window.clearFlags(WindowManager.LayoutParams.FLAG_TRANSLUCENT_STATUS);
window.getDecorView().setSystemUiVisibility(View.SYSTEM_UI_FLAG_LAYOUT_FULLSCREEN
| View.SYSTEM_UI_FLAG_LAYOUT_STABLE);
window.addFlags(WindowManager.LayoutParams.FLAG_DRAWS_SYSTEM_BAR_BACKGROUNDS);
window.setStatusBarColor(Color.GRAY);
}
setContentView(R.layout.activity_main);
/*
* 在选择图片的时候,在android 7.0及以上通过FileProvider获取Uri,不需要文件权限
*/
if (Build.VERSION.SDK_INT < Build.VERSION_CODES.N) {
List permissionList = new ArrayList<>(Arrays.asList(neededPermissions));
permissionList.add(Manifest.permission.READ_EXTERNAL_STORAGE);
neededPermissions = permissionList.toArray(new String[0]);
}
initView();
TFLiteLoader loader = TFLiteLoader.newInstance(this);
interpreter = loader.get();
showToast("模型加载成功!");
bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.orange);
}
五、实现效果
基于TensorFlow的交通标志识别
六、完整源码下载
链接:https://pan.baidu.com/s/1vhtkevbQdbt3nuB6YOTzZA?pwd=99gp
提取码:99gp