扑克牌识别

本文主要分为两部分,第一部分为使用moblienetv2 ssd lite训练一个识别扑克牌的模型,第2部分就是将训练好的模型部署到手机上运行,具体的运行效果如下所示:

pc端运行效果:

20220316

b站地址:视频链接

移动端:

扑克牌识别

app下载地址:下载地址

在这里分享一下数据集链接;
https://cloud.189.cn/t/eYRzu27ne6za (访问码:upa5)
该数据集制作不易,如有需要可前往下载。

1 网络训练

可以使用不同的网络,本文使用的是mobilenetv2-ssd-lite。具体可参考该github。
训练好模型以后将其转换成onnx格式。
onnx模型下载地址:
链接:https://pan.baidu.com/s/1kchAu_0B3KfJ3WkYY9IDXQ
提取码:g9fq

使用opencv加载onnx模型进行前向推理,代码如下所示:

import cv2
import numpy as np

label_dic=list({'card1':'0001','card2':'0002','card3':'0003','card4':'0004','card5':'0005',
          'card6':'0006','card7':'0007','card8':'0008','card9':'0009','card10':'0010',
          'card11':'0011','card12':'0012','card13':'0013',

          'second1':'0014','second2':'0015','second3':'0016','second4':'0017','second5':'0018',
          'second6':'0019','second7':'0020','second8':'0021','second9':'0022','second10':'0023',
          'second11':'0024','second12':'0025','second13':'0026',

          'three1':'0027','three2':'0028','three3':'0029','three4':'0030','three5':'0031','three6':'0032',
          'three7':'0033','three8':'0034','three9':'0035','three10':'0036','three11':'0037','three12':'0038',
          'three13':'0039',

          'four1':'0040','four2':'0041','four3':'0042','four4':'0043','four5':'0044','four6':'0045','four7':'0046',
          'four8':'0047','four9':'0048','four10':'0049','four11':'0050','four12':'0051','four13':'0052',
          'xiao':'0053','da':'0054'
          })

net=cv2.dnn.readNetFromONNX('./models/mb2-ssd-lite.onnx')

def dector(ori_image):
    h, w = ori_image.shape[:2]
    # ori_image=cv2.resize(ori_image,(300,300))
    blob = cv2.dnn.blobFromImage(ori_image,
                                 scalefactor=1.0 / 128,
                                 size=(300, 300),
                                 mean=[127, 127, 127],
                                 swapRB=True,
                                 crop=False)

    # Run a model
    net.setInput(blob)
    # net.getUnconnectedOutLayersNames()
    # boxes
    aa = net.getUnconnectedOutLayersNames()
    out = net.forward(['scores', 'boxes'])
    #scores =>(3000,54) boxes=>(3000,4)
    scores = out[0][0][...,1::]
    boxes = out[1][0]
    #3000
    max_score=np.max(scores,axis=-1)
    # mask=max_score>0.4
    max_index=np.argmax(scores,axis=-1)
    #
    # pick_score=max_score[mask]
    # pick_label=max_index[mask]
    # pick_box=boxes[mask]


    # out=detect(ori_image)
    box_id = cv2.dnn.NMSBoxes(boxes.tolist(), max_score.tolist(), 0.4, 0.45)
    if len(box_id)==0:
        return ori_image
    boxes = boxes[box_id]
    confidences = max_score[box_id]
    label=max_index[box_id]
    for box,confidence,la in zip(boxes,confidences,label):
        box=box*np.array([w,h,w,h])
        box=box.astype(np.int32)

        cv2.rectangle(ori_image,(box[0],box[1]),(box[2],box[3]),(255,255,0),2)

        cv2.putText(ori_image,f'{label_dic[la]}:{round(confidence,3)}',
                    (box[0]+20, box[1]+40),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    1,  # font scale
                    (255, 0, 255),
                    2)
    return ori_image


cap = cv2.VideoCapture(r'E:\card.mp4')
while True:
    r,ori_image=cap.read()
    # cv2.imshow('a231',ori_image)
    # cv2.waitKey(0)
    dector(ori_image)
    cv2.imshow('a4232',ori_image)
    cv2.waitKey(1)

**

2 移动端部署

使用opencv加载onnx模型,进行移动端部署。
**
mainactivaty.java代码

package com.myapp.puke;

import androidx.appcompat.app.AppCompatActivity;

import android.os.Bundle;
import android.Manifest;
import android.content.Context;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.media.MediaPlayer;
import android.os.Build;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.view.Window;
import android.view.WindowManager;
import android.widget.Toast;


import org.opencv.android.BaseLoaderCallback;
import org.opencv.android.CameraActivity;
import org.opencv.android.CameraBridgeViewBase;
import org.opencv.android.JavaCameraView;
import org.opencv.android.LoaderCallbackInterface;
import org.opencv.android.OpenCVLoader;
import org.opencv.android.Utils;
import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.core.MatOfFloat;
import org.opencv.core.MatOfInt;
import org.opencv.core.MatOfRect2d;
import org.opencv.core.Point;
import org.opencv.core.Rect2d;
import org.opencv.core.Scalar;
import org.opencv.core.Size;
import org.opencv.dnn.Dnn;
import org.opencv.dnn.Net;
import org.opencv.imgproc.Imgproc;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class MyCamera extends CameraActivity implements CameraBridgeViewBase.CvCameraViewListener2 {

    static {
        System.loadLibrary("native-lib");
    }

    private JavaCameraView mOpenCvCameraView;
    private int M_REQUEST_CODE = 203;
    private String[] permissions = {Manifest.permission.CAMERA};

    // Initialize OpenCV manager.
    private BaseLoaderCallback mLoaderCallback = new BaseLoaderCallback(this) {
        @Override
        public void onManagerConnected(int status) {
            switch (status) {
                case LoaderCallbackInterface.SUCCESS: {

                    mOpenCvCameraView.enableView();
                    break;
                }
                default:
                    break;

            }
        }
    };

    private MediaPlayer mediaPlayer;
    private Mat src;
    private Net net;
    final double IN_SCALE_FACTOR = 0.0078125;
    final double MEAN_VAL = 127.0;
    final double THRESHOLD = 0.2;
    final int IN_WIDTH = 300;
    final int IN_HEIGHT = 300;
    private Bitmap bp;

    private long time1=0;
    private long time2=0;
    @Override
    public void onResume() {
        super.onResume();
        if (!OpenCVLoader.initDebug()) {

        } else {
            mLoaderCallback.onManagerConnected(LoaderCallbackInterface.SUCCESS);
        }
    };

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_my_camera);

        //        透明状态栏
        if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.LOLLIPOP) {
            Window window = getWindow();
            window.clearFlags(WindowManager.LayoutParams.FLAG_TRANSLUCENT_STATUS
                    | WindowManager.LayoutParams.FLAG_TRANSLUCENT_NAVIGATION);
            window.getDecorView().setSystemUiVisibility(View.SYSTEM_UI_FLAG_LAYOUT_FULLSCREEN
                    | View.SYSTEM_UI_FLAG_LAYOUT_HIDE_NAVIGATION
                    | View.SYSTEM_UI_FLAG_LAYOUT_STABLE);
            window.addFlags(WindowManager.LayoutParams.FLAG_DRAWS_SYSTEM_BAR_BACKGROUNDS);
            window.setStatusBarColor(Color.TRANSPARENT);
            window.setNavigationBarColor(Color.TRANSPARENT);
        }
        // Set up camera listener.
        mOpenCvCameraView = (JavaCameraView)findViewById(R.id.CameraView);
        mOpenCvCameraView.setVisibility(CameraBridgeViewBase.VISIBLE);
        mOpenCvCameraView.setCvCameraViewListener(this);

        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
            requestPermissions(permissions, M_REQUEST_CODE);
        }

//        mediaPlayer= MediaPlayer.create(getApplicationContext(),R.raw.ye);
//        mediaPlayer.start();



    }
    public void onCameraViewStarted(int width, int height) {
        //加载模型
        String proto=getPath("mb2-ssd-lite.onnx",this);
        net= Dnn.readNetFromONNX(proto);

        bp=Bitmap.createBitmap(width,height, Bitmap.Config.ARGB_8888);
    }

    private List<Integer> music_index=new ArrayList<>();
    public Mat onCameraFrame(CameraBridgeViewBase.CvCameraViewFrame inputFrame) {

        //局部放大
//         mRgba=inputFrame.rgba();
//        Size sizeRgba = mRgba.size();
//        int rows = (int) sizeRgba.height;
//        int cols = (int) sizeRgba.width;
//
//        switch (statue){
//            case 0:
//                //Canny边缘检测
//                mRgba = inputFrame.rgba();
//                Imgproc.Canny(inputFrame.gray(), mTmp, 80, 100);
//                Imgproc.cvtColor(mTmp, mRgba, Imgproc.COLOR_GRAY2RGBA, 4);
//                break;
//            case 1:
//                //ZOOM放大镜
//                Mat zoomCorner = mRgba.submat(0, rows / 2 - rows / 10, 0, cols / 2 - cols / 10);
//                Mat mZoomWindow = mRgba.submat(rows / 2 - 9 * rows / 100, rows / 2 + 9 * rows / 100, cols / 2 - 9 * cols / 100, cols / 2 + 9 * cols / 100);
//                Imgproc.resize(mZoomWindow, zoomCorner, zoomCorner.size());
//                Size wsize = mZoomWindow.size();
//                Imgproc.rectangle(mZoomWindow, new Point(1, 1), new Point(wsize.width - 2, wsize.height - 2), new Scalar(255, 0, 0, 255), 2);
//                zoomCorner.release();
//                mZoomWindow.release();
//                break;
//
//        }
        //目标检测


        time2=System.currentTimeMillis();
        if (time2-time1>5000){
            time1=time2;
            music_index=new ArrayList<>();
        }


        src=inputFrame.rgba();
        Imgproc.cvtColor(src,src,Imgproc.COLOR_RGBA2RGB);

        Utils.matToBitmap(src,bp);
        bp=detection(bp);
        Utils.bitmapToMat(bp,src);
        //开始播放
//        mediaPlayer= MediaPlayer.create(getApplicationContext(),R.raw.m2);
//        mediaPlayer.start();

        List<Integer> index_class=new ArrayList<>();
        //取出大于5个下标
        for(int x=0;x<number_class.length;x++) {

            if (number_class[x] >= 2) {
                index_class.add(x);
            }
        }

        if(index_class.size()>0 ){
            for(int j:index_class){


                if (music_index.contains(j)){

                }else {
                    music_index.add(j);
                    mediaPlayer= MediaPlayer.create(getApplicationContext(),musices[j]);
                    mediaPlayer.start();
                    Log.i("aa",""+music_index);

                }
//                if (music_index!=j){
//
//                }

            }
            number_class=new int[54];
        }

        return src;
    }
    public void onCameraViewStopped() {
//        src.release();

    }




    @Override
    public void onPause() {
        super.onPause();
        if (mOpenCvCameraView != null)
            mOpenCvCameraView.disableView();
    }

    public void onDestroy() {
        super.onDestroy();
        mOpenCvCameraView.disableView();
    }

    @Override
    protected List<? extends CameraBridgeViewBase> getCameraViewList() {
        return Collections.singletonList(mOpenCvCameraView);
    }

    private static String getPath(String file, Context context) {
        AssetManager assetManager = context.getAssets();
        BufferedInputStream inputStream = null;
        try {
            // Read data from assets.
            inputStream = new BufferedInputStream(assetManager.open(file));
            byte[] data = new byte[inputStream.available()];
            inputStream.read(data);
            inputStream.close();
            // Create copy file in storage.
            File outFile = new File(context.getFilesDir(), file);
            FileOutputStream os = new FileOutputStream(outFile);
            os.write(data);
            os.close();
            // Return a path to file which may be read in common way.
            return outFile.getAbsolutePath();
        } catch (IOException ex) {

        }
        return "";
    }
//    private static final String TAG = "OpenCV/Sample/MobileNet";

    private static final String[] classNames = {"黑桃A",
            "黑桃2", "黑桃3", "黑桃4", "黑桃5","黑桃6", "黑桃7", "黑桃8", "黑桃9","黑桃10", "黑桃J", "黑桃Q", "黑桃K",
            "红桃A", "红桃2","红桃3","红桃4","红桃5","红桃6","红桃7","红桃8","红桃9","红桃10","红桃J","红桃Q","红桃K",
            "方块A","方块2","方块3","方块4","方块5","方块6","方块7","方块8","方块9","方块10","方块J","方块Q","方块K",
            "梅花A","梅花2","梅花3","梅花4","梅花5","梅花6","梅花7","梅花8","梅花9","梅花10","梅花J","梅花Q","梅花K",
            "小王","大王"};

    private static final int[] musices = {R.raw.m1,
            R.raw.m2, R.raw.m3, R.raw.m4, R.raw.m5,R.raw.m6, R.raw.m7, R.raw.m8, R.raw.m9,R.raw.m10, R.raw.m11, R.raw.m12,R.raw.m13,
            R.raw.m14, R.raw.m15, R.raw.m16, R.raw.m17,R.raw.m18, R.raw.m19, R.raw.m20, R.raw.m21,R.raw.m22, R.raw.m23, R.raw.m24,R.raw.m25,R.raw.m26,
            R.raw.m27, R.raw.m28, R.raw.m29, R.raw.m30,R.raw.m31, R.raw.m32, R.raw.m33, R.raw.m34,R.raw.m35, R.raw.m36, R.raw.m37,R.raw.m38,R.raw.m39,
            R.raw.m40, R.raw.m41, R.raw.m42, R.raw.m43,R.raw.m44, R.raw.m45, R.raw.m46, R.raw.m47,R.raw.m48, R.raw.m49, R.raw.m50,R.raw.m51,R.raw.m52,
            R.raw.m53,R.raw.m54};

    public native String stringFromJNI();

    //统计预测的结果
    private int[] number_class=new int[54];
    public Bitmap detection(Bitmap bp){
        Canvas can=new Canvas();
        Paint p=new Paint();
        android.graphics.Bitmap.Config bitmapConfig = bp.getConfig();
        bp = bp.copy(bitmapConfig, true);
        can=new Canvas(bp);
        p.setAntiAlias(true);
        //不填充,默认填充
        p.setStyle(Paint.Style.STROKE);
        //设置线条宽度
        p.setStrokeWidth(5);
        //设置颜色
        p.setColor(0xFF33FFFF);
        p.setTextAlign(Paint.Align.LEFT);
        p.setTextSize(50);

        Mat blob = Dnn.blobFromImage(src, IN_SCALE_FACTOR,
                new Size(IN_WIDTH, IN_HEIGHT),
                new Scalar(MEAN_VAL, MEAN_VAL, MEAN_VAL), false);
        net.setInput(blob);
        blob.release();

        //获取输出层的名字
        List<String> outnames=net.getUnconnectedOutLayersNames();
//        Log.i("aa",String.valueOf(outnames));
        //创建输出矩阵集合
        List<Mat> detections = new ArrayList<Mat>();
        net.forward(detections,outnames);
        //获取输出的盒子和置信度
        Mat scores=detections.get(0);
        Mat boxes= detections.get(1);
        scores= scores.reshape(1,3000).colRange(1,55);
        boxes= boxes.reshape(1,3000);
        Size ss=scores.size();
//        Log.i("aa",String.valueOf(scores));
//        Log.i("aa",String.valueOf(boxes));

        List<Rect2d> rect2dList=new ArrayList<>();//box信息
        List<Float> confList=new ArrayList<>();//置信度
        List<Integer> objIndexList=new ArrayList<>();//对象类别索引


        for(int i=0; i<scores.rows();i++){
            Mat one_row=scores.rowRange(i,i+1);
            Core.MinMaxLocResult max_index=Core.minMaxLoc(one_row);
            double max_value=max_index.maxVal;
            Point location=max_index.maxLoc;
            if(max_value>0.4){

                confList.add((float) max_value);
                objIndexList.add((int)location.x);

                Mat box_one=boxes.rowRange(i,i+1);
                float[] aa=new float[4];
                box_one.get(0,0,aa);

                double x1=aa[0];
                double y1=aa[1];
                double x2=aa[2];
                double y2=aa[3];
                rect2dList.add(new Rect2d(x1,y1,x2,y2));

            }

        }

        //去重
        //去重后的索引值
        MatOfInt index=new MatOfInt();
        //转换box的结果集
        MatOfRect2d boxe=new MatOfRect2d(rect2dList.toArray(new Rect2d[0]));
        //转换置信度结果集
        float[] confArr=new float[confList.size()];
        for(int j=0;j<confList.size();j++){
            confArr[j]=confList.get(j);
        }
        MatOfFloat con=new MatOfFloat(confArr);
        //使用nms去重
        Dnn.NMSBoxes(boxe,con,0.4f,0.45f,index);
        if (index.empty()){
            return bp;
        }
        //画框
        int[] ints=index.toArray();

        for(int x:ints){
//            Log.i("aa",String.valueOf(x));

            double[] aa=new double[4];
            boxe.get(x,0,aa);
            //Log.i("aa",String.valueOf(aa[0]));
            //Imgproc.rectangle(src,new Point(aa[0]*src.width(),aa[1]*src.height()-70),new Point(aa[2]*src.width()+200,aa[1]*src.height()),new Scalar(255,255,255),-1);
            //Imgproc.rectangle(src,new Point(aa[0]*src.width(),aa[1]*src.height()),new Point(aa[2]*src.width(),aa[3]*src.height()),new Scalar(255,255,0),10);

            //Imgproc.putText(src,""+classNames[objIndexList.get(x)]+":"+confList.get(x),new Point(aa[0]*src.width(),aa[1]*src.height()),Imgproc.FONT_HERSHEY_SIMPLEX, 3, new Scalar(0, 0, 0));
            //画框
            can.drawRect((float) aa[0]*src.width(),(float) aa[1]*src.height(),(float)aa[2]*src.width(),(float)aa[3]*src.height(),p);
            //绘制填充框
            p.setStyle(Paint.Style.FILL);
            p.setColor(0xFFFFCC00);
            can.drawRect((float) aa[0]*src.width(),(float) aa[1]*src.height()-60,(float)aa[2]*src.width()+150,(float)aa[1]*src.height(),p);
            //写字
            p.setColor(0xFFFF0000);
            can.drawText(classNames[objIndexList.get(x)]+": "+String.format("%.3f", confList.get(x)),(float)aa[0]*src.width(),(float) aa[1]*src.height()-10,p);
            p.setStyle(Paint.Style.STROKE);
            p.setColor(0xFF33FFFF);

            //统计类别信息
            number_class[objIndexList.get(x)]+=1;
        }

        scores.release();
        boxe.release();
        index.release();
        boxes.release();
        con.release();
        rect2dList.clear();
        confList.clear();
        objIndexList.clear();

        return bp;

    }





}

布局代码:


<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:opencv="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:visibility="visible">

    <org.opencv.android.JavaCameraView
        android:id="@+id/CameraView"
        android:layout_width="match_parent"
        android:layout_height="match_parent"
        opencv:camera_id="any"
        opencv:show_fps="true" />

RelativeLayout>

你可能感兴趣的:(深度学习,计算机视觉,opencv,pytorch,python)