一、开发环境
二、步骤
2.1 创建一个python 3.7.3的虚拟环境
conda create -n trash_gpu python==3.7.3
2.2 激活虚拟环境
conda activate trash_gpu
2.3 安装tensorflow-gpu,提前安装好CUDA 10.1和cuDNN 7.6.5
pip install tensorflow-gpu==2.3.0
2.4 准备垃圾分类数据集
2.5 编写训练模型代码,为了使模型文件更加轻量化,使用MobileNetV2来训练模型。
代码如下:
# 模型加载
def model_load(IMG_SHAPE=(224, 224, 3), class_num=214):
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
base_model.trainable = False
model = tf.keras.models.Sequential([
tf.keras.layers.experimental.preprocessing.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE),
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(class_num, activation='softmax')
])
# 输出模型信息
model.summary()
model.compile(optimizer='adam', loss='categorical_crossentropy',
metrics=['accuracy'])
return model
# 训练模型
def train(epochs):
# 1. 加载数据集
train_dataset, validate_dataset, class_names = data_load("E:/trash_image_set/data", 224, 224, 16)
# print(class_names)
print('类别的个数-->')
print(len(class_names))
# 2. 加载模型
model = model_load(class_num=len(class_names))
# 3. 训练
history = model.fit(train_dataset, validation_data=validate_dataset, epochs=epochs)
# 4. 保存模型
model.save("models/trash_model.h5")
# 5. 转换为tflite模型
h5_model = tf.keras.models.load_model("models/trash_model.h5")
converter = tf.lite.TFLiteConverter.from_keras_model(h5_model)
tflite_model = converter.convert()
open("models/model.tflite", "wb").write(tflite_model)
if __name__ == '__main__':
train(epochs=30)
2.6 经过漫长的训练过程后,在models文件夹中得到名称为model.tflite的模型文件,接下来将这个模型文件导入Android Studio工程中。
三、编写Android APP
3.1 将model.tflite模型文件拷贝到Android工程的assets文件中,如图:
3.2 同时要在app下build.gradle文件添加如下内容
aaptOptions {
noCompress "tflite"
}
3.3 编写activity_main.xml布局文件
3.4 编写MainActivity.java代码
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.LOLLIPOP) {
Window window = this.getWindow();
window.clearFlags(WindowManager.LayoutParams.FLAG_TRANSLUCENT_STATUS);
window.getDecorView().setSystemUiVisibility(View.SYSTEM_UI_FLAG_LAYOUT_FULLSCREEN
| View.SYSTEM_UI_FLAG_LAYOUT_STABLE);
window.addFlags(WindowManager.LayoutParams.FLAG_DRAWS_SYSTEM_BAR_BACKGROUNDS);
window.setStatusBarColor(Color.GRAY);
}
setContentView(R.layout.activity_main);
/*
* 在选择图片的时候,在android 7.0及以上通过FileProvider获取Uri,不需要文件权限
*/
if (Build.VERSION.SDK_INT < Build.VERSION_CODES.N) {
List permissionList = new ArrayList<>(Arrays.asList(neededPermissions));
permissionList.add(Manifest.permission.READ_EXTERNAL_STORAGE);
neededPermissions = permissionList.toArray(new String[0]);
}
initView();
TFLiteLoader loader = TFLiteLoader.newInstance(this);
interpreter = loader.get();
showToast("模型加载成功!");
mBitmap = BitmapFactory.decodeResource(getResources(), R.drawable.cup);
}
private void initView() {
tv_trash_detail = findViewById(R.id.tv_trash_detail);
iv_trash = findViewById(R.id.iv_trash);
tv_waste_name = findViewById(R.id.tv_waste_name);
}
private void showToast(String text) {
Toast.makeText(this, text, Toast.LENGTH_LONG).show();
}
// 更换图片
public void choose_image(View view) {
Intent intent = new Intent(Intent.ACTION_PICK);
intent.setDataAndType(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, "image/*");
startActivityForResult(intent, 0);
}
private int maxIndex = 0;
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
super.onActivityResult(requestCode, resultCode, data);
if (data == null || data.getData() == null) {
showToast("获取图片失败");
return;
}
try {
mBitmap = MediaStore.Images.Media.getBitmap(getContentResolver(), data.getData());
} catch (IOException e) {
e.printStackTrace();
}
// 识别图片
detect_image();
// 更新显示的图片
iv_trash.setImageBitmap(mBitmap);
// 更新垃圾分类的名称
tv_waste_name.setText(class_names[maxIndex]);
// 更新垃圾分类的介绍
String text = class_names[maxIndex];
if (text.contains("厨余垃圾"))
{
tv_trash_detail.setText(waste_detail[0]);
} else if (text.contains("有害垃圾")) {
tv_trash_detail.setText(waste_detail[1]);
} else if (text.contains("可回收物")) {
tv_trash_detail.setText(waste_detail[2]);
} else if (text.contains("其他垃圾")) {
tv_trash_detail.setText(waste_detail[3]);
}
}
// 识别图片
public void detect_image() {
// bitmap convert to array
float[][][][] pixels = getScaledMatrix(mBitmap, input);
interpreter.run(pixels, output);
for (int j = 0; j < output[0].length; j++) {
BigDecimal b = new BigDecimal(output[0][j]);
float f1 = b.setScale(3, BigDecimal.ROUND_HALF_UP).floatValue();
Log.i("Test", f1 + "--> "+ j);
}
float max = output[0][0];
for(int i = 1; i < output[0].length;i++){
if(max < output[0][i]){
max = output[0][i];
maxIndex = i;
}
}
String text = class_names[maxIndex];
// 显示Toast
showToast(text);
}
基于TensorFlow2.3.0的垃圾分类
四、资料下载
APK下载:https://wwi.lanzoup.com/itLZV0a53qni 密码:1a9y
源码下载(包含数据集、模型文件、APP源码) https://
item.taobao.com/item.htm?ft=t&id=681383960366