使用 java-onnx 部署 Meta-ai Segment anything 分割一切

使用 java-onnx 部署 Meta-ai Segment anything 分割一切_第1张图片

 

近日,Meta AI在官网发布了基础模型 Segment Anything Model(SAM)并开源,其本质是用GPT的方式(基于Transform 模型架构)让计算机具备理解了图像里面的一个个“对象”的通用能力。SAM模型建立了一个可以接受文本提示、基于海量数据(603138)训练而获得泛化能力的图像分割大模型。图像分割是计算机视觉中的一项重要任务,有助于识别和确认图像中的不同物体,把它们从背景中分离出来,这在自动驾驶(检测其他汽车、行人和障碍物)、医学成像(提取特定结构或潜在病灶)等应用中特别重要。

下面是 java 使用 onnx 进行推理的分割代码,提示抠图点进行分割,目前还没有文本交互式提示的部署按理。代码如下:

package tool.deeplearning;

import ai.onnxruntime.*;
import org.opencv.core.*;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;

import javax.imageio.ImageIO;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.nio.FloatBuffer;
import java.util.*;


/**
*   @desc : meta-ai sam , 使用抠图点进行分割
*   @auth : tyf
*   @date : 2023-04-25  09:34:40
*/
public class metaai_sam_test {


    // 模型1
    public static OrtEnvironment env1;
    public static OrtSession session1;

    // 模型2
    public static OrtEnvironment env2;
    public static OrtSession session2;


    // 模型1
    public static void init1(String weight) throws Exception{
        // opencv 库
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);

        env1 = OrtEnvironment.getEnvironment();
        session1 = env1.createSession(weight, new OrtSession.SessionOptions());

        // 打印模型信息,获取输入输出的shape以及类型:
        System.out.println("---------模型1输入-----------");
        session1.getInputInfo().entrySet().stream().forEach(n->{
            String inputName = n.getKey();
            NodeInfo inputInfo = n.getValue();
            long[] shape = ((TensorInfo)inputInfo.getInfo()).getShape();
            String javaType = ((TensorInfo)inputInfo.getInfo()).type.toString();
            System.out.println(inputName+" -> "+ Arrays.toString(shape)+" -> "+javaType);
        });
        System.out.println("---------模型1输出-----------");
        session1.getOutputInfo().entrySet().stream().forEach(n->{
            String outputName = n.getKey();
            NodeInfo outputInfo = n.getValue();
            long[] shape = ((TensorInfo)outputInfo.getInfo()).getShape();
            String javaType = ((TensorInfo)outputInfo.getInfo()).type.toString();
            System.out.println(outputName+" -> "+Arrays.toString(shape)+" -> "+javaType);
        });
//        session1.getMetadata().getCustomMetadata().entrySet().forEach(n->{
//            System.out.println("元数据:"+n.getKey()+","+n.getValue());
//        });

    }

    // 模型2
    public static void init2(String weight) throws Exception{
        // opencv 库
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);

        env2 = OrtEnvironment.getEnvironment();
        session2 = env2.createSession(weight, new OrtSession.SessionOptions());

        // 打印模型信息,获取输入输出的shape以及类型:
        System.out.println("---------模型2输入-----------");
        session2.getInputInfo().entrySet().stream().forEach(n->{
            String inputName = n.getKey();
            NodeInfo inputInfo = n.getValue();
            long[] shape = ((TensorInfo)inputInfo.getInfo()).getShape();
            String javaType = ((TensorInfo)inputInfo.getInfo()).type.toString();
            System.out.println(inputName+" -> "+ Arrays.toString(shape)+" -> "+javaType);
        });
        System.out.println("---------模型2输出-----------");
        session2.getOutputInfo().entrySet().stream().forEach(n->{
            String outputName = n.getKey();
            NodeInfo outputInfo = n.getValue();
            long[] shape = ((TensorInfo)outputInfo.getInfo()).getShape();
            String javaType = ((TensorInfo)outputInfo.getInfo()).type.toString();
            System.out.println(outputName+" -> "+Arrays.toString(shape)+" -> "+javaType);
        });
//        session2.getMetadata().getCustomMetadata().entrySet().forEach(n->{
//            System.out.println("元数据:"+n.getKey()+","+n.getValue());
//        });

    }


    public static class ImageObj{

        // 原始图片
        Mat src;
        Mat dst_3_1024_1024;
        BufferedImage image_3_1024_1024;
        float[][][] image_embeddings;
        ArrayList points;
        float[][] info;
        public ImageObj(String image) {
            this.src = this.readImg(image);
            this.dst_3_1024_1024 = this.resizeWithoutPadding(src,1024,1024);
            this.image_3_1024_1024 = mat2BufferedImage(this.dst_3_1024_1024);
        }
        public void setPoints(ArrayList points) {

            this.points = points;
        }

        public Mat readImg(String path){
            Mat img = Imgcodecs.imread(path);
            return img;
        }
        public Mat resizeWithoutPadding(Mat src,int inputWidth,int inputHeight){
            Mat resizedImage = new Mat();
            Size size = new Size(inputWidth, inputHeight);
            Imgproc.resize(src, resizedImage, size, 0, 0, Imgproc.INTER_AREA);
            return resizedImage;
        }
        public float[] chw2chw(float[][][] chw,int c,int h,int w){
            float[] res = new float[ c * h * w ];

            int index = 0;
            for(int i=0;i in = new HashMap<>();
            in.put("image_embeddings",_image_embeddings);
            in.put("point_coords", _point_coords);
            in.put("point_labels",_point_labels);
            in.put("has_mask_input",_has_mask_input);
            in.put("orig_im_size",_orig_im_size);
            in.put("mask_input",_mask_input);


            // 推理
            OrtSession.Result res = session2.run(in);

            float[][][] masks  = ((float[][][][])(res.get(0)).getValue())[0];
            float[][] iou_predictions  = ((float[][])(res.get(1)).getValue());
            float[][][][] low_res_masks  = ((float[][][][])(res.get(2)).getValue());


            int count = masks.length;

            for(int i=0;i < count;i++){
                float[][] info = masks[i];
                this.info = info;
                break;

            }

        }

        public BufferedImage mat2BufferedImage(Mat mat){
            BufferedImage bufferedImage = null;
            try {
                MatOfByte matOfByte = new MatOfByte();
                Imgcodecs.imencode(".jpg", mat, matOfByte);
                byte[] byteArray = matOfByte.toArray();
                ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(byteArray);
                bufferedImage = ImageIO.read(byteArrayInputStream);
            } catch (Exception e) {
                e.printStackTrace();
            }
            return bufferedImage;
        }

        public BufferedImage resize(BufferedImage img, int newWidth, int newHeight) {
            Image scaledImage = img.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH);
            BufferedImage scaledBufferedImage = new BufferedImage(newWidth, newHeight, BufferedImage.TYPE_INT_ARGB);
            Graphics2D g2d = scaledBufferedImage.createGraphics();
            g2d.drawImage(scaledImage, 0, 0, null);
            g2d.dispose();
            return scaledBufferedImage;
        }

        public void show(){


            int sub_w = info.length;
            int sub_h = info[0].length;

            for(int j=0;j0.5){
                        // 修改颜色为绿色
                        image_3_1024_1024.setRGB(k,j, Color.GREEN.getRGB());
                    }
                }
            }

            BufferedImage showImg = resize(image_3_1024_1024,src.width(),src.height());

            // 弹窗显示
            JFrame frame = new JFrame();
            frame.setTitle("Meta-ai: SAM");
            JPanel content = new JPanel();
            content.add(new JLabel(new ImageIcon(showImg)));
            frame.add(content);
            frame.pack();
            frame.setVisible(true);


        }

    }


    public static void main(String[] args) throws Exception{


        init1(new File("").getCanonicalPath()+
                "\\src\\main\\resources\\deeplearning\\metaai_sam\\encoder-vit_b.quant.onnx");

        init2(new File("").getCanonicalPath()+
                "\\src\\main\\resources\\deeplearning\\metaai_sam\\decoder-vit_b.quant.onnx");


        // 图片
        ImageObj imageObj = new ImageObj(new File("").getCanonicalPath()+
                "\\src\\main\\resources\\deeplearning\\metaai_sam\\truck.jpg");


        // 提示,这里使用抠图点进行提示,可以设置多个提示点
        ArrayList points = new ArrayList<>();
        points.add(new float[]{514,357,1});// 车窗户
        points.add(new float[]{555,377,1});// 车窗户
        points.add(new float[]{556,387,1});// 车窗户
        imageObj.setPoints(points);

        // 推理
        imageObj.infenence1();
        imageObj.infenence2();

        // 显示
        imageObj.show();


    }


}

你可能感兴趣的:(人工智能,opencv,meta-ai,sam,分割一切)