一个案例搞定策略模式

提到设计模式,只要是有过开发经验的开发人员都或多或少听过&用过设计模式,比如我们都能信手拈来的「单例模式」、「观察者模式」等等。当然也有我们平时不太常用,但众多优秀的开源框架中广泛使用的设计模式,例如著名的网络框架retrofit使用的「代理模式」、okhttp使用的「责任链模式」。

关于设计模式的文章,网上一搜一大堆,各位前辈都总结得非常好。可以说前人的技术分享大大降低了后人的学习门槛,使中国互联网整体技术水平成指数上升,感谢每一位热爱分享的Coder!

很早以前我就准备写一篇介绍策略模式的文章,但是始终没有一个较好的例子。在最近的项目中,我再次用到了策略模式,于是我决定将其作为本文讲解策略模式的案例。

本文会先直接通过实际案例的形式逐步带入策略模式,最后再给出策略模式的完整定义,这样更容易理解。

为了不偏离主题,提升阅读体验,本文所有代码都经过精简处理。

一、案例前戏

公司参加了某人工智能比赛,AI部门的同事使用TensorFlow训练了一个能根据呼吸音推测患上“肺气肿”概率的模型。需要在Android设备上通过APP + 听诊器完成呼吸音的采集,然后通过模型给出结论。

由于时间紧迫,不知同事从哪里搞来了一个半成品项目,该项目已经实现了呼吸音的采集功能,将音频保存为.wav文件。而我需要做的,就是将采集到的呼吸音交给模型,得出结论。

而TensorFlow模型在Android上是没办法直接使用的,必须要将TensorFlow模型转换为TensorFlow Lite才可以在Android上使用(反正又不是我转o( ̄︶ ̄)o)。

二、案例中期

趁着AI部门的同事还没有将TensorFlow转为TensorFlow Lite之前,我去TensorFlow Lite官网看了看使用文档,一切都是如此美妙,仅需3步就可以搞定一切:

  1. 加载并初始化模型文件
  2. 调用run方法传入inputObject和outputObject
  3. 得到结果

在这个项目中,inputObject其实就是.wav文件的字节数组,由于模型的返回结果是JSON,因此outputObject其实就是String。现在我就放心的摸鱼去了。

摸鱼的时光总是短暂的,不一会儿,同事就给了我TensorFlow Lite的模型文件,后缀是.tflite。按照官方文档的要求,我将模型文件audio.tflite放到了assets目录中,然后编写了如下代码:

private void stopRecord() {
    // 省略其它代码,这里已经获取到了wav录音的File对象
    File wavFile = ...;
    aiCheckLocal(wavFile);
}

private void aiCheckLocal(File file) {
    ByteBuffer buffer = loadModelFile(getAssets());
    // 初始化模型
    Interpreter tfLite = new Interpreter(buffer);
    byte[] inputBytes = FileUtils.fileToByteArray(file);
    String outputStr = "";
    // 调用模型
    tfLite.run(inputBytes, outputStr);
    // 输出模型的结论
    Log.d(TAG, outputStr);
}

private MappedByteBuffer loadModelFile(AssetManager assetManager) {
    // 读取assets目录下tflite模型的代码,无需关心
}

aiCheckLocal方法就是我编写的根据手机采集到的声音数据,利用TensorFlow Lite进行结果推测。大家在看代码的时候,不需要关注具体的细节。

呵呵,不到10分钟就搞定了。一运行,???:

Internal error: Unexpected failure when preparing tensor allocations: Encountered unresolved custom op: Switch.
    Node number 10 (Switch) failed to prepare.

最终查明这个错是因为AI部门同事给我的模型有问题,加载不了。短时间内他们也没办法解决,于是他们提出了让我调用接口完成音频数据的推测,也就是将这个过程变为了在线而不是本地模型推测了。同时,还需要我保留本地模型推测的代码,以便后期他们修复模型问题后,还是可以切换为本地模型推测。

这有什么难的,再加个在线推测的方法不就行了:

private void stopRecord() {
    // 省略其它代码,这里已经获取到了wav录音的File对象
    File wavFile = ...;
    aiCheckOnline(wavFile);
}

private void aiCheckOnline(File file) {
    // 通过网络推测代码
    OkHttpClient client = new OkHttpClient();
    MediaType mediaType = MediaType.Companion.parse("multipart/form-data");
    RequestBody fileBody = RequestBody.Companion.create(file, mediaType);

    RequestBody requestBody = new MultipartBody.Builder()
        .setType(MultipartBody.FORM)
        .addFormDataPart("sound", file.getName(), fileBody)
        .build();

    Request request = new Request.Builder()
        .url(REQUEST_RUL)
        .post(requestBody)
        .build();
    
    client.newCall(request).enqueue(new Callback() {
        @Override
        public void onFailure(@NotNull Call call, @NotNull IOException e) {
        }

        @Override
        public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
            String jsonStr = response.body().string();
            Log.d(TAG, jsonStr);
        }
    });
}

private void aiCheckLocal(File file) {
    // 本地tflite推测代码 省略
}

我又新增了一个aiCheckOnline的方法用于在线推测。这样,当使用在线推测的时候,调用aiCheckOnline,使用本地模型推测的时候,调用aiCheckLocal即可。但是我们回过头来想想,功能虽然是实现了,但这样真的好吗?现在的代码有如下问题:

  • 所有的实现都放在了同一个类中,导致这个类臃肿
  • 对于调用者来说,对于同一个行为(音频推测)的调用需要知道实现细节(本地 or 网络 or other)
  • 如果多处都需要使用这个推测功能,那么将来切换推测方式(如网络换为本地)的时候,需要修改多处代码

也许你觉得这些都不是问题,因为你已经有解决方案了:对于第一个问题,我们把这些方法写在一个单独的类中不就行了;对于第二个问题,我们在类中对外界提供一个统一个调用方法,在这个方法内部进行判断到底是需要何种实现方案;对于第三个问题,由于我们提供了统一的调用方法,因此这个问题也就不存在了。所以,我们的代码可以改成这样:

public class AiCheck {
    /**
     * 执行推测时的方式
     */
    private int type;

    /**
     * 方式1 本地模型推测
     */
    public static final int TYPE_LOCAL = 1;

    /**
     * 方式二 网络接口推测
     */
    public static final int TYPE_NETWORK = 2;

    public AiCheck(int type) {
        this.type = type;
    }

    public void check(File file) {
        // 根据方式,执行对应的方法
        if (type == TYPE_LOCAL) {
            aiCheckLocal(file);
        } else if (type == TYPE_NETWORK) {
            aiCheckOnline(file);
        }
    }

    private void aiCheckOnline(File file) {
        // 通过网络推测代码
        OkHttpClient client = new OkHttpClient();
        MediaType mediaType = MediaType.Companion.parse("multipart/form-data");
        RequestBody fileBody = RequestBody.Companion.create(file, mediaType);

        RequestBody requestBody = new MultipartBody.Builder()
                .setType(MultipartBody.FORM)
                .addFormDataPart("sound", file.getName(), fileBody)
                .build();

        Request request = new Request.Builder()
                .url(REQUEST_RUL)
                .post(requestBody)
                .build();

        client.newCall(request).enqueue(new Callback() {
            @Override
            public void onFailure(@NotNull Call call, @NotNull IOException e) {
            }

            @Override
            public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
                String jsonStr = response.body().string();
                Log.d(TAG, jsonStr);
            }
        });
    }

    private void aiCheckLocal(File file) {
        ByteBuffer buffer = loadModelFile(getAssets());
        // 初始化模型
        Interpreter tfLite = new Interpreter(buffer);
        byte[] inputBytes = FileUtils.fileToByteArray(file);
        String outputStr = "";
        // 调用模型
        tfLite.run(inputBytes, outputStr);
        // 输出模型的结论
        Log.d(TAG, outputStr);
    }

    private MappedByteBuffer loadModelFile(AssetManager assetManager) {
        // 读取assets目录下tflite模型的代码,无需关心
    }
}

这样,我们的调用就可以变为这样:

private AiCheck aiCheck = new AiCheck(AiCheck.TYPE_NETWORK);

private void stopRecord() {
    // 省略其它代码,这里已经获取到了wav录音的File对象
    File wavFile = ...;
    aiCheck.check(wavFile);
}

调用就变得如此清爽了。但是这样真的就没有任何问题了吗?其实还是有的:

  • 无论我们要修改本地推测还是在线推测,我们都需要在AiCheck这个类中直接修改
  • 假如我们要新增另一种方式,我们还是只能在AiCheck中新增,并且还要提供对应的type类型,还要增加一次if判断
  • 如果日后针对音频推测,变种出10种方式,那么AiCheck类中就会有10个方法、10个type,10个if判断与之对应,才能满足需要

这...怎么办?

三、案例高潮

到了这个时间点,必须要放出大招了——策略模式。我们知道设计模式分为了三类:创建型、行为型、结构型,而策略模式属于行为型。先不谈其定义,我们来看看如何使用策略模式改进当前的问题。

首先,对于我们要实现的这个功能,行为只有一个,那就是音频推测,因此我们可以将这个行为抽象成一个接口:

/**
 * 行为接口
 */
public interface IAudioCheckBehavior {
    void check(File wavFile);
}

接着,在当前的情况下,针对这个行为,我们需要两种实现方案:本地模型推测 和 在线推测。我们分别写两个类来实现这两种方案:

  • 本地模型推测方案:

    public class LocalAudioCheck implements IAudioCheckBehavior {
      @Override
        public void check(File wavFile) {
          ByteBuffer buffer = loadModelFile(getAssets());
            // 初始化模型
            Interpreter tfLite = new Interpreter(buffer);
            byte[] inputBytes = FileUtils.fileToByteArray(file);
            String outputStr = "";
            // 调用模型
            tfLite.run(inputBytes, outputStr);
            // 输出模型的结论
            Log.d(TAG, outputStr);
        }
        
        private MappedByteBuffer loadModelFile(AssetManager assetManager) {
          // 读取assets目录下tflite模型的代码,无需关心
      }
    }
    
  • 在线接口推测方案:

    public class OnlineAudioCheck implements IAudioCheckBehavior {
      @Override
        public void check(File wavFile) {
            // 通过网络推测代码
            OkHttpClient client = new OkHttpClient();
            MediaType mediaType = MediaType.Companion.parse("multipart/form-data");
            RequestBody fileBody = RequestBody.Companion.create(file, mediaType);
    
            RequestBody requestBody = new MultipartBody.Builder()
                    .setType(MultipartBody.FORM)
                    .addFormDataPart("sound", file.getName(), fileBody)
                    .build();
    
            Request request = new Request.Builder()
                    .url(REQUEST_RUL)
                    .post(requestBody)
                    .build();
    
            client.newCall(request).enqueue(new Callback() {
                @Override
                public void onFailure(@NotNull Call call, @NotNull IOException e) {
                }
    
                @Override
                public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
                    String jsonStr = response.body().string();
                    Log.d(TAG, jsonStr);
                }
            });
        }
    }
    

这两种方案都实现了IAudioCheckBehavior接口,并各自用自己的方式实现了接口中的check方法。如果后期要增加新的实现方案呢?我们可以再定义一个类,并实现IAudioCheckBehavior接口就行了,不会在原有类中去添加或修改代码。

最后,我们还需要定义一个Context类,这个类的作用是设置一个具体的策略供外界使用:

public class AudioCheckContext {
    private IAudioCheckBehavior behavior;

    public AudioCheckContext() {
        this.behavior = new OnlineAudioCheck();
    }

    public void setAudioCheckBehavior(IAudioCheckBehavior behavior) {
        this.behavior = behavior;
    }

    public void check(File wavFile) {
        behavior.check(wavFile);
    }
}

注意,这个behaviorIAudioCheckBehavior类型的,可赋值为它的任意子类,比如OnlineAudioCheck或是LocalAudioCheck,这其实就是多态的思想。如果外界没有明确指定,则behavior的默认实现是OnlineAudioCheck,当然外界也可手动指定behavior的实现类。

没错,这就是策略模式在本案例中的完整实现了,下面用起来就很爽了:

private AudioCheckContext audioCheckContext = new AudioCheckContext();

private void stopRecord() {
    // 省略其它代码,这里已经获取到了wav录音的File对象
    File wavFile = ...;
    audioCheckContext.check(wavFile);
}

首先创建了AudioCheckContext对象,然后调用了AudioCheckContextcheck方法,在check方法内部,会去调用IAudioCheckBehaviorcheck方法,由于我们没有在外界设置IAudioCheckBehavior,因此它的实现类默认是OnlineAudioCheck,转而就会去走在线推测音频的逻辑。

真爽,如果哪天AI部门的同事把模型搞定了,要求我改为本地模型推测,我只需新增或修改一行代码:

  • 方式一:新增一行代码:

    private AudioCheckContext audioCheckContext = new AudioCheckContext();
    
    private void stopRecord() {
        // 省略其它代码,这里已经获取到了wav录音的File对象
        File wavFile = ...;
        // 新增的代码在这里
        audioCheckContext.setAudioCheckBehavior(new LocalAudioCheck());
        audioCheckContext.check(wavFile);
    }
    
  • 方式二:修改一行代码:

    public class AudioCheckContext {
        private IAudioCheckBehavior behavior;
    
        public AudioCheckContext() {
          // 修改的代码在这里,原来是OnlineAudioCheck,现在是LocalAudioCheck
            this.behavior = new LocalAudioCheck();
        }
    
        public void setAudioCheckBehavior(IAudioCheckBehavior behavior) {
            this.behavior = behavior;
        }
    
        public void check(File wavFile) {
            behavior.check(wavFile);
        }
    }
    

如果我发现本地推测的实现代码有问题,直接去对应的LocalAudioCheck类修复就好了,不会影响到其它功能。

四、策略模式的定义

好了,现在是时候来看看策略模式的定义了:

定义了算法族,分别封装起来,让它们之间可以互相替换,此模式让算法的变化独立于使用算法的客户

有了上面的案例,理解这个定义就容易多了。

总结一下策略模式的核心思想:将变化的部分独立出来,将它们单独实现成算法类,并且这些算法是可以相互替换且对调用者隐藏实现细节。

你可能感兴趣的:(一个案例搞定策略模式)