AI应用基于DJL开发WEB应用对鞋分类进行预测和推理------AI

package com.alatus.djl.app;

import ai.djl.MalformedModelException;
import ai.djl.translate.TranslateException;
import com.alatus.djl.service.InterferenceService;
import com.alatus.djl.service.TrainService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestPart;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;

import java.io.IOException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadPoolExecutor;

@RestController
public class DeepLearning {
//    线性函数,就是一条线过去,现在我有两个点的坐标,我要知道第三个点的坐标
//
//            我只需要套用线性的数学公式,如Y=ax+b,只要能顺利得到前两个点的坐标
//
//    就能使用这个公式得到第三个点的坐标(我只需要套用一下第三个点,如果套不上,就说明他们不在一条线上)
//
//    机器学习是计算这些点的关系,如果我现在给出一堆点,我要求你计算这些点之间的关系
//
//            现在是很多很多点,如何能让他们成为一个线或者说一个集合
//
//    需要你从这些点的矩阵数据集合里找相同的规律确切的说是
//
//            因为这里的点很多很多,机器学习就需要不断的调整和计算,最终就会得到一个范围
//
//    再有新的数据进来,我就看你是不是在我的范围内即可
//            机器学习大概就是四类算法
//
//    Classification分类算法
//
//            Regression回归算法回归问题
//
//    clustering聚类问题
//
//    dimensionality reduction降维问题
//
//    机器学习算法选择
//
//    样本数是否大于50?是继续,否去收集数据去
//
//是否推理的是分类问题?是就看我们的数据是否标注完成
//
//            不是分类问题的就属于回归问题和降维问题
//
//    是分类问题,我们的数据是否有人工标注?
//
//    是就是监督学习的类型,走分类算法部分
//
//            没有就属于聚类问题,走聚类算法
    @Autowired
    private TrainService trainService;
    @Autowired
    private InterferenceService interferenceService;
    @Autowired
    private ThreadPoolTaskExecutor threadPoolExecutor;
    //训练模型的接口
    @GetMapping("/train")
    public String train() {
        try {
            trainService.train("build/ut-zap50k-images-square","build/models","footWeaver");
        } catch (IOException e) {
            throw new RuntimeException(e);
        } catch (TranslateException e) {
            throw new RuntimeException(e);
        }
        return "ok";
    }
    //进行推理
    @PostMapping("/result")
    public String predict(@RequestPart("file") MultipartFile image) {
        try {
            return trainService.predict(image);
        } catch (MalformedModelException e) {
            throw new RuntimeException(e);
        } catch (IOException e) {
            throw new RuntimeException(e);
        } catch (TranslateException e) {
            throw new RuntimeException(e);
        }
//        返回结果的分类
    }
}
package com.alatus.djl.app;

import ai.djl.MalformedModelException;
import ai.djl.translate.TranslateException;
import com.alatus.djl.service.InterferenceService;
import com.alatus.djl.service.TrainService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestPart;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;

import java.io.IOException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadPoolExecutor;

@RestController
public class DeepLearning {
//    线性函数,就是一条线过去,现在我有两个点的坐标,我要知道第三个点的坐标
//
//            我只需要套用线性的数学公式,如Y=ax+b,只要能顺利得到前两个点的坐标
//
//    就能使用这个公式得到第三个点的坐标(我只需要套用一下第三个点,如果套不上,就说明他们不在一条线上)
//
//    机器学习是计算这些点的关系,如果我现在给出一堆点,我要求你计算这些点之间的关系
//
//            现在是很多很多点,如何能让他们成为一个线或者说一个集合
//
//    需要你从这些点的矩阵数据集合里找相同的规律确切的说是
//
//            因为这里的点很多很多,机器学习就需要不断的调整和计算,最终就会得到一个范围
//
//    再有新的数据进来,我就看你是不是在我的范围内即可
//            机器学习大概就是四类算法
//
//    Classification分类算法
//
//            Regression回归算法回归问题
//
//    clustering聚类问题
//
//    dimensionality reduction降维问题
//
//    机器学习算法选择
//
//    样本数是否大于50?是继续,否去收集数据去
//
//是否推理的是分类问题?是就看我们的数据是否标注完成
//
//            不是分类问题的就属于回归问题和降维问题
//
//    是分类问题,我们的数据是否有人工标注?
//
//    是就是监督学习的类型,走分类算法部分
//
//            没有就属于聚类问题,走聚类算法
    @Autowired
    private TrainService trainService;
    @Autowired
    private InterferenceService interferenceService;
    @Autowired
    private ThreadPoolTaskExecutor threadPoolExecutor;
    //训练模型的接口
    @GetMapping("/train")
    public String train() {
        try {
            trainService.train("build/ut-zap50k-images-square","build/models","footWeaver");
        } catch (IOException e) {
            throw new RuntimeException(e);
        } catch (TranslateException e) {
            throw new RuntimeException(e);
        }
        return "ok";
    }
    //进行推理
    @PostMapping("/result")
    public String predict(@RequestPart("file") MultipartFile image) {
        try {
            return trainService.predict(image);
        } catch (MalformedModelException e) {
            throw new RuntimeException(e);
        } catch (IOException e) {
            throw new RuntimeException(e);
        } catch (TranslateException e) {
            throw new RuntimeException(e);
        }
//        返回结果的分类
    }
}
package com.alatus.djl.service.impl;

import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.ImageFolder;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import com.alatus.djl.Models;
import com.alatus.djl.service.TrainService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import java.io.BufferedWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;

@Service
@Slf4j
public class TrainServiceImpl implements TrainService {
    @Override
    public void train(String datasetPath, String modelName, String modelPath) throws IOException, TranslateException {
//        获取数据集
        ImageFolder imageFolder = initDataset(Paths.get(datasetPath));
//        分隔数据集
//        这可以把图片数据集做切分,80%的数据用于训练,20%的数据用于测试
        RandomAccessDataset[] randomAccessDatasets = imageFolder.randomSplit(8, 2);
//        训练集
        RandomAccessDataset trainDatasets = randomAccessDatasets[0];
//        验证集
        RandomAccessDataset validationDatasets = randomAccessDatasets[1];
//        定义模型(对多层神经网络进行封装,封装数学函数)
        try(Model model = Models.getModel();){
//            模型训练器,模型训练配置
//            首先准备好我们的训练配置,然后获取训练器
//            训练配置需要我们传递损失函数
            DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                    .addEvaluator(new Accuracy())//计算精度
                    .addTrainingListeners(TrainingListener.Defaults.logging());//训练监听器,我们这里想要训练日志输出
//            获取训练器
            Trainer trainer = model.newTrainer(config);
            trainer.setMetrics(new Metrics());
//            初始化使用一张图片,三种RGB色彩,100*100像素(先启动)
            trainer.initialize(new Shape(1, 3, 100, 100));
//            训练,使用快速训练
//            使用刚刚生成的训练器进行拟合训练
            EasyTrain.fit(trainer, 2, trainDatasets, validationDatasets);
//            训练完成保存模型
            model.save(Paths.get(modelPath), modelName);
            List synset = imageFolder.getSynset();
//            获取模型保存路径
            Path modelDir = Paths.get(modelPath);
//            直接在同一个目录下创建一个synset.txt文件
            Path resolve = modelDir.resolve("synset.txt");
            try(BufferedWriter bufferedWriter = Files.newBufferedWriter(resolve)){
//                把可迭代集合内容转为一个以换行符分隔的字符串,然后写入文件
                bufferedWriter.write(String.join("\n", synset));
            }
        };
        log.info("训练完成");
    }

    @Override
    public String predict(MultipartFile image) throws MalformedModelException, IOException, TranslateException {
        Image predictImage = ImageFactory.getInstance().fromInputStream(image.getInputStream());
//        拿到模型,算法模型只是一个公式,但是模型公式中每一个变量最终应该是多少都是训练得到的
        try(Model model = Models.getModel()){
//            这里的模型加载的本质就是我们训练得到的那一大堆的参数,所以加载模型就是把参数加载到内存中
//            再结合模型和公式,模型加载完成之后,模型就可以开始预测了
            model.load(Paths.get("build/models"));
//            模型是什么?
//            封装的计算函数和算法公式加上训练得到的一大堆参数数据,合在一起就是我们所称的模型
//            未训练的模型就是封装起来的算法公式和计算规则
//            训练好的模型就是公式和训练得到的一大堆参数数据,但是它本身没有推测作用,需要结合算法公式,才能进行预测

//            预测器里面需要一个转换器,来把需要预测的数据转换为模型认识可以理解的数据
            ImageClassificationTranslator classificationTranslator = ImageClassificationTranslator.builder()
                    .addTransform(new Resize(100, 100))//调整尺寸
                    .optApplySoftmax(true)
                    .addTransform(new ToTensor())//将图像N维数组从预处理的格式转换为神经网络格式的变换
                    .build();
//            模型需要一个预测器
            try(Predictor predictor = model.newPredictor(classificationTranslator);){
                Classifications predict = predictor.predict(predictImage);
                return predict.toString();
            }
        }
    }

    private ImageFolder initDataset(Path path) throws IOException, TranslateException {
//        当数据集完成加载以后,数据集就知道我们的图片要分成几类了,因为它是根据文件夹加载的
        ImageFolder folder = ImageFolder.builder().setRepositoryPath(path)//指定数据集的位置
                .setSampling(128, true)
                .optMaxDepth(10)
                .addTransform(new Resize(100, 100))
                .addTransform(new ToTensor())//把图片转为向量
                .build();
        folder.prepare(new ProgressBar());
        List synset = folder.getSynset();
        for (String s : synset) {
            log.info(s);
        }
//        之前的数字分类案例我们保存了一个synset的文件,这个文件就是用来保存图片的类别的枚举值
//        顺手保存一下syset文件也可以
        return folder;
    }
}

 

package com.alatus.djl.service.impl;

import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.ImageFolder;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import com.alatus.djl.Models;
import com.alatus.djl.service.TrainService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import java.io.BufferedWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;

@Service
@Slf4j
public class TrainServiceImpl implements TrainService {
    @Override
    public void train(String datasetPath, String modelName, String modelPath) throws IOException, TranslateException {
//        获取数据集
        ImageFolder imageFolder = initDataset(Paths.get(datasetPath));
//        分隔数据集
//        这可以把图片数据集做切分,80%的数据用于训练,20%的数据用于测试
        RandomAccessDataset[] randomAccessDatasets = imageFolder.randomSplit(8, 2);
//        训练集
        RandomAccessDataset trainDatasets = randomAccessDatasets[0];
//        验证集
        RandomAccessDataset validationDatasets = randomAccessDatasets[1];
//        定义模型(对多层神经网络进行封装,封装数学函数)
        try(Model model = Models.getModel();){
//            模型训练器,模型训练配置
//            首先准备好我们的训练配置,然后获取训练器
//            训练配置需要我们传递损失函数
            DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                    .addEvaluator(new Accuracy())//计算精度
                    .addTrainingListeners(TrainingListener.Defaults.logging());//训练监听器,我们这里想要训练日志输出
//            获取训练器
            Trainer trainer = model.newTrainer(config);
            trainer.setMetrics(new Metrics());
//            初始化使用一张图片,三种RGB色彩,100*100像素(先启动)
            trainer.initialize(new Shape(1, 3, 100, 100));
//            训练,使用快速训练
//            使用刚刚生成的训练器进行拟合训练
            EasyTrain.fit(trainer, 2, trainDatasets, validationDatasets);
//            训练完成保存模型
            model.save(Paths.get(modelPath), modelName);
            List synset = imageFolder.getSynset();
//            获取模型保存路径
            Path modelDir = Paths.get(modelPath);
//            直接在同一个目录下创建一个synset.txt文件
            Path resolve = modelDir.resolve("synset.txt");
            try(BufferedWriter bufferedWriter = Files.newBufferedWriter(resolve)){
//                把可迭代集合内容转为一个以换行符分隔的字符串,然后写入文件
                bufferedWriter.write(String.join("\n", synset));
            }
        };
        log.info("训练完成");
    }

    @Override
    public String predict(MultipartFile image) throws MalformedModelException, IOException, TranslateException {
        Image predictImage = ImageFactory.getInstance().fromInputStream(image.getInputStream());
//        拿到模型,算法模型只是一个公式,但是模型公式中每一个变量最终应该是多少都是训练得到的
        try(Model model = Models.getModel()){
//            这里的模型加载的本质就是我们训练得到的那一大堆的参数,所以加载模型就是把参数加载到内存中
//            再结合模型和公式,模型加载完成之后,模型就可以开始预测了
            model.load(Paths.get("build/models"));
//            模型是什么?
//            封装的计算函数和算法公式加上训练得到的一大堆参数数据,合在一起就是我们所称的模型
//            未训练的模型就是封装起来的算法公式和计算规则
//            训练好的模型就是公式和训练得到的一大堆参数数据,但是它本身没有推测作用,需要结合算法公式,才能进行预测

//            预测器里面需要一个转换器,来把需要预测的数据转换为模型认识可以理解的数据
            ImageClassificationTranslator classificationTranslator = ImageClassificationTranslator.builder()
                    .addTransform(new Resize(100, 100))//调整尺寸
                    .optApplySoftmax(true)
                    .addTransform(new ToTensor())//将图像N维数组从预处理的格式转换为神经网络格式的变换
                    .build();
//            模型需要一个预测器
            try(Predictor predictor = model.newPredictor(classificationTranslator);){
                Classifications predict = predictor.predict(predictImage);
                return predict.toString();
            }
        }
    }

    private ImageFolder initDataset(Path path) throws IOException, TranslateException {
//        当数据集完成加载以后,数据集就知道我们的图片要分成几类了,因为它是根据文件夹加载的
        ImageFolder folder = ImageFolder.builder().setRepositoryPath(path)//指定数据集的位置
                .setSampling(128, true)
                .optMaxDepth(10)
                .addTransform(new Resize(100, 100))
                .addTransform(new ToTensor())//把图片转为向量
                .build();
        folder.prepare(new ProgressBar());
        List synset = folder.getSynset();
        for (String s : synset) {
            log.info(s);
        }
//        之前的数字分类案例我们保存了一个synset的文件,这个文件就是用来保存图片的类别的枚举值
//        顺手保存一下syset文件也可以
        return folder;
    }
}
package com.alatus.djl.service;

import ai.djl.MalformedModelException;
import ai.djl.translate.TranslateException;
import org.springframework.web.multipart.MultipartFile;

import java.io.IOException;

/**
 * @author: Alatus
 * @create: 2025-02-12 17:48
 * @description: DJL训练模型
 **/
public interface TrainService {
    /**
     *
     * @数据集的位置 datasetPath
     * @模型名字 modelName
     * @模型存放路径 modelPath
     */
    void train(String datasetPath,String modelName,String modelPath) throws IOException, TranslateException;
    /**
     *
     * @图片 图片
     */
    String predict(MultipartFile image) throws MalformedModelException, IOException, TranslateException;
}

 

package com.alatus.djl.service;

import ai.djl.MalformedModelException;
import ai.djl.translate.TranslateException;
import org.springframework.web.multipart.MultipartFile;

import java.io.IOException;

/**
 * @author: Alatus
 * @create: 2025-02-12 17:48
 * @description: DJL训练模型
 **/
public interface TrainService {
    /**
     *
     * @数据集的位置 datasetPath
     * @模型名字 modelName
     * @模型存放路径 modelPath
     */
    void train(String datasetPath,String modelName,String modelPath) throws IOException, TranslateException;
    /**
     *
     * @图片 图片
     */
    String predict(MultipartFile image) throws MalformedModelException, IOException, TranslateException;
}


    4.0.0
    
        org.springframework.boot
        spring-boot-starter-parent
        3.4.2
         
    
    com.alatus
    DJL
    0.0.1-SNAPSHOT
    DJL
    DJL
    
    
        
    
    
        
    
    
        
        
        
        
    
    
        17
        0.26.0
    
    
        
            
                ai.djl
                bom
                ${djl.version}
                pom
                import
            
        
    
    

        
            org.springframework.boot
            spring-boot-starter-web
        

        
            org.projectlombok
            lombok
            true
        

        
            ai.djl
            api
        

        
            ai.djl
            basicdataset
        

        
            ai.djl
            model-zoo
        

        
            ai.djl.opencv
            opencv
        
        
            ai.djl.mxnet
            mxnet-engine
        

        
            ai.djl.pytorch
            pytorch-engine
            0.26.0
        
        
            ai.djl.pytorch
            pytorch-jni
            2.1.1-0.26.0
        

        
            ai.djl.onnxruntime
            onnxruntime-engine
            0.32.0
            runtime
            
                
                    com.microsoft.onnxruntime
                    onnxruntime
                
            
        
        
            com.microsoft.onnxruntime
            onnxruntime_gpu
            1.18.0
            runtime
        








        
            ai.djl.pytorch
            pytorch-native-cpu
            win-x86_64
            2.1.1
        
        
            ai.djl.mxnet
            mxnet-native-auto
            1.8.0
        
    
    
        
            
                org.springframework.boot
                spring-boot-maven-plugin
                
                    
                        
                            org.projectlombok
                            lombok
                        
                    
                
            
        
    

 



    4.0.0
    
        org.springframework.boot
        spring-boot-starter-parent
        3.4.2
         
    
    com.alatus
    DJL
    0.0.1-SNAPSHOT
    DJL
    DJL
    
    
        
    
    
        
    
    
        
        
        
        
    
    
        17
        0.26.0
    
    
        
            
                ai.djl
                bom
                ${djl.version}
                pom
                import
            
        
    
    

        
            org.springframework.boot
            spring-boot-starter-web
        

        
            org.projectlombok
            lombok
            true
        

        
            ai.djl
            api
        

        
            ai.djl
            basicdataset
        

        
            ai.djl
            model-zoo
        

        
            ai.djl.opencv
            opencv
        
        
            ai.djl.mxnet
            mxnet-engine
        

        
            ai.djl.pytorch
            pytorch-engine
            0.26.0
        
        
            ai.djl.pytorch
            pytorch-jni
            2.1.1-0.26.0
        

        
            ai.djl.onnxruntime
            onnxruntime-engine
            0.32.0
            runtime
            
                
                    com.microsoft.onnxruntime
                    onnxruntime
                
            
        
        
            com.microsoft.onnxruntime
            onnxruntime_gpu
            1.18.0
            runtime
        








        
            ai.djl.pytorch
            pytorch-native-cpu
            win-x86_64
            2.1.1
        
        
            ai.djl.mxnet
            mxnet-native-auto
            1.8.0
        
    
    
        
            
                org.springframework.boot
                spring-boot-maven-plugin
                
                    
                        
                            org.projectlombok
                            lombok
                        
                    
                
            
        
    

package com.alatus.djl;

import ai.djl.Model;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.SequentialBlock;

public class Models {
    private static final String modelName = "footWeaver";
    public static Model getModel() {
//        创建模型
        Model model = Model.newInstance(modelName);
//        设置神经网络,定义神经网络
//        残差网络,是一种CNN模型,专门用来解决图像识别问题
//        RGB有三种值,然后100和100是图像尺寸
//        这里表示的就是100*100的像素,每个像素都由RBG三种值构成
//        万物的世界,就是各种神经网络,因此我们需要把所有的内容都表示为某种章量向量

//        当我们知道每一个网络都是起到什么作用的时候,剩下就是我们的参数调整或者说调优即可
//        应用层只是套用算法,开发算法不在这里
//        机器学习应用层就是套现成的模型,然后调参,底层才需要开发算法
        SequentialBlock sequentialBlock = ResNetV1.builder()
                .setImageShape(new Shape(1, 3, 100, 100))
                .setNumLayers(50)
                .setOutSize(4)//分类数量
                .build();
//        以后训练就是这个模型,模型里面用的就是咱们这个神经网络(残差网络)
        model.setBlock(sequentialBlock);
        return model;
    }
}

 

package com.alatus.djl;

import ai.djl.Model;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.SequentialBlock;

public class Models {
    private static final String modelName = "footWeaver";
    public static Model getModel() {
//        创建模型
        Model model = Model.newInstance(modelName);
//        设置神经网络,定义神经网络
//        残差网络,是一种CNN模型,专门用来解决图像识别问题
//        RGB有三种值,然后100和100是图像尺寸
//        这里表示的就是100*100的像素,每个像素都由RBG三种值构成
//        万物的世界,就是各种神经网络,因此我们需要把所有的内容都表示为某种章量向量

//        当我们知道每一个网络都是起到什么作用的时候,剩下就是我们的参数调整或者说调优即可
//        应用层只是套用算法,开发算法不在这里
//        机器学习应用层就是套现成的模型,然后调参,底层才需要开发算法
        SequentialBlock sequentialBlock = ResNetV1.builder()
                .setImageShape(new Shape(1, 3, 100, 100))
                .setNumLayers(50)
                .setOutSize(4)//分类数量
                .build();
//        以后训练就是这个模型,模型里面用的就是咱们这个神经网络(残差网络)
        model.setBlock(sequentialBlock);
        return model;
    }
}
package com.alatus.djl.service.impl;

import com.alatus.djl.service.InterferenceService;
import org.springframework.stereotype.Service;

@Service
public class InterferenceServiceImpl implements InterferenceService {
}

 

package com.alatus.djl.service.impl;

import com.alatus.djl.service.InterferenceService;
import org.springframework.stereotype.Service;

@Service
public class InterferenceServiceImpl implements InterferenceService {
}
package com.alatus.djl.app;

import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.nio.file.Paths;

/**
 * author: Alatus
 * date: 2025/2/2
 * email: [email protected]
 * description:使用DJL训练大模型
 */
@RestController
@Slf4j
public class DJL {
//    测试模型
    @GetMapping("/predict")
    public String predict() throws IOException, MalformedModelException, TranslateException {
//        搞个图片先,准备测试数据
        Image image = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");

//        加载模型
        Path path = Paths.get("build/mlp");
        Model model = Model.newInstance("mlp");
        model.setBlock(new Mlp(Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES,new int[]{128,64}));
        model.load(path);

//        预测(给模型一个新的输入,让它来判断我们的输入内容)

//        获取一个转换器
        ImageClassificationTranslator build = ImageClassificationTranslator.builder()
                .addTransform(new ToTensor())//转换器
                .addTransform(new Resize(28, 28))//设置图片尺寸
                .build();
//        获取预测器
        Predictor predictor = model.newPredictor(build);
//        预测图片分类
        Classifications predict = predictor.predict(image);
        log.info(predict.toString());
        return predict.toString();
    }

    @GetMapping("/predictPic")
    public String predictImage() throws IOException, MalformedModelException, TranslateException {
        InputStream imageStream = getClass().getClassLoader().getResourceAsStream("static/3.png");
        Image image = null;
        if (imageStream == null) {
            // 处理图片没有找到的情况
            System.out.println("Image not found!");
        } else {
            image = ImageFactory.getInstance().fromInputStream(imageStream);
        }
        Model model = Model.newInstance("mlp");
        model.setBlock(new Mlp(Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES,new int[]{128,64}));
        model.load(Paths.get("build/mlp"));
        ImageClassificationTranslator build = ImageClassificationTranslator.builder().addTransform(new Resize(28, 28))
                .addTransform(new CenterCrop())//中心裁剪
                .addTransform(new ToTensor())
                .build();
        Predictor predictor = model.newPredictor(build);
        Classifications predict = predictor.predict(image);
        log.info(predict.toString());
        return predict.toString();
    }

    @GetMapping("/fullModel")
    public String fullModel() throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
//        完全训练一个模型
//        准备数据集(这里我们用的是官方自带的),用自己的数据集就自定义DataSet
        RandomAccessDataset trainDataset = getDataset(Dataset.Usage.TRAIN);
        RandomAccessDataset validationDataset = getDataset(Dataset.Usage.TEST);

//        自定义数据集的例子,这里的set方法的都是必须要填的
//        ImageFolder build = ImageFolder.builder()
//                .addTransform()//添加转换器
//                .optImageSize()//设置图片尺寸大小
//                .optImageWidth()
//                .optImageHeight()
//                .setSampling(64,true)//设置采样信息,一次64张图片,随机采样
//                .setRepositoryPath()
//                .build();//设置数据集的存储路径
//        build.getData()//获取数据集就得到了

//        构建神经网络,这边我们采用直接用多层感知机的方式,而不是Block块的方式
//        因为这个案例是自带的,我们直接使用Mnist的参数即可
//        这里的神经网路MLP本身就是一个Block块:public class Mlp extends SequentialBlock
        Mlp mlp = new Mlp(Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES,new int[]{128,64});
//        这里我们的层数如果太多了,他的精度会比较高,但也会导致一些过拟合的情况存在
//        过拟合就是过度敏感了,图片里面一点点接近的信息他都认为是它正确的
//        比如说照片里面有一个鸟,但是鸟的图片和某一个他认识的东西近似,他就认为是它认为的那个东西然后会认为它是正确的
//        这就叫过拟合
//        欠拟合就是训练量太少了,他认不出来,啥东西他都认为是自行车

        //        构建模型(模型应用上面的神经网络)
        try(Model model = Model.newInstance("mlp")){
//            设置神经网络
            model.setBlock(mlp);
//            接下来训练这个神经网络,配置了损失函数,精度,训练监听器
            String output = "build/mlp";

//            训练配置信息
            DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                    .addEvaluator(new Accuracy())
                    .addTrainingListeners(TrainingListener.Defaults.logging(output));


            //        训练配置(训练集)
//            基于给到的训练配置信息开始训练
            try(Trainer trainer = model.newTrainer(config)){
//                查看训练期间的详细指标数据
                trainer.setMetrics(new Metrics());
//                初始化训练器
                trainer.initialize(new Shape(1, Mnist.IMAGE_HEIGHT, Mnist.IMAGE_WIDTH));
//                接下来做拟合(拟合也就是训练)
//                trainer.fit(trainDataset, 10);
//                这里我们用EasyTrain来训练
//                训练5次
                EasyTrain.fit(trainer,17, trainDataset, validationDataset);
                TrainingResult result = trainer.getTrainingResult();
                log.info("训练结果:"+result.toString());

                //        保存模型
                model.save(Paths.get(output),"mlp");
                return "模型训练完成,并保存成功";
            }
        }
    }

    private RandomAccessDataset getDataset(Dataset.Usage usage) throws IOException {
        Mnist build = Mnist.builder().setSampling(64, true)//设置采样信息
                .optUsage(usage)
                .optLimit(64)
                .build();
//        弄个进度条
        build.prepare(new ProgressBar());
        return build;
    }


    //    机器学习最基本的,我们要把我们需要处理的数据给转为一个N维向量
//    只有转为一个向量了,才能继续向下处理
    @GetMapping("/test01")
    public String test01() {
        try(NDManager manager = NDManager.newBaseManager()){
//            我们通过这个manager创建向量
//            这里的Shape就是N维数组的形状
//            我们这里创建的是一个2乘以3的矩阵(N维向量)
//            这里的ones指的是内容都是1填充的
//            输出的1.代表这是一个1的float值
            NDArray ones = manager.ones(new Shape(2, 3));
            log.info(ones.toString());
//            这里,我们同样可以自己创建一个矩阵
//            通过创建对应的数组和给予我们需要的形状来创建一个矩阵
            NDArray array = manager.create(new float[]{1.14F, 5.14F, 1.9F, 1.9F, 8.10F, 1.14f}, new Shape(2, 3));
            log.info(array.toString());
//            矩阵计算
//            如矩阵转质,这里我们的矩阵二乘三的矩阵就变成了三乘二
            NDArray transpose = array.transpose();
            log.info(transpose.toString());
            return "矩阵"+ones+"和"+array+"的转置为"+transpose;
        }
    }
//    我们这里的这些矩阵你可以模拟为从数据集中加载的
//    数据集是用于训练机器学习模型的数据集合
//    机器学习通常使用三个数据集,训练集,验证集和测试集

//    训练集是我们用来训练的实际数据集,模型从这些数据中学习权重和参数

//    验证集用来在训练过程中评估给定模型,它帮助机器学习工程师在模型开发阶段微调超参数
//    模型不从验证数据集学习,验证数据集是可选的

//    测试数据集提供了用于评估模型性能的黄金标准,它只在模型完全训练完成后使用
//    测试数据集应该更准确的评估模型将如何在新数据上执行

//    当我们有了数据集以后
//    数据集加载为N维向量,我们需要通过Translator来转换数据集

    @GetMapping("/test02")
    public String test02() {
//        输入的图片像素
        long inputSize = 28 * 28;
//        输出的图片类型
        long outputSize = 10;
//        整一个批量扁平块,把二维图像输入转为一维特征向量
        SequentialBlock block = new SequentialBlock();
//        添加扁平块
        block.add(Blocks.batchFlattenBlock(inputSize));
//        添加一个隐藏层,线性变化大小为128
        block.add(Linear.builder().setUnits(128).build());
//        添加相应的激活函数
        block.add(Activation::relu);

//        第二个隐藏层的激活函数,这一层是大小为64的变化
        block.add(Linear.builder().setUnits(64).build());
        block.add(Activation::relu);

//        我试着添加一个32的隐藏层激活函数
        block.add(Linear.builder().setUnits(32).build());
        block.add(Activation::relu);

//        最后输出10大小的特征向量
        block.add(Linear.builder().setUnits(outputSize).build());

//        这些大小是在实验过程中选择的
//        围绕块,可以构建我们的模型,添加一些重要的元数据,如可以在训练和推理时使用的超参数
        Model model = Model.newInstance("mlp");
        model.setBlock(block);
//        现在就拥有了块和模型了,剩下的就是如何进行训练的部分
        return "构建块和模型";
    }

//    因此本质上,我们的模型训练就是,我们构建一个函数量,再经由我们创建的这个模型
//    模型内部使用的就是我们配置的激活函数
//    一层一层训练,直到最后精度损失控制到一定程度,停止训练
//    然后我们就可以使用这个模型进行预测了
//    模型的工作原理就是,由一个N维数组经过训练和优化,变成一个N-1维数组(另外一个N维数组)
//    也可以说,所谓的模型就是对我们的Block量的一个封装

//    也就是我们自己配置和封装对应的算法和输入输出,得到我们需要的模型,接下来就是使用这个模型进行训练,让它的精度损失控制到我们需要的水平
//    再经过正向传播,反向传播的多轮训练,直到精度损失控制到我们想要的水平

//    除此以外就是需要把我们的数据或者说训练内容进行转换,使用Translator,得到对应的N维数组
    @GetMapping("/test03")
    public String test03() {
//        这里我们使用Translator进行数据预处理的主要目的是,解决训练数据的格式不一致的问题
//        毕竟输入的数据集和模型训练的维度不一定完全一致,所以需要将我们的数据集进行预处理,使其符合训练的维度
//        比如说这里我们如果不是28*28的图像,但是模型训练的维度是28*28,所以需要将我们的数据集进行预处理,使其符合训练的维度
        ImageClassificationTranslator classificationTranslator = ImageClassificationTranslator.builder().addTransform(new CenterCrop())//中心裁剪
                .addTransform(new Resize(28, 28))//调整尺寸
                .addTransform(new ToTensor())//将图像N维数组从预处理的格式转换为神经网络格式的变换
                .build();
        return "参数预处理";
    }

//    最终达到的效果就是,把输入数据,不管是图片视频文字语言还是任何其他的东西转为机器能识别的数据量或者说数字量
//    更直接就是N维数组,将输出的内容转为人类可以识别的内容,和对应的概率
//    数字化,数字化,任何东西严格意义上都可以用数字来表示,只要你有办法去表示,那么你就可以用数字来表示

    @GetMapping("/test04")
    public String test04() throws TranslateException, ModelNotFoundException, MalformedModelException, IOException {
        ImageClassificationTranslator classificationTranslator = ImageClassificationTranslator.builder().addTransform(new CenterCrop())//中心裁剪
                .addTransform(new Resize(28, 28))//调整尺寸
                .addTransform(new ToTensor())//将图像N维数组从预处理的格式转换为神经网络格式的变换
                .build();
        Criteria criteria = Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                .optApplication(Application.CV.IMAGE_CLASSIFICATION)
                .optFilter("layer","50")
                .optTranslator(classificationTranslator)
                .optProgress(new ProgressBar())
                .build();
        ZooModel model = ModelZoo.loadModel(criteria);
        log.info(model.getName());
        Predictor predictor = model.newPredictor();
//        这里我们传递一个图片给他去做预测
//        假装有这么一个图片
        Classifications predict = predictor.predict(null);
        return "使用模型";
    }
}

 

package com.alatus.djl.app;

import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.nio.file.Paths;

/**
 * author: Alatus
 * date: 2025/2/2
 * email: [email protected]
 * description:使用DJL训练大模型
 */
@RestController
@Slf4j
public class DJL {
//    测试模型
    @GetMapping("/predict")
    public String predict() throws IOException, MalformedModelException, TranslateException {
//        搞个图片先,准备测试数据
        Image image = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");

//        加载模型
        Path path = Paths.get("build/mlp");
        Model model = Model.newInstance("mlp");
        model.setBlock(new Mlp(Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES,new int[]{128,64}));
        model.load(path);

//        预测(给模型一个新的输入,让它来判断我们的输入内容)

//        获取一个转换器
        ImageClassificationTranslator build = ImageClassificationTranslator.builder()
                .addTransform(new ToTensor())//转换器
                .addTransform(new Resize(28, 28))//设置图片尺寸
                .build();
//        获取预测器
        Predictor predictor = model.newPredictor(build);
//        预测图片分类
        Classifications predict = predictor.predict(image);
        log.info(predict.toString());
        return predict.toString();
    }

    @GetMapping("/predictPic")
    public String predictImage() throws IOException, MalformedModelException, TranslateException {
        InputStream imageStream = getClass().getClassLoader().getResourceAsStream("static/3.png");
        Image image = null;
        if (imageStream == null) {
            // 处理图片没有找到的情况
            System.out.println("Image not found!");
        } else {
            image = ImageFactory.getInstance().fromInputStream(imageStream);
        }
        Model model = Model.newInstance("mlp");
        model.setBlock(new Mlp(Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES,new int[]{128,64}));
        model.load(Paths.get("build/mlp"));
        ImageClassificationTranslator build = ImageClassificationTranslator.builder().addTransform(new Resize(28, 28))
                .addTransform(new CenterCrop())//中心裁剪
                .addTransform(new ToTensor())
                .build();
        Predictor predictor = model.newPredictor(build);
        Classifications predict = predictor.predict(image);
        log.info(predict.toString());
        return predict.toString();
    }

    @GetMapping("/fullModel")
    public String fullModel() throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
//        完全训练一个模型
//        准备数据集(这里我们用的是官方自带的),用自己的数据集就自定义DataSet
        RandomAccessDataset trainDataset = getDataset(Dataset.Usage.TRAIN);
        RandomAccessDataset validationDataset = getDataset(Dataset.Usage.TEST);

//        自定义数据集的例子,这里的set方法的都是必须要填的
//        ImageFolder build = ImageFolder.builder()
//                .addTransform()//添加转换器
//                .optImageSize()//设置图片尺寸大小
//                .optImageWidth()
//                .optImageHeight()
//                .setSampling(64,true)//设置采样信息,一次64张图片,随机采样
//                .setRepositoryPath()
//                .build();//设置数据集的存储路径
//        build.getData()//获取数据集就得到了

//        构建神经网络,这边我们采用直接用多层感知机的方式,而不是Block块的方式
//        因为这个案例是自带的,我们直接使用Mnist的参数即可
//        这里的神经网路MLP本身就是一个Block块:public class Mlp extends SequentialBlock
        Mlp mlp = new Mlp(Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES,new int[]{128,64});
//        这里我们的层数如果太多了,他的精度会比较高,但也会导致一些过拟合的情况存在
//        过拟合就是过度敏感了,图片里面一点点接近的信息他都认为是它正确的
//        比如说照片里面有一个鸟,但是鸟的图片和某一个他认识的东西近似,他就认为是它认为的那个东西然后会认为它是正确的
//        这就叫过拟合
//        欠拟合就是训练量太少了,他认不出来,啥东西他都认为是自行车

        //        构建模型(模型应用上面的神经网络)
        try(Model model = Model.newInstance("mlp")){
//            设置神经网络
            model.setBlock(mlp);
//            接下来训练这个神经网络,配置了损失函数,精度,训练监听器
            String output = "build/mlp";

//            训练配置信息
            DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                    .addEvaluator(new Accuracy())
                    .addTrainingListeners(TrainingListener.Defaults.logging(output));


            //        训练配置(训练集)
//            基于给到的训练配置信息开始训练
            try(Trainer trainer = model.newTrainer(config)){
//                查看训练期间的详细指标数据
                trainer.setMetrics(new Metrics());
//                初始化训练器
                trainer.initialize(new Shape(1, Mnist.IMAGE_HEIGHT, Mnist.IMAGE_WIDTH));
//                接下来做拟合(拟合也就是训练)
//                trainer.fit(trainDataset, 10);
//                这里我们用EasyTrain来训练
//                训练5次
                EasyTrain.fit(trainer,17, trainDataset, validationDataset);
                TrainingResult result = trainer.getTrainingResult();
                log.info("训练结果:"+result.toString());

                //        保存模型
                model.save(Paths.get(output),"mlp");
                return "模型训练完成,并保存成功";
            }
        }
    }

    private RandomAccessDataset getDataset(Dataset.Usage usage) throws IOException {
        Mnist build = Mnist.builder().setSampling(64, true)//设置采样信息
                .optUsage(usage)
                .optLimit(64)
                .build();
//        弄个进度条
        build.prepare(new ProgressBar());
        return build;
    }


    //    机器学习最基本的,我们要把我们需要处理的数据给转为一个N维向量
//    只有转为一个向量了,才能继续向下处理
    @GetMapping("/test01")
    public String test01() {
        try(NDManager manager = NDManager.newBaseManager()){
//            我们通过这个manager创建向量
//            这里的Shape就是N维数组的形状
//            我们这里创建的是一个2乘以3的矩阵(N维向量)
//            这里的ones指的是内容都是1填充的
//            输出的1.代表这是一个1的float值
            NDArray ones = manager.ones(new Shape(2, 3));
            log.info(ones.toString());
//            这里,我们同样可以自己创建一个矩阵
//            通过创建对应的数组和给予我们需要的形状来创建一个矩阵
            NDArray array = manager.create(new float[]{1.14F, 5.14F, 1.9F, 1.9F, 8.10F, 1.14f}, new Shape(2, 3));
            log.info(array.toString());
//            矩阵计算
//            如矩阵转质,这里我们的矩阵二乘三的矩阵就变成了三乘二
            NDArray transpose = array.transpose();
            log.info(transpose.toString());
            return "矩阵"+ones+"和"+array+"的转置为"+transpose;
        }
    }
//    我们这里的这些矩阵你可以模拟为从数据集中加载的
//    数据集是用于训练机器学习模型的数据集合
//    机器学习通常使用三个数据集,训练集,验证集和测试集

//    训练集是我们用来训练的实际数据集,模型从这些数据中学习权重和参数

//    验证集用来在训练过程中评估给定模型,它帮助机器学习工程师在模型开发阶段微调超参数
//    模型不从验证数据集学习,验证数据集是可选的

//    测试数据集提供了用于评估模型性能的黄金标准,它只在模型完全训练完成后使用
//    测试数据集应该更准确的评估模型将如何在新数据上执行

//    当我们有了数据集以后
//    数据集加载为N维向量,我们需要通过Translator来转换数据集

    @GetMapping("/test02")
    public String test02() {
//        输入的图片像素
        long inputSize = 28 * 28;
//        输出的图片类型
        long outputSize = 10;
//        整一个批量扁平块,把二维图像输入转为一维特征向量
        SequentialBlock block = new SequentialBlock();
//        添加扁平块
        block.add(Blocks.batchFlattenBlock(inputSize));
//        添加一个隐藏层,线性变化大小为128
        block.add(Linear.builder().setUnits(128).build());
//        添加相应的激活函数
        block.add(Activation::relu);

//        第二个隐藏层的激活函数,这一层是大小为64的变化
        block.add(Linear.builder().setUnits(64).build());
        block.add(Activation::relu);

//        我试着添加一个32的隐藏层激活函数
        block.add(Linear.builder().setUnits(32).build());
        block.add(Activation::relu);

//        最后输出10大小的特征向量
        block.add(Linear.builder().setUnits(outputSize).build());

//        这些大小是在实验过程中选择的
//        围绕块,可以构建我们的模型,添加一些重要的元数据,如可以在训练和推理时使用的超参数
        Model model = Model.newInstance("mlp");
        model.setBlock(block);
//        现在就拥有了块和模型了,剩下的就是如何进行训练的部分
        return "构建块和模型";
    }

//    因此本质上,我们的模型训练就是,我们构建一个函数量,再经由我们创建的这个模型
//    模型内部使用的就是我们配置的激活函数
//    一层一层训练,直到最后精度损失控制到一定程度,停止训练
//    然后我们就可以使用这个模型进行预测了
//    模型的工作原理就是,由一个N维数组经过训练和优化,变成一个N-1维数组(另外一个N维数组)
//    也可以说,所谓的模型就是对我们的Block量的一个封装

//    也就是我们自己配置和封装对应的算法和输入输出,得到我们需要的模型,接下来就是使用这个模型进行训练,让它的精度损失控制到我们需要的水平
//    再经过正向传播,反向传播的多轮训练,直到精度损失控制到我们想要的水平

//    除此以外就是需要把我们的数据或者说训练内容进行转换,使用Translator,得到对应的N维数组
    @GetMapping("/test03")
    public String test03() {
//        这里我们使用Translator进行数据预处理的主要目的是,解决训练数据的格式不一致的问题
//        毕竟输入的数据集和模型训练的维度不一定完全一致,所以需要将我们的数据集进行预处理,使其符合训练的维度
//        比如说这里我们如果不是28*28的图像,但是模型训练的维度是28*28,所以需要将我们的数据集进行预处理,使其符合训练的维度
        ImageClassificationTranslator classificationTranslator = ImageClassificationTranslator.builder().addTransform(new CenterCrop())//中心裁剪
                .addTransform(new Resize(28, 28))//调整尺寸
                .addTransform(new ToTensor())//将图像N维数组从预处理的格式转换为神经网络格式的变换
                .build();
        return "参数预处理";
    }

//    最终达到的效果就是,把输入数据,不管是图片视频文字语言还是任何其他的东西转为机器能识别的数据量或者说数字量
//    更直接就是N维数组,将输出的内容转为人类可以识别的内容,和对应的概率
//    数字化,数字化,任何东西严格意义上都可以用数字来表示,只要你有办法去表示,那么你就可以用数字来表示

    @GetMapping("/test04")
    public String test04() throws TranslateException, ModelNotFoundException, MalformedModelException, IOException {
        ImageClassificationTranslator classificationTranslator = ImageClassificationTranslator.builder().addTransform(new CenterCrop())//中心裁剪
                .addTransform(new Resize(28, 28))//调整尺寸
                .addTransform(new ToTensor())//将图像N维数组从预处理的格式转换为神经网络格式的变换
                .build();
        Criteria criteria = Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                .optApplication(Application.CV.IMAGE_CLASSIFICATION)
                .optFilter("layer","50")
                .optTranslator(classificationTranslator)
                .optProgress(new ProgressBar())
                .build();
        ZooModel model = ModelZoo.loadModel(criteria);
        log.info(model.getName());
        Predictor predictor = model.newPredictor();
//        这里我们传递一个图片给他去做预测
//        假装有这么一个图片
        Classifications predict = predictor.predict(null);
        return "使用模型";
    }
}

你可能感兴趣的:(#,AI,#,Spring-Boot框架,spring,boot,微服务,spring,cloud,后端,mybatis,stable,diffusion,chatgpt)