TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11

原文:Mobile Deep Learning with TensorFlow Lite, ML Kit and Flutter

协议:CC BY-NC-SA 4.0

译者:飞龙

本文来自【ApacheCN 深度学习 译文集】,采用译后编辑(MTPE)流程来尽可能提升效率。

不要担心自己的形象,只关心如何实现目标。——《原则》,生活原则 2.3.c

六、构建人工智能认证系统

认证是任何应用中最突出的功能之一,无论它是本机移动软件还是网站,并且自从保护数据的需求以及与机密有关的隐私需求开始以来,认证一直是一个活跃的领域。 在互联网上共享的数据。 在本章中,我们将从基于 Firebase 的简单登录到应用开始,然后逐步改进以包括基于人工智能(AI)的认证置信度指标和 Google 的 ReCaptcha。 所有这些认证方法均以深度学习为核心,并提供了一种在移动应用中实现安全性的最新方法。

在本章中,我们将介绍以下主题:

  • 一个简单的登录应用
  • 添加 Firebase 认证
  • 了解用于认证的异常检测
  • 用于认证用户的自定义模型
  • 实现 ReCaptcha 来避免垃圾邮件
  • 在 Flutter 中部署模型

技术要求

对于移动应用,需要具有 Flutter 的 Visual Studio Code 和 Dart 插件以及 Firebase Console

GitHub 网址。

一个简单的登录应用

我们将首先创建一个简单的认证应用,该应用使用 Firebase 认证对用户进行认证,然后再允许他们进入主屏幕。 该应用将允许用户输入其电子邮件和密码来创建一个帐户,然后使他们随后可以使用此电子邮件和密码登录。

以下屏幕快照显示了应用的完整流程:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第1张图片

该应用的小部件树如下:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第2张图片

现在让我们详细讨论每个小部件的实现。

创建 UI

让我们从创建应用的登录屏幕开始。 用户界面UI)将包含两个TextFormField来获取用户的电子邮件 ID 和密码,RaisedButton进行注册/登录,以及FlatButton进行注册和登录操作之间的切换。

以下屏幕快照标记了将用于应用的第一个屏幕的小部件:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第3张图片

现在让我们创建应用的 UI,如下所示:

  1. 我们首先创建一个名为signup_signin_screen.dart的新 dart 文件。 该文件包含一个有状态的小部件– SignupSigninScreen
  2. 第一个屏幕中最上面的窗口小部件是TextField,用于获取用户的邮件 ID。 _createUserMailInput()方法可帮助我们构建窗口小部件:
 Widget _createUserMailInput() {
  return Padding(
     padding: const EdgeInsets.fromLTRB(0.0, 100.0, 0.0, 0.0),
     child: new TextFormField(
       maxLines: 1,
       keyboardType: TextInputType.emailAddress,
       autofocus: false,
       decoration: new InputDecoration(
           hintText: 'Email',
           icon: new Icon(
             Icons.mail,
             color: Colors.grey,
           )),
       validator: (value) => value.isEmpty ? 'Email can\'t be empty' : null,
       onSaved: (value) => _usermail = value.trim(),
     ),
   );
 }

首先,我们使用EdgeInsets.fromLTRB()为小部件提供了填充。 这有助于我们在四个基本方向的每个方向(即左,上,右和下)上创建具有不同值的偏移量。 接下来,我们使用maxLines(输入的最大行数)创建了TextFormField,其值为1作为子级,它接收用户的电子邮件地址。 另外,根据输入类型TextInputType.emailAddress,我们指定了将在属性keyboardType中使用的键盘类型。 然后,将autoFocus设置为false。 然后,我们在装饰属性中使用InputDecoration提供hintText "Email"和图标Icons.mail。 为了确保用户在没有输入电子邮件地址或密码的情况下不要尝试登录,我们添加了一个验证器。 当尝试使用空字段登录时,将显示警告“电子邮件不能为空”。 最后,我们通过使用trim()删除所有尾随空格来修剪输入的值,然后将输入的值存储在_usermail字符串变量中。

  1. 与“步骤 2”中的TextField相似,我们定义了下一个方法_createPasswordInput(),以创建用于输入密码的TextFormField()
Widget _createPasswordInput() {
   return Padding(
     padding: const EdgeInsets.fromLTRB(0.0, 15.0, 0.0, 0.0),
     child: new TextFormField(
       maxLines: 1,
       obscureText: true,
       autofocus: false,
       decoration: new InputDecoration(
           hintText: 'Password',
           icon: new Icon(
             Icons.lock,
             color: Colors.grey,
           )),
       validator: (value) => value.isEmpty ? 'Password can\'t be empty' : null,
       onSaved: (value) => _userpassword = value.trim(),
     ),
   );
 }

我们首先使用EdgeInsets.fromLTRB()在所有四个基本方向上提供填充,以在顶部提供15.0的偏移量。 接下来,我们创建一个TextFormField,其中maxLines1,并将obscureText设置为true,将autofocus设置为falseobscureText用于隐藏正在键入的文本。 我们使用InputDecoration提供hintText密码和一个灰色图标Icons.lock。 为确保文本字段不为空,使用了一个验证器,当传递空值时,该警告器会发出警告Password can't be empty,即用户尝试在不输入密码的情况下登录/注册。 最后,trim()用于删除所有尾随空格,并将密码存储在_userpassword字符串变量中。

  1. 接下来,我们在_SignupSigninScreenState外部声明FormMode枚举,该枚举在两种模式SIGNINSIGNUP之间运行,如以下代码片段所示:
enum FormMode { SIGNIN, SIGNUP }

我们将对该按钮使用此枚举,该按钮将使用户既可以登录又可以注册。 这将帮助我们轻松地在两种模式之间切换。 枚举是一组用于表示常量值的标识符。

使用enum关键字声明枚举类型。 在enum内部声明的每个标识符都代表一个整数值; 例如,第一标识符具有值0,第二标识符具有值1。 默认情况下,第一个标识符的值为0

  1. 让我们定义一个_createSigninButton()方法,该方法返回按钮小部件以使用户注册并登录:
 Widget _createSigninButton() {
   return new Padding(
       padding: EdgeInsets.fromLTRB(0.0, 45.0, 0.0, 0.0),
       child: SizedBox(
         height: 40.0,
         child: new RaisedButton(
           elevation: 5.0,
           shape: new RoundedRectangleBorder(borderRadius: new BorderRadius.circular(30.0)),
           color: Colors.blue,
           child: _formMode == FormMode.SIGNIN
               ? new Text('SignIn',
                   style: new TextStyle(fontSize: 20.0, color: Colors.white))
               : new Text('Create account',
                   style: new TextStyle(fontSize: 20.0, color: Colors.white)),
           onPressed: _signinSignup,
         ),
       ));
 }

我们从Padding开始,将45.0的按钮offset置于顶部,然后将SizedBox40.0height作为子项,并将RaisedButton作为其子项。 使用RoundedRectangleBorder()为凸起的按钮赋予圆角矩形形状,其边框半径为30.0,颜色为blue。 作为子项添加的按钮的文本取决于_formMode的当前值。 如果_formMode的值(FormMode枚举的一个实例)为FormMode.SIGNIN,则按钮显示SignIn,否则创建帐户。 按下按钮时将调用_signinSignup方法,该方法将在后面的部分中介绍。

  1. 现在,我们将第四个按钮添加到屏幕上,以使用户在SIGNINSIGNUP表单模式之间切换。 我们定义返回FlatButton_createSigninSwitchButton()方法,如下所示:
 Widget _createSigninSwitchButton() {
   return new FlatButton(
     child: _formMode == FormMode.SIGNIN
         ? new Text('Create an account',
             style: new TextStyle(fontSize: 18.0, fontWeight: FontWeight.w300))
         : new Text('Have an account? Sign in',
             style:
                 new TextStyle(fontSize: 18.0, fontWeight: FontWeight.w300)),
     onPressed: _formMode == FormMode.SIGNIN
         ? _switchFormToSignUp
         : _switchFormToSignin,
   );
 }

如果_formMode的当前值为SIGNIN并按下按钮,则应更改为SIGNUP并显示Create an account。 否则,如果_formModeSIGNUP作为其当前值,并且按下按钮,则该值应切换为由文本Have an account? Sign in表示的SIGNIN。 使用三元运算符创建RaisedButtonText子级时,添加了在文本之间切换的逻辑。 onPressed属性使用非常相似的逻辑,该逻辑再次检查_formMode的值以在模式之间切换并使用_switchFormToSignUp_switchFormToSignin方法更新_formMode的值。 我们将在“步骤 7”和 8 中定义_switchFormToSignUp_switchFormToSignin方法。

  1. 现在,我们定义_switchFormToSignUp()如下:
 void _switchFormToSignUp() {
   _formKey.currentState.reset();
   setState(() {
     _formMode = FormMode.SIGNUP;
   });
 }

此方法重置_formMode的值并将其更新为FormMode.SIGNUP。 更改setState()内部的值有助于通知框架该对象的内部状态已更改,并且 UI 可能需要更新。

  1. 我们以与_switchFormToSignUp()非常相似的方式定义_switchFormToSignin()
 void _switchFormToSignin() {
   _formKey.currentState.reset();
   setState(() {
     _formMode = FormMode.SIGNIN;
   });
 }

此方法重置_formMode的值并将其更新为FormMode.SIGNIN。 更改setState()内部的值有助于通知框架该对象的内部状态已更改,并且 UI 可能需要更新。

  1. 现在,让我们将所有屏幕小部件Email TextFieldPassword TextFiedSignIn ButtonFlatButton切换为在单个容器中进行注册和登录。 为此,我们定义了一种方法createBody(),如下所示:
 Widget _createBody(){
   return new Container(
       padding: EdgeInsets.all(16.0),
       child: new Form(
         key: _formKey,
         child: new ListView(
           shrinkWrap: true,
           children: <Widget>[
             _createUserMailInput(),
             _createPasswordInput(),
             _createSigninButton(),
             _createSigninSwitchButton(),
             _createErrorMessage(),
           ],
         ),
       )
    );
 }

此方法返回一个以Form作为子元素的新Container并为其填充16.0。 表单使用_formKey作为其键,并添加ListView作为其子级。 ListView的元素是我们在前述方法中创建的用于添加TextFormFieldsButtons的小部件。 shrinkWrap设置为true,以确保ListView仅占用必要的空间,并且不会尝试扩展和填充整个屏幕

Form类用于将多个FormFields一起分组和验证。 在这里,我们使用Form将两个TextFormFields,一个RaisedButton和一个FlatButton包装在一起。

  1. 这里要注意的一件事是,由于进行认证,因此用户最终将成为网络操作,因此可能需要一些时间来发出网络请求。 在此处添加进度条可防止在进行网络操作时 UI 的死锁。 我们声明boolean标志_loading,当网络操作开始时将其设置为true。 现在,我们定义一种_createCircularProgress()方法,如下所示:
 Widget _createCircularProgress(){
   if (_loading) {
     return Center(child: CircularProgressIndicator());
   } return Container(height: 0.0, width: 0.0,);
 }

仅当_loadingtrue并且正在进行网络操作时,该方法才返回CircularProgressIndicator()

  1. 最后,让我们在build()方法内添加所有小部件:
 @override
 Widget build(BuildContext context) {
   return new Scaffold(
       appBar: new AppBar(
         title: new Text('Firebase Authentication'),
       ),
       body: Stack(
         children: <Widget>[
           _createBody(),
           _createCircularProgress(),
         ],
       ));
 }

build()内部,添加包含应用标题的AppBar变量后,我们返回一个支架。 支架的主体包含一个带有子项的栈,这些子项是_createBody()_createCircularProgress() 函数调用返回的小部件。

现在,我们已经准备好应用的主要 UI 结构。

可以在这个页面中找到SignupSigninScreen的完整代码。

在下一部分中,我们将介绍将 Firebase 认证添加到应用中涉及的步骤。

添加 Firebase 认证

如前所述,在“简单登录应用”部分中,我们将使用用户的电子邮件和密码通过 Firebase 集成认证。

要在 Firebase 控制台上创建和配置 Firebase 项目,请参考“附录”。

以下步骤详细讨论了如何在 Firebase Console 上设置项目:

  1. 我们首先在 Firebase 控制台上选择项目:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第4张图片

  1. 接下来,我们将在Develop菜单中单击Authentication选项:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第5张图片

这将带我们进入认证屏幕。

  1. 迁移到登录标签并启用登录提供者下的“电子邮件/密码”选项:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第6张图片

这是设置 Firebase 控制台所需的全部。

接下来,我们将 Firebase 集成到代码中。 这样做如下:

  1. 迁移到 Flutter SDK 中的项目,然后将firebase-auth添加到应用级别build.gradle文件中:
implementation 'com.google.firebase:firebase-auth:18.1.0'
  1. 为了使FirebaseAuthentication在应用中正常工作,我们将在此处使用firebase_auth插件。 在pubspec.yaml文件的依赖项中添加插件依赖项:
firebase_auth: 0.14.0+4

确保运行flutter pub get以安装依赖项。

现在,让我们编写一些代码以在应用内部提供 Firebase 认证功能。

创建auth.dart

现在,我们将创建一个 Dart 文件auth.dart。 该文件将作为访问firebase_auth插件提供的认证方法的集中点:

  1. 首先,导入firebase_auth插件:
import 'package:firebase_auth/firebase_auth.dart';
  1. 现在,创建一个抽象类BaseAuth,该类列出了所有认证方法,并充当 UI 组件和认证方法之间的中间层:
abstract class BaseAuth {
 Future<String> signIn(String email, String password);
 Future<String> signUp(String email, String password);
 Future<String> getCurrentUser();
 Future<void> signOut();
}

顾名思义,这些方法将使用认证的四个主要函数:

  • signIn():使用电子邮件和密码登录已经存在的用户
  • signUp():使用电子邮件和密码为新用户创建帐户
  • getCurrentUser():获取当前登录的用户
  • signOut():注销已登录的用户

这里要注意的重要一件事是,由于这是网络操作,因此所有方法都异步操作,并在执行完成后返回Future值。

  1. 创建一个实现BaseAuthAuth类:
class Auth implements BaseAuth {
    //. . . . . 
}

在接下来的步骤中,我们将定义BaseAuth中声明的所有方法。

  1. 创建FirebaseAuth的实例:
final FirebaseAuth _firebaseAuth = FirebaseAuth.instance;
  1. signIn()方法实现如下:
 Future<String> signIn(String email, String password) async {
     AuthResult result = await _firebaseAuth.signInWithEmailAndPassword(email: email, password: password);
    FirebaseUser user = result.user;
    return user.uid;
}

此方法接收用户的电子邮件和密码,然后调用signInWithEmailAndPassword(),并传递电子邮件和密码以登录已经存在的用户。 登录操作完成后,将返回AuthResult实例。 我们将其存储在result中,还使用result.user,它返回FirebaseUser.。它可用于获取与用户有关的信息,例如他们的uidphoneNumberphotoUrl。 在这里,我们返回user.uid,它是每个现有用户的唯一标识。 如前所述,由于这是网络操作,因此它异步运行,并在执行完成后返回Future

  1. 接下来,我们将定义signUp()方法以添加新用户:
Future<String> signUp(String email, String password) async {
    AuthResult result = await _firebaseAuth.createUserWithEmailAndPassword(email: email, password: password);
    FirebaseUser user = result.user;
    return user.uid;
 }

前面的方法接收在注册过程中使用的电子邮件和密码,并将其值传递给createUserWithEmailAndPassword。 类似于上一步中定义的对象,此调用还返回AuthResult对象,该对象还用于提取FirebaseUser。 最后,signUp方法返回新创建的用户的uid

  1. 现在,我们将定义getCurrentUser()
 Future<String> getCurrentUser() async {
   FirebaseUser user = await _firebaseAuth.currentUser();
   return user.uid;
 }

在先前定义的函数中,我们使用_firebaseAuth.currentUser()提取当前登录用户的信息。 此方法返回包装在FirebaseUser对象中的完整信息。 我们将其存储在user变量中。 最后,我们使用user.uid返回用户的uid

  1. 接下来,我们执行signOut()
Future<void> signOut() async {
   return _firebaseAuth.signOut();
 }

此函数仅在当前FirebaseAuth实例上调用signOut()并注销已登录的用户。

至此,我们已经完成了用于实现 Firebase 认证的所有基本编码。

可以在这个页面中查看auth.dart中的整个代码。

现在让我们看看如何在应用内部使认证生效。

SignupSigninScreen中添加认证

在本节中,我们将在SignupSigninScreen中添加 Firebase 认证。

我们在signup_signin_screen.dart文件中定义了_signinSignup()方法。 当按下登录按钮时,将调用该方法。 该方法的主体如下所示:

 void _signinSignup() async {
   setState(() {
     _loading = true;
   });
     String userId = "";     
       if (_formMode == FormMode.SIGNIN) {
         userId = await widget.auth.signIn(_usermail, _userpassword);
       } else {
         userId = await widget.auth.signUp(_usermail, _userpassword);
       }
       setState(() {
         _loading = false;
       });
       if (userId.length > 0 && userId != null && _formMode == FormMode.SIGNIN) {
         widget.onSignedIn();
       }
}

在上述方法中,我们首先将_loading的值设置为true,以便进度条显示在屏幕上,直到登录过程完成。 接下来,我们创建一个userId字符串,一旦登录/登录操作完成,该字符串将存储userId的值。 现在,我们检查_formMode的当前值。 如果等于FormMode.SIGNIN,则用户希望登录到现有帐户。 因此,我们使用传递到SignupSigninScreen构造器中的实例来调用Auth类内部定义的signIn()方法。

这将在后面的部分中详细讨论。 否则,如果_formMode的值等于FormMode.SIGNUP,则将调用Auth类的signUp()方法,并传递用户的邮件和密码以创建新帐户。 一旦成功完成登录/注册,userId变量将用于存储用户的 ID。 整个过程完成后,将_loading设置为false,以从屏幕上删除循环进度指示器。 另外,如果在用户登录到现有帐户时userId具有有效值,则将调用onSignedIn(),这会将用户定向到应用的主屏幕。

此方法也传递给SignupSigninScreen的构造器,并将在后面的部分中进行讨论。 最后,我们将整个主体包裹在try-catch块中,以便在登录过程中发生的任何异常都可以捕获而不会导致应用崩溃,并可以在屏幕上显示。

创建主屏幕

我们还需要确定认证状态,即用户在启动应用时是否已登录,如果已经登录,则将其定向到主屏幕。如果尚未登录,则应显示SignInSignupScreen 首先,在完成该过程之后,将启动主屏幕。 为了实现这一点,我们在新的 dart 文件main_screen.dart中创建一个有状态的小部件MainScreen,然后执行以下步骤:

  1. 我们将从定义枚举AuthStatus开始,该枚举表示用户的当前认证状态,可以登录或不登录:
enum AuthStatus {
 NOT_SIGNED_IN,
 SIGNED_IN,
}
  1. 现在,我们创建enum类型的变量来存储当前认证状态,其初始值设置为NOT_SIGNED_IN
AuthStatus authStatus = AuthStatus.NOT_SIGNED_IN;
  1. 初始化小部件后,我们将通过覆盖initState()方法来确定用户是否已登录:
 @override
 void initState() {
   super.initState();
   widget.auth.getCurrentUser().then((user) {
     setState(() {
       if (user != null) {
         _userId = user;
       }
       authStatus =
           user == null ? AuthStatus.NOT_SIGNED_IN : AuthStatus.SIGNED_IN;
     });
   });
 }

使用在构造器中传递的类的实例调用Auth类的getCurrentUser()。 如果该方法返回的值不为null,则意味着用户已经登录。因此,_userId字符串变量的值设置为返回的值。 另外,将authStatus设置为AuthStatus.SIGNED_IN.,否则,如果返回的值为null,则意味着没有用户登录,因此authStatus的值设置为AuthStatus.NOT_SIGNED_IN

  1. 现在,我们将定义另外两个方法onSignIn()onSignOut(),以确保将认证状态正确存储在变量中,并相应地更新用户界面:
void _onSignedIn() {
   widget.auth.getCurrentUser().then((user){
     setState(() {
       _userId = user;
     });
   });
   setState(() {
     authStatus = AuthStatus.SIGNED_IN;
   });
 }
 void _onSignedOut() {
   setState(() {
     authStatus = AuthStatus.NOT_SIGNED_IN;
     _userId = "";
   });
 }

_onSignedIn()方法检查用户是否已经登录,并将authStatus设置为AuthStatus.SIGNED_IN._onSignedOut()方法检查用户是否已注销,并将authStatus设置为AuthStatus.SIGNED_OUT

  1. 最后,我们重写build方法将用户定向到正确的屏幕:
 @override
 Widget build(BuildContext context) {
   if(authStatus == AuthStatus.SIGNED_OUT) {
     return new SignupSigninScreen(
       auth: widget.auth,
       onSignedIn: _onSignedIn,
     );
   } else {
     return new HomeScreen(
       userId: _userId,
       auth: widget.auth,
       onSignedOut: _onSignedOut,
       );
   }
 }

如果authStatusAuthStatus.SIGNED_OUT,则返回SignupSigninScreen,并传递auth实例和_onSignedIn()方法。 否则,将直接返回HomeScreen,并传递已登录用户的userIdAuth实例类和_onSignedOut()方法。

可以在此处查看main_screen.dart的完整代码。

在下一部分中,我们将为应用添加一个非常简单的主屏幕。

创建主屏幕

由于我们对认证部分更感兴趣,因此主屏幕(即成功登录后指向用户的屏幕)应该非常简单。 它仅包含一些文本和一个注销选项。 正如我们对所有先前的屏幕和小部件所做的一样,我们首先创建一个home_screen.dart文件和一个有状态的HomeScreen小部件。

主屏幕将显示如下:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第7张图片

此处的完整代码位于重写的build()方法内部:

 @override
 Widget build(BuildContext context) {
     return new Scaffold(
         appBar: new AppBar(
             title: new Text('Firebase Authentication'),
             actions: <Widget>[
                 new FlatButton(
                     child: new Text('Logout',
                     style: new TextStyle(fontSize: 16.0, color: Colors.white)),
                     onPressed: _signOut
                 )
             ],
         ),
         body: Center(child: new Text('Hello User', 
         style: new TextStyle(fontSize: 32.0))
         ),
     );
 }

我们在此处返回Scaffold,其中包含标题为Text Firebase AuthenticationAppBaractions属性的小部件列表。 actions用于在应用标题旁边添加小部件列表到应用栏中。 在这里,它仅包含FlatButtonLogout,在按下时将调用_signOut

_signOut()方法显示如下:

 _signOut() async {
   try {
     await widget.auth.signOut();
     widget.onSignedOut();
   } catch (e) {
     print(e);
   }
 }

该方法主要是调用Auth类中定义的signOut()方法,以将用户从应用中注销。 回忆传入HomeScreenMainScreen_onSignedOut()方法。 当用户退出时,该方法在此处用作widget.onSignedOut()来将authStatus更改为SIGNED_OUT。 同样,它包装在try-catch块中,以捕获并打印此处可能发生的任何异常。

可以在此处查看home_screen.dart的整个代码。

至此,应用的主要组件已经准备就绪,现在让我们创建最终的材质应用。

创建main.dart

main.dart内部,我们创建Stateless WidgetApp,并覆盖build()方法,如下所示:

 @override
 Widget build(BuildContext context) {
   return new MaterialApp(
       title: 'Firebase Authentication',
       debugShowCheckedModeBanner: false,
       theme: new ThemeData(
         primarySwatch: Colors.blue,
       ),
       home: new MainScreen(auth: new Auth()));
 }

该方法从主屏幕返回MaterialApp,以提供标题,主题。

可以在此处查看main.dart文件。

了解用于认证的异常检测

异常检测是机器学习的一个备受关注的分支。 该术语含义简单。 基本上,它是用于检测异常的方法的集合。 想象一袋苹果。 识别并挑选坏苹果将是异常检测的行为。

异常检测以几种方式执行:

  • 通过使用列的最小最大范围来识别数据集中与其余样本非常不同的数据样本
  • 通过将数据绘制为线形图并识别图中的突然尖峰
  • 通过围绕高斯曲线绘制数据并将最末端的点标记为离群值(异常)

一些常用的方法是支持向量机,贝叶斯网络和 K 最近邻。 在本节中,我们将重点介绍与安全性相关的异常检测。

假设您通常在家中登录应用上的帐户。 如果您突然从数千英里外的位置登录帐户,或者在另一种情况下,您以前从未使用过公共计算机登录帐户,那将是非常可疑的,但是突然有一天您这样做。 另一个可疑的情况可能是您尝试 10-20 次密码,每次在成功成功登录之前每次都输入错误密码。 当您的帐户遭到盗用时,所有这些情况都是可能的行为。 因此,重要的是要合并一个能够确定您的常规行为并对异常行为进行分类的系统。 换句话说,即使黑客使用了正确的密码,企图破坏您的帐户的尝试也应标记为异常。

这带给我们一个有趣的观点,即确定用户的常规行为。 我们如何做到这一点? 什么是正常行为? 它是针对每个用户的还是一个通用概念? 问题的答案是它是非常特定于用户的。 但是,行为的某些方面对于所有用户而言都可以相同。 一个应用可能会在多个屏幕上启动登录。 单个用户可能更喜欢其中一种或两种方法。 这将导致特定于该用户的特定于用户的行为。 但是,如果尝试从未由开发人员标记为登录屏幕的屏幕进行登录,则无论是哪个用户尝试登录,都肯定是异常的。

在我们的应用中,我们将集成一个这样的系统。 为此,我们将记录一段时间内我们应用的许多用户进行的所有登录尝试。 我们将特别注意他们尝试登录的屏幕以及它们传递给系统的数据类型。 一旦收集了很多这些样本,就可以根据用户执行的任何操作来确定系统对认证的信心。 如果系统在任何时候认为用户表现出的行为与他们的惯常行为相差很大,则该用户将未经认证并被要求验证其帐户详细信息。

让我们从创建预测模型开始,以确定用户认证是常规的还是异常的。

用于认证用户的自定义模型

我们将本节分为两个主要子节:

  • 构建用于认证有效性检查的模型
  • 托管自定义认证验证模型

让我们从第一部分开始。

构建用于认证有效性检查的模型

在本部分中,我们将构建模型来确定是否有任何用户正在执行常规登录或异常登录:

  1. 我们首先导入必要的模块,如下所示:
import sys
import os
import json
import pandas
import numpy
from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout
from keras.layers.embeddings import Embedding
from keras.preprocessing import sequence
from keras.preprocessing.text import Tokenizer
from collections import OrderedDict
  1. 现在,我们将数据集导入到项目中。 可以在这里中找到该数据集:
csv_file = 'data.csv'

dataframe = pandas.read_csv(csv_file, engine='python', quotechar='|', header=None)
count_frame = dataframe.groupby([1]).count()
print(count_frame)
total_req = count_frame[0][0] + count_frame[0][1]
num_malicious = count_frame[0][1]

print("Malicious request logs in dataset: {:0.2f}%".format(float(num_malicious) / total_req * 100))

前面的代码块将 CSV 数据集加载到项目中。 它还会打印一些与数据有关的统计信息,如下所示:

  1. 我们在上一步中加载的数据目前尚无法使用,无法进行深度学习。 在此步骤中,我们将其分为特征列和标签列,如下所示:
X = dataset[:,0]
Y = dataset[:,1]
  1. 接下来,我们将删除数据集中包含的某些列,因为我们不需要所有这些列来构建简单的模型:
for index, item in enumerate(X):
    reqJson = json.loads(item, object_pairs_hook=OrderedDict)
    del reqJson['timestamp']
    del reqJson['headers']
    del reqJson['source']
    del reqJson['route']
    del reqJson['responsePayload']
    X[index] = json.dumps(reqJson, separators=(',', ':'))
  1. 接下来,我们将在剩余的请求正文上执行分词。 分词是一种用于将大文本块分解为较小文本的方法,例如将段落分成句子,将句子分成单词。 我们这样做如下:
tokenizer = Tokenizer(filters='\t\n', char_level=True)
tokenizer.fit_on_texts(X)
  1. 分词之后,我们将请求正文中的文本转换为单词向量,如下一步所示。 我们将数据集和DataFrame标签分为两部分,即 75%-25%,以进行训练和测试:
num_words = len(tokenizer.word_index)+1
X = tokenizer.texts_to_sequences(X)

max_log_length = 1024
train_size = int(len(dataset) * .75)

X_processed = sequence.pad_sequences(X, maxlen=max_log_length)
X_train, X_test = X_processed[0:train_size], X_processed[train_size:len(X_processed)]
Y_train, Y_test = Y[0:train_size], Y[train_size:len(Y)]
  1. 接下来,我们基于长短期记忆LSTM)创建基于循环神经网络RNN)的学习方法,来识别常规用户行为。 将单词嵌入添加到层中,以帮助维持单词向量和单词之间的关系:
model = Sequential()
model.add(Embedding(num_words, 32, input_length=max_log_length))
model.add(Dropout(0.5))
model.add(LSTM(64, recurrent_dropout=0.5))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))

我们的输出是单个神经元,在正常登录的情况下,该神经元保存0;在登录异常的情况下,则保存1

  1. 现在,我们以精度作为度量标准编译模型,而损失则作为二进制交叉熵来计算:
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
print(model.summary())
  1. 现在,我们准备进行模型的训练:
model.fit(X_train, Y_train, validation_split=0.25, epochs=3, batch_size=128)
  1. 我们将快速检查模型所达到的准确率。 当前模型的准确率超过 96%:
score, acc = model.evaluate(X_test, Y_test, verbose=1, batch_size=128)
print("Model Accuracy: {:0.2f}%".format(acc * 100))

下面的屏幕快照显示了前面代码块的输出:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第8张图片

  1. 现在,我们保存模型权重和模型定义。 我们稍后将它们加载到 API 脚本中,以验证用户的认证:
model.save_weights('lstm-weights.h5')
model.save('lstm-model.h5')

现在,我们可以将认证模型作为 API 进行托管,我们将在下一部分中进行演示。

托管自定义认证验证模型

在本节中,我们将创建一个 API,用于在用户向模型提交其登录请求时对其进行认证。 请求标头将被解析为字符串,并且模型将使用它来预测登录是否有效:

  1. 我们首先导入创建 API 服务器所需的模块:
from sklearn.externals import joblib
from flask import Flask, request, jsonify
from string import digits

import sys
import os
import json
import pandas
import numpy
import optparse
from keras.models import Sequential, load_model
from keras.preprocessing import sequence
from keras.preprocessing.text import Tokenizer
from collections import OrderedDict
  1. 现在,我们实例化一个Flask应用对象。 我们还将从上一节“构建用于认证有效性检查的模型”中加载保存的模型定义和模型权重。然后,我们重新编译模型,并使用_make_predict_function( )方法创建其预测方法,如以下步骤所示:
app = Flask(__name__)

model = load_model('lstm-model.h5')
model.load_weights('lstm-weights.h5')
model.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['accuracy'])
model._make_predict_function()
  1. 然后,我们创建一个remove_digits()函数,该函数用于从提供给它的输入中去除所有数字。 这将用于在将请求正文文本放入模型之前清除它:
def remove_digits(s: str) -> str:
    remove_digits = str.maketrans('', '', digits)
    res = s.translate(remove_digits)
    return res
  1. 接下来,我们将在 API 服务器中创建/login路由。 该路由由login()方法处理,并响应GETPOST请求方法。 正如我们对训练输入所做的那样,我们删除了请求标头中的非必要部分。 这可以确保模型将对数据进行预测,类似于对其进行训练的数据:
@app.route('/login', methods=['GET, POST'])
def login():
    req = dict(request.headers)
    item = {}
    item["method"] = str(request.method)
    item["query"] = str(request.query_string)
    item["path"] = str(request.path)
    item["statusCode"] = 200
    item["requestPayload"] = []

    ## MORE CODE BELOW THIS LINE

    ## MORE CODE ABOVE THIS LINE

    response = {'result': float(prediction[0][0])}
    return jsonify(response)
  1. 现在,我们将代码添加到login()方法中,该方法将标记请求正文并将其传递给模型以执行有关登录请求有效性的预测,如下所示:
@app.route('/login', methods=['GET, POST'])
def login():
    ...
    ## MORE CODE BELOW THIS LINE
    X = numpy.array([json.dumps(item)])
    log_entry = "store"

    tokenizer = Tokenizer(filters='\t\n', char_level=True)
    tokenizer.fit_on_texts(X)
    seq = tokenizer.texts_to_sequences([log_entry])
    max_log_length = 1024
    log_entry_processed = sequence.pad_sequences(seq, maxlen=max_log_length)

    prediction = model.predict(log_entry_processed)
    ## MORE CODE ABOVE THIS LINE
    ...

最后,应用以 JSON 字符串的形式返回其对用户进行认证的信心。

  1. 最后,我们使用apprun()方法启动服务器脚本:
if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8000)
  1. 将此文件另存为main.py。 要开始执行服务器,请打开一个新终端并使用以下命令:
python main.py

服务器监听其运行系统的所有传入 IP。 通过在0.0.0.0 IP 上运行它,可以实现这一点。 如果我们希望稍后在基于云的服务器上部署脚本,则需要这样做。 如果不指定0.0.0.0主机,则默认情况下会使它监听127.0.0.1,这不适合在公共服务器上进行部署。 您可以在此处详细了解这些地址之间的区别。

在下一节中,我们将看到如何将 ReCaptcha 集成到迄今为止在该项目中构建的应用中。 之后,我们将把本节中构建的 API 集成到应用中。

实现 ReCaptcha 来保护垃圾邮件

为了为 Firebase 认证增加另一层安全性,我们将使用 ReCaptcha。 这是 Google 所支持的一项测试,可帮助我们保护数据免受垃圾邮件和滥用行为的自动 bot 攻击。 该测试很简单,很容易被人类解决,但是却阻碍了漫游器和恶意用户的使用。

要了解有关 ReCaptcha 及其用途的更多信息,请访问这里。

ReCAPTCHA v2

在本节中,我们将把 ReCaptcha 版本 2 集成到我们的应用中。 在此版本中,向用户显示一个简单的复选框。 如果刻度变为绿色,则表明用户已通过验证。

另外,还可以向用户提出挑战,以区分人和机器人。 这个挑战很容易被人类解决。 他们要做的就是根据说明选择一堆图像。 使用 ReCaptcha 进行认证的传统流程如下所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第9张图片

一旦用户能够验证其身份,他们就可以成功登录。

获取 API 密钥

要在我们的应用内部使用 ReCaptcha,我们需要在reCAPTCHA管理控制台中注册该应用,并获取站点密钥和秘密密钥。 为此,请访问这里并注册该应用。 您将需要导航到“注册新站点”部分,如以下屏幕截图所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第10张图片

我们可以通过以下两个简单步骤来获取 API 密钥:

  1. 首先提供一个域名。 在这里,我们将在 reCAPTCHA v2 下选择 reCAPTCHA Android。
  2. 选择 Android 版本后,添加项目的包名称。 正确填写所有信息后,单击“注册”。

这将引导您到显示站点密钥和秘密密钥的屏幕,如以下屏幕快照所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第11张图片

站点密钥秘密密钥复制并保存到安全位置。 我们将在编码应用时使用它们。

代码整合

为了在我们的应用中包含 ReCaptcha v2,我们将使用 Flutter 包flutter_recaptcha_v2。 将flutter_recaptcha_v2:0.1.0依赖项添加到pubspec.yaml文件中,然后在终端中运行flutter packages get以获取所需的依赖项。 以下步骤详细讨论了集成:

  1. 我们将代码添加到signup_signin_screen.dart。 首先导入依赖项:
import 'package:flutter_recaptcha_v2/flutter_recaptcha_v2.dart';
  1. 接下来,创建一个RecaptchaV2Controller实例:
RecaptchaV2Controller recaptchaV2Controller = RecaptchaV2Controller();
  1. reCAPTCHA 复选框将添加为小部件。 首先,让我们定义一个返回小部件的_createRecaptcha()方法:
 Widget _createRecaptcha() {
   return RecaptchaV2(
     apiKey: "Your Site Key here", 
     apiSecret: "Your API Key here", 
     controller: recaptchaV2Controller,
     onVerifiedError: (err){
       print(err);
     },
     onVerifiedSuccessfully: (success) {
       setState(() {
       if (success) {
         _signinSignup();
       } else {
         print('Failed to verify');
       }
       });
     },
   );
 }

在上述方法中,我们仅使用RecaptchaV2()构造器,即可为特定属性指定值。 添加您先前在apiKeyapiSecret属性中注册时保存的站点密钥和秘密密钥。 我们使用先前为属性控制器创建的recaptcha控制器recaptchaV2Controller的实例。 如果成功验证了用户,则将调用_signinSignup()方法以使用户登录。如果在验证期间发生错误,我们将打印错误。

  1. 现在,由于在用户尝试登录时应显示reCaptcha,因此我们将createSigninButton()中的登录凸起按钮的onPressed属性修改为recaptchaV2Controller
Widget _createSigninButton() {
    . . . . . . .
    return new Padding(
        . . . . . . .
        child: new RaisedButton(
            . . . . . . 
            //Modify the onPressed property
            onPressed: recaptchaV2Controller.show
        )
    )
}
  1. 最后,我们将_createRecaptcha()添加到build()内部的主体栈中:
 @override
 Widget build(BuildContext context) {
    . . . . . . .
    return new Scaffold(
        . . . . . . .
        body: Stack(
            children: <Widget>[
                _createBody(),
                _createCircularProgress(),

                //Add reCAPTCHA Widget
                _createRecaptcha()
                 ],
       ));
 }

这就是一切! 现在,我们具有比 Firebase 认证更高的安全级别,可以保护应用的数据免受自动机器人的攻击。 现在让我们看一下如何集成定制模型以检测恶意用户。

在 Flutter 中部署模型

至此,我们的 Firebase 认证应用与 ReCaptcha 保护一起运行。 现在,让我们添加最后的安全层,该层将不允许任何恶意用户进入应用。

我们已经知道该模型位于以下端点。 我们只需从应用内部进行 API 调用,传入用户提供的电子邮件和密码,并从模型中获取结果值。 该值将通过使用阈值结果值来帮助我们判断登录是否是恶意的。

如果该值小于 0.20,则认为该登录名是恶意的,并且屏幕上将显示以下消息:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第12张图片

现在,让我们看一下在 Flutter 应用中部署模型的步骤:

  1. 首先,由于我们正在获取数据并且将使用网络调用(即 HTTP 请求),因此我们需要向pubspec.yaml文件添加http依赖项,并按以下方式导入:
import 'package:http/http.dart' as http;
  1. 首先在auth.dart:内部定义的BaseAuth抽象类中添加以下函数声明
 Future<double> isValidUser(String email, String password);
  1. 现在,让我们在Auth类中定义isValidUser()函数:
 Future<double> isValidUser(String email, String password) async{
   final response = await http.Client()
       .get('http://34.67.160.232:8000/login?user=$email&password=$password');
     var jsonResponse = json.decode(response.body);
     var val = '${jsonResponse["result"]}';
     double result = double.parse(val);     
     return result;
   }

此函数将用户的电子邮件和密码作为参数,并将它们附加到请求 URL,以便为特定用户生成输出。 get request响应存储在变量响应中。 由于响应为 JSON 格式,因此我们使用json.decode()对其进行解码,并将解码后的响应存储在另一个变量响应中。 现在,我们使用‘${jsonResponse["result"]}'访问jsonResponse中的结果值,使用double.parse()将其转换为双精度类型整数,并将其存储在结果中。 最后,我们返回结果的值。

  1. 为了激活代码内部的恶意检测,我们从SigninSignupScreen调用了isValidUser()方法。 当具有现有帐户的用户选择从if-else块内部登录时,将调用此方法:
if (_formMode == FormMode.SIGNIN) {

    var val = await widget.auth.isValidUser(_usermail, _userpassword);

    . . . .
    } else {
      . . . .   
    }

isValidUser返回的值存储在val变量中。

  1. 如果该值小于 0.20,则表明登录活动是恶意的。 因此,我们将异常抛出并在 catch 块内抛出catch并在屏幕上显示错误消息。 这可以通过创建自定义异常类MalicousUserException来完成,该类在实例化时返回一条错误消息:
class MaliciousUserException implements Exception {
  String message() => 'Malicious login! Please try later.';
}
  1. 现在,我们将在调用isValidUser()之后添加if块,以检查是否需要抛出异常:
var val = await widget.auth.isValidUser(_usermail, _userpassword);
//Add the if block 
if(val < 0.20) {
    throw new MaliciousUserException();
}
  1. 现在,该异常已捕获在catch块内,并且不允许用户继续登录。此外,我们将_loading设置为false以表示不需要进一步的网络操作:
catch(MaliciousUserException) {
       setState(() {
         _loading = false;
           _errorMessage = 'Malicious user detected. Please try again later.';
       });

这就是一切! 我们之前基于 Firebase 认证创建的 Flutter 应用现在可以在后台运行智能模型的情况下找到恶意用户。

总结

在本章中,我们了解了如何使用 Flutter 和由 Firebase 支持的认证系统构建跨平台应用,同时结合了深度学习的优势。 然后,我们了解了如何将黑客攻击尝试归类为一般用户行为中的异常现象,并创建了一个模型来对这些异常现象进行分类以防止恶意用户登录。最后,我们使用了 Google 的 ReCaptcha 来消除对该应用的垃圾邮件使用,因此,使其在自动垃圾邮件或脚本化黑客攻击方面更具弹性。

在下一章中,我们将探索一个非常有趣的项目–使用移动应用上的深度学习生成音乐成绩单。

七、语音/多媒体处理 - 使用 AI 生成音乐

鉴于人工智能AI)的应用越来越多,将 AI 与音乐结合使用的想法已经存在了很长时间,并且受到了广泛的研究。 由于音乐是一系列音符,因此它是时间序列数据集的经典示例。 最近证明时间序列数据集在许多预测领域中非常有用–股市,天气模式,销售模式以及其他基于时间的数据集。 循环神经网络RNN)是处理时间序列数据集的最多模型之一。 对 RNN 进行的流行增强称为长短期记忆LSTM)神经元。 在本章中,我们将使用 LSTM 处理音符。

多媒体处理也不是一个新话题。 在本项目系列的早期,我们在多章中详细介绍了图像处理。 在本章中,我们将讨论并超越图像处理,并提供一个带有音频的深度学习示例。 我们将训练 Keras 模型来生成音乐样本,每次都会生成一个新样本。 然后,我们将此模型与 Flutter 应用结合使用,以通过 Android 和 iOS 设备上的音频播放器进行部署。

在本章中,我们将介绍以下主题:

  • 设计项目的架构
  • 了解多媒体处理
  • 开发基于 RNN 的音乐生成模型
  • 在 Android 和 iOS 上部署音频生成 API

让我们首先概述该项目的架构。

设计项目的架构

该项目的架构与作为应用部署的常规深度学习项目略有不同。 我们将有两组不同的音乐样本。 第一组样本将用于训练可以生成音乐的 LSTM 模型。 另一组样本将用作 LSTM 模型的随机输入,该模型将输出生成的音乐样本。 我们稍后将开发和使用的基于 LSTM 的模型将部署在 Google Cloud PlatformGCP)上。 但是,您可以将其部署在 AWS 或您选择的任何其他主机上。

下图总结了将在本项目中使用的不同组件之间的交互:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pw7KUWRD-1681785128417)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/50f17dc4-2658-4211-a7ff-9c41daafd884.png)]

移动应用要求部署在服务器上的模型生成新的音乐样本。 该模型使用随机音乐样本作为输入,以使其通过预先训练的模型来生成新的音乐样本。 然后,新的音乐样本由移动设备获取并播放给用户。

您可以将此架构与我们之前介绍的架构进行比较,在该架构中,将有一组用于训练的数据样本,然后将模型部署在云上或本地,并用于作出预测。

我们还可以更改此项目架构,以在存在为 Dart 语言编写的 midi 文件处理库的情况下在本地部署模型。 但是,在撰写本文时,还没有与我们在开发模型时使用的 Python midi 文件库的要求兼容的稳定库。

让我们从学习多媒体处理的含义以及如何使用 OpenCV 处理多媒体文件开始。

了解多媒体处理

多媒体是几乎所有形式的视觉,听觉或两者兼有的内容的总称。 术语多媒体处理本身非常模糊。 讨论该术语的更精确方法是将其分解为两个基本部分-视觉或听觉。 因此,我们将讨论多媒体处理的术语,即图像处理和音频处理。 这些术语的混合产生了视频处理,这只是多媒体的另一种形式。

在以下各节中,我们将以单独的形式讨论它们。

图像处理

图像处理或计算机视觉是迄今为止人工智能研究最多的分支之一。 在过去的几十年中,它发展迅速,并在以下几种技术的进步中发挥了重要作用:

  • 图像过滤器和编辑器
  • 面部识别
  • 数字绘画
  • 自动驾驶汽车

我们在较早的项目中讨论了图像处理的基础知识。 在这个项目中,我们将讨论一个非常流行的用于执行图像处理的库-OpenCV。 OpenCV 是开源计算机视觉的缩写。 它由 Intel 开发,并由 Willow Garage 和 Itseez(后来被 Intel 收购)推动。 毫无疑问,由于它与所有主要的机器学习框架(例如 TensorFlow,PyTorch 和 Caffe)兼容,因此它是执行图像处理的全球大多数开发人员的首要选择。 除此之外,OpenCV 还可以使用多种语言,例如 C++,Java 和 Python。

要在 Python 环境中安装 OpenCV,可以使用以下命令:

pip install opencv-contrib-python

前面的命令将同时安装主 OpenCV 模块和contrib模块。 您可以在此处找到更多模块供您选择。 有关更多安装说明,如果前面的链接不符合您的要求,则可以在此处遵循官方文档。

让我们为您介绍一个非常简单的示例,说明如何使用 OpenCV 执行图像处理。 创建一个新的 Jupyter 笔记本,并从以下步骤开始:

  1. 要将 OpenCV 导入笔记本,请使用以下代码行:
import cv2
  1. 我们还要将 matplotlib 导入笔记本,因为如果您尝试使用本机 OpenCV 图像显示功能,Jupyter 笔记本将会崩溃:
from matplotlib import pyplot as plt
%matplotlib inline
  1. 让我们使用 matplotlib 为 OpenCV 的本机图像显示功能创建一个替代函数,以方便在笔记本中显示图像:
def showim(image):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    plt.imshow(image)
    plt.show()

请注意,我们将图像的配色方案从蓝色绿色红色BGR)转换为红色绿色蓝色RGB)。 这是由于默认情况下 OpenCV 使用 BGR 配色方案。 但是,matplotlib 在显示图片时会使用 RGB 方案,并且如果不进行这种转换,我们的图像就会显得奇怪。

  1. 现在,让我们将图像读取到 Jupyter 笔记本中。 完成后,我们将能够看到加载的图像:
image = cv2.imread("Image.jpeg")
showim(image)

前面代码的输出取决于您选择加载到笔记本中的图像:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第13张图片

在我们的示例中,我们加载了柑橘类水果切片的图像,这是艾萨克·奎萨达(Isaac Quesada)在“Unsplash”上拍摄的惊人照片。

您可以在这里找到上一张图片。

  1. 让我们通过将之前的图像转换为灰度图像来进行简单的操作。 为此,我们就像在声明的showim()函数中那样简单地使用转换方法:
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
showim(gray_image)

这将产生以下输出:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第14张图片

  1. 现在让我们执行另一种常见的操作,即图像模糊。 在图像处理中通常采用模糊处理,以消除图像中信息的不必要的细节(此时)。 我们使用高斯模糊过滤器,这是在图像上创建模糊的最常见算法之一:
blurred_image = cv2.GaussianBlur(image, (7, 7), 0)
showim(blurred_image)

这将产生以下输出:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第15张图片

请注意,前面的图像不如原始图像清晰。 但是,它很容易达到愿意计算此图像中对象数量的目的。

  1. 为了在图像中定位对象,我们首先需要标记图像中的边缘。 为此,我们可以使用Canny()方法,该方法是 OpenCV 中可用的其他选项之一,用于查找图像的边缘:
canny = cv2.Canny(blurred_image, 10, 50)
showim(canny)

这将产生以下输出:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第16张图片

请注意,在上图中找到的边缘数量很高。 虽然这会显示图像的细节,但是如果我们尝试对边缘进行计数以尝试确定图像中的对象数量,这将无济于事。

  1. 让我们尝试计算上一步生成的图像中不同项目的数量:
contours, hierarchy= cv2.findContours(canny, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
print("Number of objects found = ", len(contours))

上面的代码将产生以下输出:

Number of objects found = 18

但是,我们知道前面的图像中没有 18 个对象。 只有 9。因此,在寻找边缘时,我们将在canny方法中处理阈值。

  1. 让我们在 canny 方法中增加边缘发现的阈值。 这使得更难检测到边缘,因此仅使最明显的边缘可见:
canny = cv2.Canny(blurred_image, 50, 150)
showim(canny)

这将产生以下输出:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第17张图片

请注意,在柑橘类水果体内发现的边缘急剧减少,仅清晰可见其轮廓。 我们希望这会在计数时产生较少的对象。

  1. 让我们再次运行以下代码块:
contours, hierarchy= cv2.findContours(canny, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
print("Number of objects found = ", len(contours))

这将产生以下输出:

Number of objects found = 9

这是期望值。 但是,只有在特殊情况下,该值才是准确的。

  1. 最后,让我们尝试概述检测到的对象。 为此,我们绘制了findContours()方法的上一步中确定的轮廓:
_ = cv2.drawContours(image, contours, -1, (0,255,0), 10)
showim(image)

这将产生以下输出:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第18张图片

请注意,我们已经在拍摄的原始图像中非常准确地识别出了九片水果。 我们可以进一步扩展此示例,以在任何图像中找到某些类型的对象。

要了解有关 OpenCV 的更多信息并找到一些可供学习的示例,请访问以下存储库。

现在让我们学习如何处理音频文件。

音频处理

我们已经看到了如何处理图像以及可以从中提取信息。 在本节中,我们将介绍音频文件的处理。 音频或声音是吞没您周围环境的东西。 在许多情况下,您仅能从该区域的音频剪辑中正确预测该区域或环境,而无需实际看到任何视觉提示。 声音或语音是人与人之间交流的一种形式。 安排良好的节奏模式形式的音频称为音乐,可以使用乐器制作。

音频文件的一些流行格式如下:

  • MP3:一种非常流行的格式,广泛用于共享音乐文件。
  • AAC:是对 MP3 格式的改进,AAC 主要用于 Apple 设备。
  • WAV:由 Microsoft 和 IBM 创建,这种格式是无损压缩,即使对于小的音频文件也可能很大。
  • MIDI:乐器数字接口文件实际上不包含音频。 它们包含乐器音符,因此体积小且易于使用。

音频处理是以下技术的增长所必需的:

  • 用于基于语音的界面或助手的语音处理
  • 虚拟助手的语音生成
  • 音乐生成
  • 字幕生成
  • 推荐类似音乐

TensorFlow 团队的 Magenta 是一种非常流行的音频处理工具。

您可以通过这里访问 Magenta 主页。 该工具允许快速生成音频和音频文件的转录。

让我们简要地探讨 Magenta。

Magenta

Magenta 是 Google Brain 团队参与研究的一部分,该团队也参与了 TensorFlow。 它被开发为一种工具,可允许艺术家借助深度学习和强化学习算法来增强其音乐或艺术创作渠道。 这是 Magenta 的徽标:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第19张图片

让我们从以下步骤开始:

  1. 要在系统上安装 Magenta,可以使用 Python 的 pip 存储库:
pip install magenta
  1. 如果缺少任何依赖项,则可以使用以下命令安装它们:
!apt-get update -qq && apt-get install -qq libfluidsynth1 fluid-soundfont-gm build-essential libasound2-dev libjack-dev

!pip install -qU pyfluidsynth pretty_midi
  1. 要将 Magenta 导入项目中,可以使用以下命令:
import magenta

或者,按照流行的惯例,仅加载 Magenta 的音乐部分,可以使用以下命令:

import magenta.music as mm

您可以使用前面的导入在线找到很多样本。

让我们快速创作一些音乐。 我们将创建一些鼓声,然后将其保存到 MIDI 文件:

  1. 我们首先需要创建一个NoteSequence对象。 在 Magenta 中,所有音乐都以音符序列的格式存储,类似于 MIDI 存储音乐的方式:
from magenta.protobuf import music_pb2

drums = music_pb2.NoteSequence()
  1. 创建NoteSequence对象后,该对象为空,因此我们需要向其添加一些注解:
drums.notes.add(pitch=36, start_time=0, end_time=0.125, is_drum=True, instrument=10, velocity=80)
drums.notes.add(pitch=38, start_time=0, end_time=0.125, is_drum=True, instrument=10, velocity=80)
drums.notes.add(pitch=42, start_time=0, end_time=0.125, is_drum=True, instrument=10, velocity=80)
drums.notes.add(pitch=46, start_time=0, end_time=0.125, is_drum=True, instrument=10, velocity=80)
.
.
.
drums.notes.add(pitch=42, start_time=0.75, end_time=0.875, is_drum=True, instrument=10, velocity=80)
drums.notes.add(pitch=45, start_time=0.75, end_time=0.875, is_drum=True, instrument=10, velocity=80)

请注意,在前面的代码中,每个音符都有音高和力度。 再次类似于 MIDI 文件。

  1. 现在让我们为音符添加节奏,并设置音乐播放的总时间:
drums.total_time = 1.375

drums.tempos.add(qpm=60)

完成此操作后,我们现在准备导出 MIDI 文件。

  1. 我们首先需要将 MagentaNoteSequence对象转换为 MIDI 文件:
mm.sequence_proto_to_midi_file(drums, 'drums_sample_output.mid')

前面的代码首先将音符序列转换为 MIDI,然后将它们写入磁盘上的drums_sample_output.mid文件。 您现在可以使用任何合适的音乐播放器播放midi文件。

继续前进,让我们探索如何处理视频。

视频处理

视频处理是多媒体处理的另一个重要部分。 通常,我们需要弄清楚移动场景中发生的事情。 例如,如果我们要生产自动驾驶汽车,则它需要实时处理大量视频才能平稳行驶。 这种情况的另一个实例可以是将手语转换为文本以帮助与语音障碍者互动的设备。 此外,需要视频处理来创建电影和动作效果。

我们将在本节中再次探讨 OpenCV。 但是,我们将演示如何在 OpenCV 中使用实时摄像机供稿来检测面部。

创建一个新的 Python 脚本并执行以下步骤:

  1. 首先,我们需要对脚本进行必要的导入。 这将很简单,因为我们只需要 OpenCV 模块:
import cv2
  1. 现在,让我们将 Haar 级联模型加载到脚本中。 Haar 级联算法是一种用于检测任何给定图像中的对象的算法。 由于视频不过是图像流,因此我们将其分解为一系列帧并检测其中的人脸:
faceCascade = cv2.CascadeClassifier("haarcascade_frontalface_default.xml")

您将不得不从以下位置获取haarcascade_frontalface_default.xml文件。

Haar 级联是一类使用级联函数执行分类的分类器算法。 保罗·维奥拉(Paul Viola)和迈克尔·琼斯(Michael Jones)引入了它们,以试图建立一种对象检测算法,该算法足够快以在低端设备上运行。 级联函数池来自几个较小的分类器。

Haar 级联文件通常以可扩展标记语言XML)的格式找到,并且通常执行一项特定功能,例如面部检测,身体姿势检测, 对象检测等。 您可以在此处阅读有关 Haar 级联的更多信息。

  1. 现在,我们必须实例化摄像机以进行视频捕获。 为此,我们可以使用默认的笔记本电脑摄像头:
video_capture = cv2.VideoCapture(0)
  1. 现在让我们从视频中捕获帧并显示它们:
while True:
    # Capture frames
    ret, frame = video_capture.read()

    ### We'll add code below in future steps

    ### We'll add code above in future steps

    # Display the resulting frame
    cv2.imshow('Webcam Capture', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

这样您就可以在屏幕上显示实时视频供稿。 在运行此程序之前,我们需要释放相机并正确关闭窗户。

  1. 要正确关闭实时捕获,请使用以下命令:
video_capture.release()
cv2.destroyAllWindows()

现在,让我们对脚本进行测试运行。

您应该会看到一个窗口,其中包含您的脸部实时捕捉的图像(如果您不害羞的话)。

  1. 让我们向该视频提要添加面部检测。 由于用于面部检测的 Haar 级联在使用灰度图像时效果更好,因此我们将首先将每个帧转换为灰度,然后对其进行面部检测。 我们需要将此代码添加到while循环中,如以下代码所示:
    ### We'll add code below in future steps

    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    faces = faceCascade.detectMultiScale(
        gray,
        scaleFactor=1.1,
        minNeighbors=5,
        minSize=(30, 30),
        flags=cv2.CASCADE_SCALE_IMAGE
    )

    ### We'll add code above in future steps

这样,我们就可以检测到人脸了,因此让我们在视频供稿中对其进行标记!

  1. 我们将简单地使用 OpenCV 的矩形绘制函数在屏幕上标记面孔:
    minNeighbors=5,
        minSize=(30, 30),
        flags=cv2.CASCADE_SCALE_IMAGE
    )

    for (x, y, w, h) in faces:
        cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)

    ### We'll add code above in future steps

现在让我们再次尝试运行脚本。

转到终端并使用以下命令运行脚本:

python filename.py

在这里,文件名是您保存脚本文件时的名称。

您应该获得类似于以下屏幕截图的输出:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第20张图片

要退出实时网络摄像头捕获,请使用键盘上的Q键(我们已在前面的代码中进行了设置)。

我们已经研究了多媒体处理的三种主要形式的概述。 现在,让我们继续前进,构建基于 LSTM 的模型以生成音频。

开发基于 RNN 的音乐生成模型

在本节中,我们将开发音乐生成模型。 我们将为此使用 RNN,并使用 LSTM 神经元模型。 RNN 与简单的人工神经网络ANN)有很大的不同-允许在层之间重复使用输入。

虽然在 ANN 中,我们希望输入到神经网络的输入值向前移动,然后产生基于错误的反馈,并将其合并到网络权重中,但 RNN 使输入多次循环返回到先前的层。

下图表示 RNN 神经元:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第21张图片

从上图可以看到,通过神经元激活函数后的输入分为两部分。 一部分在网络中向前移动到下一层或输出,而另一部分则反馈到网络中。 在时间序列数据集中,可以相对于给定样本在t的时间标记每个样本,我们可以扩展前面的图,如下所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第22张图片

但是,由于通过激活函数反复暴露值,RNN 趋向于梯度消失,其中 RNN 的值逐梯度小到可以忽略不计(或在梯度爆炸的情况下变大)。 为避免这种情况,引入了 LSTM 单元,该单元通过将信息存储在单元中而允许将信息保留更长的时间。 每个 LSTM 单元由三个门和一个存储单元组成。 三个门(输入,输出和遗忘门)负责确定哪些值存储在存储单元中。

因此,LSTM 单元变得独立于 RNN 其余部分的更新频率,并且每个单元格都有自己的时间来记住它所拥有的值。 就我们而言,与其他信息相比,我们忘记了一些随机信息的时间要晚得多,这更自然地模仿了自然。

您可以在以下链接中找到有关 RNN 和 LSTM 的详细且易于理解的解释。

在开始为项目构建模型之前,我们需要设置项目目录,如以下代码所示:

├── app.py
├── MusicGenerate.ipynb
├── Output/
└── Samples/
    ├── 0.mid
    ├── 1.mid
    ├── 2.mid
    └── 3.mid

请注意,我们已经在Samples文件夹中下载了四个 MIDI 文件样本。 然后,我们创建了要使用的MusicGenerate.ipynb Jupyter 笔记本。 在接下来的几个步骤中,我们将仅在此 Jupyter 笔记本上工作。 app.py脚本当前为空,将来,我们将使用它来托管模型。

现在让我们开始创建基于 LSTM 的用于生成音乐的模型。

创建基于 LSTM 的模型

在本节中,我们将在 Jupyter 笔记本环境中研究MusicGenerate.ipynb笔记本:

  1. 在此笔记本中,我们将需要导入许多模块。 使用以下代码导入它们:
import mido
from mido import MidiFile, MidiTrack, Message
from tensorflow.keras.layers import LSTM, Dense, Activation, Dropout, Flatten
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from sklearn.preprocessing import MinMaxScaler
import numpy as np

我们使用了mido库。 如果您的系统上未安装它,则可以使用以下命令来安装它:

pip install mido

注意,在前面的代码中,我们还导入了 Keras 模块和子部件。 该项目中使用的 TensorFlow 版本为 2.0。 为了在您的系统上安装相同版本或升级当前的 TensorFlow 安装,可以使用以下命令:

pip install --upgrade pip

pip install --upgrade tensorflow

现在,我们将继续阅读示例文件。

  1. 要将 MIDI 文件读入项目笔记本,请使用以下代码:
notes = []
for msg in MidiFile('Samples/0.mid') :
    try:
        if not msg.is_meta and msg.channel in [0, 1, 2, 3] and msg.type == 'note_on':
            data = msg.bytes()
            notes.append(data[1])
    except:
        pass

这将在notes列表中加载通道0123的所有开头音符。

要了解有关注解,消息和频道的更多信息,请使用以下文档。

  1. 由于音符处于大于 0–1 范围的可变范围内,因此我们将使用以下代码将其缩放以适合公共范围:
scaler = MinMaxScaler(feature_range=(0,1))
scaler.fit(np.array(notes).reshape(-1,1))
notes = list(scaler.transform(np.array(notes).reshape(-1,1)))
  1. 我们基本上拥有的是随时间变化的笔记列表。 我们需要将其转换为时间序列数据集格式。 为此,我们使用以下代码转换列表:
notes = [list(note) for note in notes]

X = []
y = []

n_prev = 20
for i in range(len(notes)-n_prev):
    X.append(notes[i:i+n_prev])
    y.append(notes[i+n_prev])

我们已将其转换为一个集合,其中每个样本都带有未来的 20 个音符,并且在数据集的末尾具有过去的 20 个音符。这可以通过以下方式进行:如果我们有 5 个样本,例如M[1]M[2]M[3]M[4]M[5],然后我们将它们安排在大小为 2 的配对中(类似于我们的 20),如下所示:

  • M[1] M[2]
  • M[2] M[3]
  • M[3] M[4],依此类推
  1. 现在,我们将使用 Keras 创建 LSTM 模型,如以下代码所示:
model = Sequential()
model.add(LSTM(256, input_shape=(n_prev, 1), return_sequences=True))
model.add(Dropout(0.3))
model.add(LSTM(128, input_shape=(n_prev, 1), return_sequences=True))
model.add(Dropout(0.3))
model.add(LSTM(256, input_shape=(n_prev, 1), return_sequences=False))
model.add(Dropout(0.3))
model.add(Dense(1))
model.add(Activation('linear'))
optimizer = Adam(lr=0.001)
model.compile(loss='mse', optimizer=optimizer)

随意使用此 LSTM 模型的超参数。

  1. 最后,我们将训练样本适合模型并保存模型文件:
model.fit(np.array(X), np.array(y), 32, 25, verbose=1)
model.save("model.h5")

这将在我们的项目目录中创建model.h5文件。 每当用户从应用发出生成请求时,我们都会将此文件与其他音乐样本一起使用,以随机生成新的乐曲。

现在,让我们使用 Flask 服务器部署此模型。

使用 Flask 部署模型

对于项目的这一部分,您可以使用本地系统,也可以在其他地方的app.py中部署脚本。 我们将编辑此文件以创建 Flask 服务器,该服务器生成音乐并允许下载生成的 MIDI 文件。

该文件中的某些代码与 Jupyter 笔记本类似,因为每次加载音频样本并将其与我们生成的模型一起使用时,音频样本始终需要进行类似的处理:

  1. 我们使用以下代码将所需的模块导入此脚本:
import mido
from mido import MidiFile, MidiTrack, Message
from tensorflow.keras.models import load_model
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import random
import time
from flask import send_file
import os

from flask import Flask, jsonify

app = Flask(__name__)

请注意,我们进行的最后四次导入与之前在 Jupyter 笔记本中导入的内容不同。 同样,我们不需要将几个 Keras 组件导入此脚本,因为我们将从已经准备好的模型中加载。

在上一个代码块的最后一行代码中,我们实例化了一个名为app的 Flask 对象。

  1. 在此步骤中,我们将创建函数的第一部分,当在 API 上调用/generate路由时,该函数将生成新的音乐样本:
@app.route('/generate', methods=['GET'])
def generate():

    songnum = random.randint(0, 3)

    ### More code below this
  1. 一旦我们随机决定在音乐生成过程中使用哪个样本文件,我们就需要像 Jupyter 笔记本中的训练样本那样对它进行类似的转换:
def generate():
    .
    .
    .    
    notes = []

    for msg in MidiFile('Samples/%s.mid' % (songnum)):
        try:
            if not msg.is_meta and msg.channel in [0, 1, 2, 3] and msg.type == 'note_on':
                data = msg.bytes()
                notes.append(data[1])
        except:
            pass

    scaler = MinMaxScaler(feature_range=(0, 1))
    scaler.fit(np.array(notes).reshape(-1, 1))
    notes = list(scaler.transform(np.array(notes).reshape(-1, 1)))

    ### More code below this

在前面的代码块中,我们加载了示例文件,并从训练过程中使用的相同通道中提取了其注解。

  1. 现在,我们将像在训练期间一样缩放音符:
def generate():
    .
    .
    .     
    notes = [list(note) for note in notes]

    X = []
    y = []

    n_prev = 20
    for i in range(len(notes) - n_prev):
        X.append(notes[i:i + n_prev])
        y.append(notes[i + n_prev])

    ### More code below this

我们也将这些笔记列表转换为适合模型输入的形状,就像我们在训练过程中对输入所做的一样。

  1. 接下来,我们将使用以下代码来加载 Keras 模型并从该模型创建新的注解列表:
def generate():
    .
    .
    . 
    model = load_model("model.h5")

    xlen = len(X)

    start = random.randint(0, 100)

    stop = start + 200

    prediction = model.predict(np.array(X[start:stop]))
    prediction = np.squeeze(prediction)
    prediction = np.squeeze(scaler.inverse_transform(prediction.reshape(-1, 1)))
    prediction = [int(i) for i in prediction]    

    ### More code below this
  1. 现在,我们可以使用以下代码将此音符列表转换为 MIDI 序列:
def generate():
    .
    .
    . 
    mid = MidiFile()
    track = MidiTrack()
    t = 0
    for note in prediction:
        vol = random.randint(50, 70)
        note = np.asarray([147, note, vol])
        bytes = note.astype(int)
        msg = Message.from_bytes(bytes[0:3])
        t += 1
        msg.time = t
        track.append(msg)
    mid.tracks.append(track)

    ### More code below this
  1. 现在,我们准备将文件保存到磁盘。 它包含从模型随机生成的音乐:
def generate():
    .
    .
    . 
    epoch_time = int(time.time())

    outputfile = 'output_%s.mid' % (epoch_time)
    mid.save("Output/" + outputfile)

    response = {'result': outputfile}

    return jsonify(response)

因此,/generate API 以 JSON 格式返回生成的文件的名称。 然后,我们可以下载并播放此文件。

  1. 要将文件下载到客户端,我们需要使用以下代码:
@app.route('/download/', methods=['GET'])
def download(fname):
    return send_file("Output/"+fname, mimetype="audio/midi", as_attachment=True)

请注意,前面的函数在/download/filename路由上起作用,在该路由上,客户端根据上一代 API 调用的输出提供文件名。 下载的文件的 MIME 类型为audio/midi,它告诉客户端它是 MIDI 文件。

  1. 最后,我们可以添加将执行此服务器的代码:
if __name__ == '__main__':
    app.run(host="0.0.0.0", port=8000)

完成此操作后,我们可以在终端中使用以下命令来运行服务器:

python app.py

如果代码中产生任何警告,您将从控制台获得一些调试信息。 完成此操作后,我们准备在下一节中为我们的 API 构建 Flutter 应用客户端。

在 Android 和 iOS 上部署音频生成 API

成功创建和部署模型后,现在开始构建移动应用。 该应用将用于获取和播放由先前创建的模型生成的音乐。

它将具有三个按钮:

  • 生成音乐:生成新的音频文件
  • 播放:播放新生成的文件
  • 停止:停止正在播放的音乐

另外,它的底部将显示一些文本,以显示应用的当前状态。

该应用将显示如下:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第23张图片

该应用的小部件树如下所示:

现在开始构建应用的 UI。

创建 UI

我们首先创建一个新的 Dart 文件play_music.dart和一个有状态的小部件PlayMusic。 如前所述,在该文件中,我们将创建三个按钮来执行基本功能。 以下步骤描述了如何创建 UI:

  1. 定义buildGenerateButton()方法以创建RaisedButton变量,该变量将用于生成新的音乐文件:
 Widget buildGenerateButton() {
   return Padding(
     padding: EdgeInsets.only(left: 16, right: 16, top: 16),
     child: RaisedButton(
       child: Text("Generate Music"),
       color: Colors.blue,
       textColor: Colors.white,
     ),
   );
 }

在前面定义的函数中,我们创建一个RaisedButton,并添加Generate Music文本作为子元素。 color属性的Colors.blue值用于为按钮赋予蓝色。 另外,我们将textColor修改为Colors.white,以使按钮内的文本为白色。 使用EdgeInsets.only()给按钮提供左,右和顶部填充。 在后面的部分中,我们将在按钮上添加onPressed属性,以便每次按下按钮时都可以从托管模型中获取新的音乐文件。

  1. 定义buildPlayButton()方法以播放新生成的音频文件:
Widget buildPlayButton() {
   return Padding(
   padding: EdgeInsets.only(left: 16, right: 16, top: 16),
   child: RaisedButton(
     child: Text("Play"),
     onPressed: () {
       play();
     },
     color: Colors.blue,
     textColor: Colors.white,
     ),
   );
 }

在前面定义的函数中,我们创建一个RaisedButton,并添加"Play"文本作为子元素。 color属性的Colors.blue值用于为按钮赋予蓝色。 另外,我们将textColor修改为Colors.white,以使按钮内的文本为白色。 使用EdgeInsets.only()给按钮提供左,右和顶部填充。 在后面的部分中,我们将在按钮上添加onPressed属性,以在每次按下按钮时播放新生成的音乐文件。

  1. 定义buildStopButton()方法以停止当前正在播放的音频:
Widget buildStopButton() {
   return Padding(
     padding: EdgeInsets.only(left: 16, right: 16, top: 16),
     child: RaisedButton(
       child: Text("Stop"),
       onPressed: (){
         stop();
       },
       color: Colors.blue,
       textColor: Colors.white,
     )
   );
 }

在前面定义的函数中,我们创建一个RaisedButton,并添加"Stop"文本作为子元素。 color属性的Colors.blue值用于为按钮赋予蓝色。 另外,我们将textColor修改为Colors.white,以使按钮内的文本为白色。 使用EdgeInsets.only()给按钮提供左,右和顶部填充。 在下一节中,我们将向按钮添加onPressed属性,以在按下按钮时停止当前播放的音频。

  1. 覆盖PlayMusicState中的build()方法,以创建先前创建的按钮的Column
 @override
 Widget build(BuildContext context) {
   return Scaffold(
     appBar: AppBar(
       title: Text("Generate Play Music"),
     ),
     body: Column(
       crossAxisAlignment: CrossAxisAlignment.stretch,
       children: <Widget>[
         buildGenerateButton(),
         buildPlayButton(),
         buildStopButton(),
       ],
     )
   );
 }

在前面的代码片段中,我们返回Scaffold。 它包含一个AppBar,其中具有[Generate Play Music]作为titleScaffold的主体是Column。 列的子级是我们在上一步中创建的按钮。 通过调用相应方法将按钮添加到该列中。 此外,crossAxisAlignment属性设置为CrossAxisAlignment.stretch,以便按钮占据父容器(即列)的总宽度。

此时,该应用如下所示:

在下一节中,我们将添加一种在应用中播放音频文件的机制。

添加音频播放器

创建应用的用户界面后,我们现在将音频播放器添加到应用中以播放音频文件。 我们将使用audioplayer插件添加音频播放器,如下所示:

  1. 我们首先将依赖项添加到pubspec.yaml文件中:
audioplayers: 0.13.2

现在,通过运行flutter pub get获得包。

  1. 接下来,我们将插件导入play_music.dart
import 'package:audioplayers/audioplayers.dart';
  1. 然后,在PlayMusicState内创建AudioPlayer的实例:
AudioPlayer audioPlayer = AudioPlayer();
  1. 现在,让我们定义一个play()方法来播放远程可用的音频文件,如下所示:
play() async {
   var url = 'http://34.70.80.18:8000/download/output_1573917221.mid';
   int result = await audioPlayer.play(url);
   if (result == 1) {
     print('Success');
     }
 }

最初,我们将使用存储在url变量中的样本音频文件。 通过传递url中的值,使用audioPlayer.play()播放音频文件。 另外,如果从url变量成功访问和播放了音频文件,则结果将存储在结果变量中,其值将为1

  1. 现在,将onPressed属性添加到buildPlayButton内置的播放按钮中,以便每当按下该按钮时就播放音频文件:
Widget buildPlayButton() {
   return Padding(
   padding: EdgeInsets.only(left: 16, right: 16, top: 16),
   child: RaisedButton(
     ....
     onPressed: () {
       play();
     },
     ....
     ),
   );
 }

在前面的代码片段中,我们添加onPressed属性并调用play()方法,以便每当按下按钮时就播放音频文件。

  1. 现在,我们将定义stop()以停止正在播放的音乐:
void stop() {
   audioPlayer.stop();
 }

stop()方法内部,我们只需调用audioPlayer.stop()即可停止正在播放的音乐。

  1. 最后,我们为buildStopButton()中内置的停止按钮添加onPressed属性:
 Widget buildStopButton() {
   return Padding(
     padding: EdgeInsets.only(left: 16, right: 16, top: 16),
     child: RaisedButton(
       ....
       onPressed: (){
         stop();
       },
       ....
     )
   );
 }

在前面的代码片段中,我们向onPressed中的stop()添加了一个调用,以便一旦按下停止按钮就停止音频。

现在开始使用 Flutter 应用部署模型。

部署模型

在为应用成功添加基本的播放和停止功能之后,现在让我们访问托管模型以每次生成,获取和播放新的音频文件。 以下步骤详细讨论了如何在应用内部访问模型:

  1. 首先,我们定义fetchResponse()方法来生成和获取新的音频文件:
void fetchResponse() async {
   final response =
     await http.get('http://35.225.134.65:8000/generate');
   if (response.statusCode == 200) {
     var v = json.decode(response.body);
     fileName = v["result"] ;
   } else {
     throw Exception('Failed to load');
   }
 }

我们首先使用http.get()从 API 获取响应,然后传入托管模型的 URL。 get()方法的响应存储在response变量中。 get()操作完成后,我们使用response.statusCode检查状态码。 如果状态值为200,则获取成功。 接下来,我们使用json.decode()将响应的主体从原始 JSON 转换为Map,以便可以轻松访问响应主体中包含的键值对。 我们使用v["result"]访问新音频文件的值,并将其存储在全局fileName变量中。 如果responseCode不是200,我们只会抛出一个错误。

  1. 现在让我们定义load()以对fetchResponse()进行适当的调用:
void load() {
   fetchResponse();
 }

在前面的代码行中,我们仅定义一个load()方法,该方法用于调用fetchResponse()来获取新生成的音频文件的值。

  1. 现在,我们将修改buildGenerateButton()中的onPressed属性,以每次生成新的音频文件:
Widget buildGenerateButton() {
   return Padding(
     ....
     child: RaisedButton(
       ....
       onPressed: () {
         load();
       },
       ....
     ),
   );
 }

根据应用的功能,每当按下生成按钮时,都应生成一个新的音频文件。 这直接意味着无论何时按下“生成”按钮,我们都需要调用 API 以获取新生成的音频文件的名称。 因此,我们修改buildGenerateButton()以添加onPressed属性,以便每当按下按钮时,它都会调用load(),该调用随后将调用fetchResponse()并将新音频文件的名称存储在输出中。

  1. 托管的音频文件有两个部分,baseUrlfileNamebaseUrl对于所有调用均保持不变。 因此,我们声明一个存储baseUrl的全局字符串变量:
String baseUrl = 'http://34.70.80.18:8000/download/';

回想一下,我们已经在“步骤 1”中将新音频文件的名称存储在fileName中。

  1. 现在,让我们修改play()以播放新生成的文件:
play() async {
   var url = baseUrl + fileName;
   AudioPlayer.logEnabled = true;
   int result = await audioPlayer.play(url);
   if (result == 1) {
     print('Success');
     }
 }

在前面的代码片段中,我们修改了前面定义的play()方法。 我们通过附加baseUrlfileName创建一个新的 URL,以便url中的值始终与新生成的音频文件相对应。 我们在调用audioPlayer.play()时传递 URL 的值。 这样可以确保每次按下播放按钮时,都会播放最新生成的音频文件。

  1. 此外,我们添加了Text小部件以反映文件生成状态:
 Widget buildLoadingText() {
   return Center(
     child: Padding(
       padding: EdgeInsets.only(top: 16),
       child: Text(loadText)
     )
   );
 }

在前面定义的函数中,我们创建了一个简单的Text小部件,以反映提取操作正在运行以及何时完成的事实。 Text小部件具有顶部填充,并与Center对齐。 loadText值用于创建窗口小部件。

全局声明该变量,其初始值为'Generate Music'

String loadText = 'Generate Music';
  1. 更新build()方法以添加新的Text小部件:
@override
 Widget build(BuildContext context) {
   return Scaffold(
     ....
     body: Column(
       ....
       children: <Widget>[
         buildGenerateButton(),
         ....
         buildLoadingText()
       ],
     )
   );
 }

现在,我们更新build()方法以添加新创建的Text小部件。 该窗口小部件只是作为先前创建的Column的子级添加的。

  1. 当用户想要生成一个新的文本文件时,并且在进行提取操作时,我们需要更改文本:
void load() {
   setState(() {
    loadText = 'Generating...';
   });
   fetchResponse();
 }

在前面的代码段中,loadText值设置为'Generating...',以反映正在进行get()操作的事实。

  1. 最后,获取完成后,我们将更新文本:
void fetchResponse() async {
   final response =
     await http.get('http://35.225.134.65:8000/generate').whenComplete((){
       setState(() {
        loadText = 'Generation Complete';
       });
     });
   ....
 }

提取完成后,我们将loadText的值更新为'Generation Complete'。 这表示应用现在可以播放新生成的文件了。

可以在此处查看play_music.dart的整个代码。

在使应用的所有部分正常工作之后,现在让我们通过创建最终的材质应用将所有内容放在一起。

创建最终的材质应用

现在创建main.dart文件。 该文件包含无状态窗口小部件MyApp。 我们重写build()方法并将PlayMusic设置为其子级:

 @override
 Widget build(BuildContext context) {
   return MaterialApp(
     title: 'Flutter Demo',
     theme: ThemeData(
       primarySwatch: Colors.blue,
     ),
     home: PlayMusic(),
   );
 }

在覆盖的build()方法中,我们简单地将home创建为PlayMusic()MaterialApp

整个项目可以在这里查看。

总结

在本章中,我们通过将多媒体处理分解为图像,音频和视频处理的核心组件来进行研究,并讨论了一些最常用的处理工具。 我们看到了使用 OpenCV 执行图像或视频处理变得多么容易。 另外,我们看到了一个使用 Magenta 生成鼓音乐的简单示例。 在本章的下半部分,我们介绍了 LSTM 如何与时间序列数据一起使用,并构建了一个 API,该 API 可以从提供的样本文件生成器乐。 最后,我们将此 API 与 Flutter 应用结合使用,该应用是跨平台的,可以同时部署在 Android,iOS 和 Web 上。

在下一章中,我们将研究如何使用深度强化学习DRL)来创建可以玩棋盘游戏(例如国际象棋)的智能体。

八、基于强化神经网络的国际象棋引擎

在几个在线应用商店以及几乎每个软件商店中,游戏都提供了自己的完整版块。 游戏的重要性和热情不容忽视,这就是为什么全世界的开发人员都在不断尝试开发出更好,更吸引人的游戏的原因。

在流行的棋盘游戏世界中,国际象棋是全世界最有竞争力和最复杂的游戏之一。 已经尝试了一些强大的自动化程序来下棋和与人类竞争。 本章将讨论 DeepMind 的开发人员所使用的方法,他们创建了 Alpha Zero,这是一种自学算法,可以自学下棋,从而能够以一个单打击败市场上当时最好的国际象棋 AI,Stockfish 8。 在短短 24 小时的训练中得分较高。

在本章中,我们将介绍您需要理解的概念,以便构建这种深度强化学习算法,然后构建示例项目。 请注意,该项目将要求您具有 Python 和机器学习的丰富知识。

我们将在本章介绍以下主题:

  • 强化学习导论
  • 手机游戏中的强化学习
  • 探索 Google 的 DeepMind
  • 适用于 Connect 4 的 Alpha 类零 AI
  • 基础项目架构
  • 为国际象棋引擎开发 GCP 托管的 REST API
  • 在 Android 上创建简单的国际象棋 UI
  • 将国际象棋引擎 API 与 UI 集成

让我们从讨论增强学习智能体在手机游戏中的用法和普及程度开始。

强化学习导论

在过去的几年中,强化学习已成为机器学习研究人员中一个重要的研究领域。 人们越来越多地使用它来构建能够在任何给定环境中表现更好的智能体,以寻求对他们所执行行为的更好回报。 简而言之,这为我们提供了强化学习的定义–在人工智能领域,这是一种算法,旨在创建虚拟的智能体,它可在任何给定条件下,在环境中执行动作,在执行一系列动作后,取得最佳的奖励

让我们尝试通过定义与通用强化学习算法关联的变量来赋予此定义更多的结构:

  • 智能体:执行动作的虚拟实体。 是替换游戏/软件的指定用户的实体。
  • 操作a):智能体可以执行的可能操作。
  • 环境e):在软件/游戏中可用的一组场景。
  • 状态S):所有方案的集合,以及其中可用的配置。
  • 奖励R):对于智能体执行的任何操作返回的值,然后智能体尝试将其最大化。
  • 策略π):智能体用来确定接下来必须执行哪些操作的策略。
  • V):R是短期每动作奖励,而值是在一组动作结束时预期的总奖励。V[π](s)通过遵循状态S下的策略π来定义预期的总回报。

下图显示了该算法的流程:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第24张图片

尽管我们在前面的定义列表中没有提到观察者,但必须有观察者或评估者才能产生奖励。 有时,观察者本身可能是一个复杂的软件,但是通常,这是一个简单的评估函数或指标。

要获得关于强化学习的更详细的想法,您可以阅读这个页面上的 Wikipedia 文章。 有关正在使用的强化学习智能体的快速样本,请阅读以下 DataCamp 文章。

在下一部分中,我们将学习强化学习在手机游戏中的地位。

手机游戏中的强化学习

出于各种原因而希望构建具有游戏性的 AI 的开发人员中,强化学习已变得越来越流行-只需检查 AI 的功能,建立可以帮助专业人士改善游戏水平的训练智能体等等。 从研究人员的角度来看,游戏为强化学习智能体提供了最佳的测试环境,可以根据经验做出决策并学习在任何给定环境中的生存/成就。 这是因为可以使用简单而精确的规则设计游戏,从而可以准确预测环境对特定动作的反应。 这使得更容易评估强化学习智能体的表现,从而为 AI 提供良好的训练基础。 考虑到在玩游戏的 AI 方面的突破,也有人表示,我们向通用 AI 的发展速度比预期的要快。 但是强化学习概念如何映射到游戏?

让我们考虑一个简单的游戏,例如井字棋。 另外,如果您觉得古怪,只需使用 Google 搜索井字棋,您就会在搜索结果中看到一个游戏!

考虑您正在用计算机玩井字棋。 这里的计算机是智能体。 在这种情况下,环境是什么? 您猜对了–井字棋板以及在环境中管理游戏的一组规则。 井字棋盘上已经放置的标记可以确定环境所在的状态。座席可以在棋盘上放置的XO是他们可以执行的动作,即输掉,赢得比赛或平局。 或朝着损失,胜利或平局前进是他们执行任何行动后回馈给智能体的奖励。 智能体赢得比赛所遵循的策略是遵循的策略。

因此,从该示例可以得出结论,强化学习智能体非常适合构建学习玩任何游戏的 AI。 这导致许多开发人员想出了象围棋,跳棋,反恐精英等国际象棋以外的几种流行游戏的游戏 AI。 甚至 Chrome Dino 之类的游戏也发现开发人员试图使用 AI 进行游戏。

在下一部分中,我们将简要概述 Google 的 DeepMind,它是游戏 AI 制造商领域中最受欢迎的公司之一。

探索 Google 的 DeepMind

当您谈论自学习人工智能的发展时,DeepMind 可能是最著名的名称之一,这是由于它们在该领域的开创性研究和成就。 自 2015 年 Google 重组以来,DeepMind 在 2014 年被 Google 收购,目前是 Alphabet 的全资子公司。DeepMind 最著名的作品包括 AlphaGo 及其继任者 Alpha Zero。 让我们更深入地讨论这些项目,并尝试了解是什么使它们在当今如此重要。

AlphaGo

2015 年,AlphaGo 成为第一个在19x19棋盘上击败职业围棋选手 Lee Sedol 的计算机软件。 突破被记录下来并作为纪录片发行。 击败李·塞多尔的影响如此之大,以至于韩国 Baduk 协会授予了荣誉 9 丹证书,这实际上意味着围棋选手的游戏技能与神性息息相关。 这是围棋历史上第一次提供 9 荣誉荣誉证书,因此提供给 AlphaGo 的证书编号为 001。ELO 等级为 3,739。

AlphaGo Master 的继任者 AlphaGo Master 在三场比赛中击败了当时统治世界的游戏冠军 Ke Jie。 为了表彰这一壮举,它获得了中国围棋协会颁发的 9 丹证书。 该软件当时的 ELO 等级为 4,858。

但是,这两款软件都被其继任者 AlphaGo Zero 压倒了,后者在 3 天的自学式学习中,能够在 21 分之后以 100:0 的游戏得分击败 AlphaGo,在 89:11 的游戏得分下击败 AlphaGo Master。 天的训练。 40 天后,它的 ELO 评分达到了 5,185,超过了以前所有 Go AI 的技能。

AlphaGo 基于蒙特卡洛树搜索算法,并采用了对生成的和人类玩家游戏日志进行的深度学习。 该模型的初始训练是通过人类游戏进行的。 然后,计算机将与自己对战并尝试改善其游戏性。 树搜索将被设置为一定的深度,以避免巨大的计算开销,在这种开销下,计算机将尝试达到所有可能的动作,然后再进行任何动作。

总而言之,遵循以下过程:

  1. 最初,该模型将在人类游戏日志上进行训练。
  2. 一旦在基线上进行了训练,计算机将使用在先前步骤中训练过的模型与自己竞争,并使用有上限的蒙特卡洛树搜索来确保进行移动而不会长时间停滞该软件。 这些游戏的日志已生成。
  3. 然后对生成的游戏进行了训练,从而改善了整体模型。

现在,让我们讨论 Alpha Zero。

Alpha Zero

Alpha Zero 是 AlphaGo Zero 的后继产品,它是对算法进行泛化的尝试,以便也可以用于其他棋盘游戏。 Alpha Zero 经过训练可以下棋,将棋(类似于棋的日式游戏)和围棋,其表现与相应游戏的现有 AI 相当。 经过 34 小时的训练,Alpha Zero for Go 击败了经过 3 天训练的 AlphaGo Zero,得分为 60:40。 这导致 ELO 等级为 4,430。

经过约 9 个小时的训练,Alpha Zero 击败了 TCEC 竞赛 2016 年冠军的 Stockfish 8。 因此,它仍然是迄今为止最强大的国际象棋 AI,尽管有人声称最新版本的 Stockfish 将能够击败它。

AlphaGo Zero 和 Alpha Zero 变体之间的主要区别如下:

  • 出现平局的可能性:在围棋中,保证有一名选手获胜,而对于象棋则不是这样。 因此,对 Alpha Zero 进行了修改,以允许并列游戏。
  • 对称性:AlphaGo Zero 利用了电路板的对称性。 但是,由于国际象棋不是非对称游戏,因此必须对 Alpha Zero 进行修改以使其工作。
  • 硬编码的超参数搜索:Alpha Zero 具有用于超参数搜索的硬编码规则。
  • 在 Alpha Zero 的情况下,神经网络会不断更新。

此时,您可能会想,“什么是蒙特卡罗树搜索?”。 让我们尝试回答这个问题!

蒙特卡洛树搜索

当我们谈论象棋,围棋或井字棋等基于当前场景的战略游戏时,我们所谈论的是大量可能的场景和可以在任何情况下在其中的给定点执行的动作。 尽管对于井字棋等较小的游戏,可能的状态和动作的数量在现代计算机可以计算的范围内,但对于游戏可以生成的状态数量,更复杂的游戏(如国际象棋和围棋)呈指数增长。

蒙特卡洛树搜索尝试找到在给定环境下赢得任何游戏或获得更好奖励所需要的正确动作序列。 之所以将其称为树搜索是因为它创建了游戏中所有可能状态的树,并通过创建每个状态的分支来实现其中的所有可能动作。 表示为树中的节点。

让我们考虑以下简单的游戏示例。 假设您正在玩一个游戏,要求您猜一个三位数的数字,每个猜中都有一个相关的奖励。 可能的数字范围是 1 到 5,您可以猜测的次数是 3。 如果您做出准确的猜测,即正确猜测任意给定位置的数字,则将获得 5 分。但是,如果您做出错误的猜测,将得到正确数字两边的线性差值的分数。

例如,如果要猜测的数字是 2,则可能获得以下奖励分数:

  • 如果您猜 1,则得分为 4
  • 如果您猜 2,则得分为 5
  • 如果您猜 3,则得分为 4
  • 如果您猜 4,则得分为 3
  • 如果您猜 5,则得分为 2

因此,游戏中的最佳总得分为 15,即每个正确的猜测为 5 分。 鉴于此,您可以在每个步骤中的五个选项中进行选择,游戏中可能的状态总数为5 * 5 * 5 = 125,只有一个状态会给出最佳分数。

让我们尝试在树上描绘前面的游戏。 假设您要猜测的数字是 413。在第一步中,您将具有以下树:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第25张图片

做出选择后,您将获得奖励,再次有五个选项可供选择-换句话说,每个节点中有五个分支可以遍历。 在最佳游戏玩法中,将获得以下树:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第26张图片

现在,让我们考虑以下事实:围棋游戏共有3^361个可能状态。 在 AI 采取行动之前尝试计算每种可能性变得不切实际。 这是蒙特卡罗树搜索与上限可信度算法相结合的地方,它比其他方法更具优势,因为它可以终止到任何搜索深度,并且可以产生趋向于最佳分数的结果。 因此,算法不需要遍历树的每个分支。 一旦树形搜索算法意识到任何特定分支的表现不佳,就可以停止沿该路径前进,而专注于表现更好的路径。 而且,它可以尽早终止任何路径并在该点返回预期的回报,从而可以调整 AI 采取任何行动所需的时间。

更确切地说,蒙特卡罗树搜索遵循以下步骤:

  1. 选择:从树的当前节点中选择最佳回报分支。 例如,在前面的游戏树中,选择除 4 以外的任何分支将产生较低的分数,因此选择了 4。

  2. 扩展:一旦选择了最佳回报节点,该节点下的树将进一步扩展,从而创建具有该节点可用的所有可能选项(分支)的节点。 这可以理解为从游戏的任何位置布局 AI 的未来动作。

  3. 模拟:现在,由于事先不知道在扩展阶段创建的哪个未来选项最有回报,因此我们使用强化学习逐个模拟游戏的每个选项。 请注意,与上限可信度上限算法结合使用时,直到结束游戏才算重要。 计算任何n个步骤的奖励也是一种不错的方法。

  4. 更新:最后,更新节点和父节点的奖励分数。 尽管不可能回到游戏中,并且由于任何节点的值都已减小,但如果在以后的游戏中的那个阶段找到了更好的替代方案,那么 AI 将不会遵循这条路径,从而通过多次迭代来改善其游戏玩法。

接下来,我们将构建一个系统,该系统的工作原理类似于 Alpha Zero,并尝试学习玩 Connect 4 游戏,该游戏比 Tic-Tac-Toe 游戏要复杂得多,但对我们来说足够大,来解释如何构建类似的国际象棋引擎。

适用于 Connect 4 的类似 Alpha Zero 的 AI

在开始研究可玩 Connect4 的 AI 之前,让我们简要了解一下游戏及其动态。 Connect 4,有时也称为连续四人,连续四人,四人以上,等等,是全世界儿童中最受欢迎的棋盘游戏之一。 我们也可以将它理解为井字棋的更高级版本,在其中您必须水平,垂直或对角放置三个相同类型的标记。 棋盘通常是一个6x7的网格,两个玩家各自玩一个标记。

Connect 4 的规则可能会有所不同,因此让我们为 AI 将学习的规则版本制定一些具体规则:

  • 该游戏被模拟为在具有七个空心列和六行的垂直板上玩。 每列在板的顶部都有一个开口,可以在其中插入片段。可以查看已放入板的片段。
  • 两位玩家都有 21 个形状像不同颜色硬币的硬币。
  • 将硬币放在板上构成一个动作。
  • 碎片从顶部的开口下降到最后一行,或者堆积在该列的最后一块。
  • 第一个以任意方向连接其任意四枚硬币的玩家,因此彼此之间不会存在任何间隙或其他玩家的硬币获胜。

现在,让我们分解将 Connect 4 播放式自学 AI 分解为子问题的问题:

  1. 首先,我们需要创建棋盘的虚拟表示。
  2. 接下来,我们必须创建允许根据游戏规则移动的函数。
  3. 然后,为了保存游戏状态,我们需要一个状态管理系统。
  4. 接下来,我们将简化游戏玩法,其中将提示用户进行移动并宣布游戏终止。
  5. 之后,我们必须创建一个脚本,该脚本可以生成示例游戏玩法,供系统学习。
  6. 然后,我们必须创建训练函数来训练系统。
  7. 接下来,我们需要蒙特卡洛树搜索MCTS)实现。
  8. 最后,我们需要一个神经网络的实现。
  9. 除了前面的具体步骤之外,我们还需要为系统创建许多驱动脚本以使其更加可用。

让我们依次移至前面的要点,一次覆盖系统的每个部分。 但是,首先,我们将快速浏览该项目中存在的目录结构和文件,这在本书的 GitHub 存储库中也可以找到。 让我们来看看:

  • command/
  • __init__.py:此文件使我们可以将此文件夹用作模块。
  • arena.py:此文件获取并解析用于运行游戏的命令。
  • generate.py:此文件接受并分析自玩招式生成系统的命令。
  • newmodel.py:此文件用于为智能体创建新的空白模型。
  • train.py:此文件用于训练基于增强学习的神经网络如何玩游戏。
  • util/
  • __init__.py:此文件使我们可以将此文件夹用作模块。
  • arena.py:此文件创建并维护玩家之间进行的比赛的记录,并允许我们在轮到谁之间切换。
  • compat.py:此文件是用于使程序与 Python 2 和 Python 3 兼容的便捷工具。如果您确定正在开发的版本并希望在其上运行,则可以跳过此文件。
  • generate.py:此文件播放一些随机移动的游戏,再加上 MCTS 移动,以生成可用于训练目的的游戏日志。 该文件存储每个游戏的获胜者以及玩家做出的动作。
  • internal.py:此文件创建棋盘的虚拟表示并定义与棋盘相关的函数,例如将棋子放置在棋盘上,寻找获胜者或只是创建新棋盘。
  • keras_model.py:此文件定义充当智能体大脑的模型。 在本项目的后面,我们将更深入地讨论该文件。
  • mcts.py:此文件提供 MCTS 类,该类实质上是蒙特卡罗树搜索的实现。
  • nn.py:此文件提供 NN 类,它是神经网络的实现,以及与神经网络相关的函数,例如拟合,预测,保存等。
  • player.py:此文件为两种类型的播放器提供了类-MCTS 播放器和人工播放器。 MCTS 玩家是我们将训练的智能体,以玩游戏。
  • state.py:这是internal.py文件的包装,提供了用于访问电路板和与电路板相关的函数的类。
  • trainer.py:这使我们可以训练模型。 这与nn.py中提供的内容不同,因为它更专注于涵盖游戏的训练过程,而nn.py中的内容主要是围绕此功能的包装。

接下来,我们将继续探索这些文件中每个文件的一些重要部分,同时遵循我们先前为构建 AI 制定的步骤。

创建棋盘的虚拟表示

您将如何代表 Connect 4 棋盘? 代表 Connect 4 棋盘的两种常用方法以及游戏状态。 让我们来看看:

  • 人类可读的长格式:在这种形式中,木板的行和列分别显示在 x 和 y 轴上,并且两个玩家的标记都显示为xo, 分别(或任何其他合适的字符)。 可能如下所示:
  |1 2 3 4 5 6 7 
--+--------------
 1|. . . . . . .
 2|. . . . . . .
 3|. . . . . . .
 4|. . . . o x .
 5|x o x . o o .
 6|o x x o x x o

但是,这种形式有点冗长并且在计算上不是很友好。

  • 计算有效的形式:在此形式中,我们将板存储为 2D NumPy 数组:
array([[1, 1, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 0, 0, 0]], dtype=int8)

以这种方式创建该数组,当将其展平为一维数组时,板位置按顺序排列,就好像该数组实际上是一维数组一样。 前两个位置分别编号为 0 和 1,而第 5 个位置位于第 5 行和第 5 列,编号为 32。通过将前一个代码块中的矩阵与给定的表进行映射,可以轻松理解此条件。 在下图中:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第27张图片

这种形式适合于进行计算,但不适合玩家在游戏过程中观看,因为对于玩家而言很难解密。

  • 一旦决定了如何表示电路板及其部件,就可以开始在util/internal.py文件中编写代码,如下所示:
BOARD_SIZE_W = 7
BOARD_SIZE_H = 6
KEY_SIZE = BOARD_SIZE_W * BOARD_SIZE_H

前几行设置了板子的常数,在这种情况下,是板子上的行数和列数。 我们还通过将它们相乘来计算板上的按键或位置的数量。

  • 现在,让我们准备在板上生成获胜位置的代码,如下所示:
LIST4 = []
LIST4 += [[(y, x), (y + 1, x + 1), (y + 2, x + 2), (y + 3, x + 3)] for y in range(BOARD_SIZE_H - 3) for x in range(BOARD_SIZE_W - 3)]
LIST4 += [[(y, x + 3), (y + 1, x + 2), (y + 2, x + 1), (y + 3, x)] for y in range(BOARD_SIZE_H - 3) for x in range(BOARD_SIZE_W - 3)]
LIST4 += [[(y, x), (y, x + 1), (y, x + 2), (y, x + 3)] for y in range(BOARD_SIZE_H) for x in range(BOARD_SIZE_W - 3)]
NO_HORIZONTAL = len(LIST4)
LIST4 += [[(y, x), (y + 1, x), (y + 2, x), (y + 3, x)] for y in range(BOARD_SIZE_H - 3) for x in range(BOARD_SIZE_W)]

LIST4变量存储任何玩家赢得比赛时可以实现的可能组合。

我们不会在此文件中讨论整个代码。 但是,重要的是要了解以下函数及其作用:

  • get_start_board():此函数以 NumPy 数组的形式返回电路板的空白 2D 数组表示形式。
  • clone_board(board):此函数用于按板级克隆整个 NumPy 数组。
  • get_action(board):此函数返回播放器已修改的数组中的位置。
  • action_to_string(action):此函数将玩家执行的动作的内部数字表示形式转换为可以以易于理解的形式显示给用户的字符串。 例如place_at(board, pos,
  • player):执行为任何给定玩家在板上放置一块棋子的动作。 它还会更新板。
  • def get_winner(board):此函数确定棋盘当前状态下的游戏是否有赢家。 如果是,则返回获胜玩家的标识符,该标识符将为 1 或 -1。
  • def to_string(board):此函数将板的 NumPy 数组表示形式转换为字符串,该字符串为人类可读的格式。

接下来,我们将研究如何对 AI 进行编程,使其根据游戏规则进行并仅接受有效的动作。

允许根据游戏规则移动

为了确定玩家(无论是人还是机器)做出的动作的有效性,我们需要建立一种机制,在机器的情况下,该机制连续不断地只生成有效的动作,或者不断验证任何人类玩家的输入。 让我们开始吧:

  1. 可以在util/generator.py文件的_selfplay(self, state, args)函数中找到一个这样的实例,如以下代码所示:
turn = 0
hard_random_turn = args['hard_random'] if 'hard_random' in args else 0
soft_random_turn = (args['soft_random'] if 'soft_random' in args else 30) + hard_random_turn
history = []

首先,我们将移动切换设置为0,指示游戏开始时尚未进行任何移动。 我们还考虑了用户在其 AI 自行生成的游戏中想要的硬性和软性随机回合的数量。 然后,我们将移动的历史记录设置为空白。

  1. 现在,我们可以开始为 AI 生成动作,如下所示:
while state.getWinner() == None:
    if turn < hard_random_turn:
        # random action
        action_list = state.getAction()
        index = np.random.choice(len(action_list))
        (action, key) = action_list[index]

前面的代码说,直到没有游戏的获胜者,都必须生成招式。 在前面的案例中,我们可以看到,只要进行一次随机随机转弯的可能性为真,AI 就会选择一个完全随机的位置来放置其棋子。

  1. 通过在前面的if语句中添加else块,我们告诉 AI,只要它需要进行柔和转弯,它就可以检查是否有任何随机位置将其放置在其中,但只能在 MCTS 算法所建议的移动范围内,如下所示:
else:
    action_list = self.mcts.getActionInfo(state, args['simulation'])
    if turn < soft_random_turn:
        # random action by visited count
        visited = [1.0 * a.visited for a in action_list]
        sum_visited = sum(visited)
        assert(sum_visited > 0)
        p = [v / sum_visited for v in visited]
        index = np.random.choice(len(action_list), p = p)
    else:
        # select most visited count
        index = np.argmax([a.visited for a in action_list])

请注意,如果既不进行硬转弯也不进行软转弯,则坐席会在游戏的那一刻进行最常用的动作,这有望使它朝着胜利迈进。

因此,在非人类玩家的情况下,智能体只能在任何给定阶段在一组填充的有效动作之间进行选择。 对于人类玩家而言,情况并非如此,根据他们的创造力,他有可能尝试做出无效的举动。 因此,当人类玩家做出动作时,需要对其进行验证。

  1. 可以在util/player.py文件的getNextAction(self, state)函数中找到验证人类玩家移动的方法,如下所示:
action = state.getAction()
available_x = []
for i in range(len(action)):
    a, k = action[i]
    x = a % util.BOARD_SIZE_W + 1
    y = a // util.BOARD_SIZE_W + 1
    print('{} - {},{}'.format(x, x, y))
    available_x.append(x)
  1. 首先,我们现在计算人类玩家可能采取的合法行动,并将其显示给用户。 然后,我们提示用户输入一个动作,直到他们做出有效的动作为止,如下所示:
while True:
    try:
        x = int(compat_input('enter x: '))
        if x in available_x:
            for i in range(len(action)):
                if available_x[i] == x:
                    select = i
                    break
            break
    except ValueError:
        pass

因此,我们根据填充的一组有效动作来验证用户所做的动作。 我们还可以选择向用户显示错误。

接下来,我们将研究程序的状态管理系统,您肯定已经注意到,到目前为止,我们一直在看该代码。

状态管理系统

游戏的状态管理系统是整个程序中最重要的部分之一,因为它控制着所有的游戏玩法,并在 AI 的自学习过程中促进了游戏玩法。 这样可以确保向玩家展示棋盘,并在进行有效的移动。 它还存储了几个与状态有关的变量,这些变量对于游戏进行很有用。 让我们来看看:

  1. 让我们讨论util/state.py文件中提供的State类中最重要的特性和函数:
import .internal as util

此类使用util/internal.py文件中定义的名称为util的变量和函数。

  1. __init__(self, prototype = None):此类在启动时,会继承现有状态或创建新状态。 该函数的定义如下:
def __init__(self, prototype = None):
    if prototype == None:
        self.board = util.get_start_board()
        self.currentPlayer = 1
        self.winner = None
    else:
        self.board = util.clone_board(prototype.board)
        self.currentPlayer = prototype.currentPlayer
        self.winner = prototype.winner

在这里,您可以看到该类可以使用游戏的现有状态启动,并作为参数传递给该类的构造器; 否则,该类将创建一个新的游戏状态。

  1. getRepresentativeString(self):此函数返回可以由人类玩家读取的游戏状态的格式正确的字符串表示形式。 其定义如下:
def getRepresentativeString(self):
        return ('x|' if self.currentPlayer > 0 else 'o|') + util.to_oneline(self.board)

状态类中的许多其他重要方法如下:

  • getCurrentPlayer(self):此方法返回游戏的当前玩家; 也就是说,应该采取行动的玩家。
  • getWinner(self):如果游戏结束,则此方法返回游戏获胜者的标识符。
  • getAction(self):此方法检查游戏是否结束。 如果没有,它将在任何给定状态下返回一组下一个可能的动作。
  • getNextState(self, action):此方法返回游戏的下一个状态; 也就是说,在将当前正在移动的棋子放在棋盘上并评估游戏是否结束之后,它将执行从一种状态到另一种状态的切换。
  • getNnInput(self):此方法返回玩家到目前为止在游戏中执行的动作,并为每个玩家的动作使用不同的标记。

现在,让我们看一下如何改善程序的游戏玩法。

实现游戏玩法

负责控制程序中游戏玩法的文件是util/arena.py文件。

它在Arena类中定义了以下两种方法:

def fight(self, state, p1, p2, count):
    stats = [0, 0, 0]
    for i in range(count):
        print('==== EPS #{} ===='.format(i + 1))
        winner = self._fight(state, p1, p2)
        stats[winner + 1] += 1
        print('stats', stats[::-1])
        winner = self._fight(state, p2, p1)
        stats[winner * -1 + 1] += 1
        print('stats', stats[::-1])

前面的fight()函数管理玩家的胜利/损失或平局的状态。 它确保在每个回合中进行两场比赛,其中每位玩家只能先玩一次。

此类中定义的另一个_fight()函数如下:

def _fight(self, state, p1, p2):
    while state.getWinner() == None:
        print(state)
        if state.getCurrentPlayer() > 0:
            action = p1.getNextAction(state)
        else:
            action = p2.getNextAction(state)
        state = state.getNextState(action)
    print(state)
    return state.getWinner()

此函数负责切换棋盘上的玩家,直到找到赢家为止。

现在,让我们看一下如何生成随机的游戏玩法以使智能体自学。

生成示例游戏

到目前为止,我们已经讨论了util/gameplay.py文件,以演示该文件中与移动规则相关的代码-特别是该文件的自播放函数。 现在,我们来看看这些自玩游戏如何在迭代中运行以生成完整的游戏玩法日志。 让我们开始吧:

  1. 请考虑此文件提供的Generator类的generate()方法的代码:
def generate(self, state, nn, cb, args):
    self.mcts = MCTS(nn)

    iterator = range(args['selfplay'])
    if args['progress']:
        from tqdm import tqdm
        iterator = tqdm(iterator, ncols = 50)

    # self play
    for pi in iterator:
        result = self._selfplay(state, args)
        if cb != None:
            cb(result)

本质上,此函数负责运行该类的_selfplay()函数,并确定一旦完成自播放后必须执行的操作。 在大多数情况下,您会将输出保存到文件中,然后将其用于训练。

  1. 这已在command/generate.py文件中定义。 该脚本可以作为具有以下签名的命令运行:
usage: run.py generate [-h]
             [--model, default='latest.h5', help='model filename']
             [--number, default=1000000, help='number of generated states']
             [--simulation, default=100, help='number of simulations per move']
             [--hard, default=0, help='number of random moves']
             [--soft, default=1000, help='number of random moves that depends on visited node count']
             [--progress, help='show progress bar']
             [--gpu, help='gpu memory fraction']
             [--file, help='save to a file']
             [--network, help='save to remote server']
  1. 该命令的示例调用如下:
python run.py generate --model model.h5 --simulation 100 -n 5000 --file selfplay.txt --progress

现在,让我们看一下一旦生成自播放日志就可以训练模型的函数。

系统训练

要训​​练智能体,我们需要创建util/trainer.py文件,该文件提供train()函数。 让我们来看看:

  1. 签名如下:
train(state, nn, filename, args = {})

该函数接受State类,神经网络类和其他参数。 它还接受文件名,该文件名是包含生成的游戏玩法的文件的路径。 训练后,我们可以选择将输出保存到另一个模型文件中,如command/train.py文件的train()函数所提供的。

  1. 此命令具有以下签名:
usage: run.py train [-h]
              [--progress, help='show progress bar']
              [--epoch EPOCH, help='training epochs']
              [--batch BATCH, help='batch size']
              [--block BLOCK, help='block size']
              [--gpu GPU, help='gpu memory fraction']
              history, help='history file'
              input, help='input model file name'
              output, help='output model file name'

历史参数是存储生成的游戏玩法的文件。 输入文件是当前保存的模型文件,而输出文件是将新训练的模型保存到的文件。

  1. 该命令的示例调用如下:
python run.py train selfplay.txt model.h5 newmodel.h5 --epoch 3 --progress

现在我们已经有了一个训练系统,我们需要创建 MCTS 和神经网络实现。

实现蒙特卡罗树搜索

util/mcts.py文件中提供了完整的 MCTS 算法实现。 该文件提供了 MCTS 类,该类具有以下重要函数:

  • getMostVisitedAction:此函数返回将状态传递给访问次数最多的操作。
  • getActionInfo:执行任何操作后,此函数返回状态信息。
  • _simulation:此函数执行单个游戏模拟,并返回有关在模拟过程中玩过的游戏的信息。

最后,我们需要创建一个神经网络实现。

实现神经网络

在最后一节中,我们将了解为智能体进行训练而创建的神经网络。 我们将探索util/nn.py文件,该文件提供NN类以及以下重要方法:

  • __init__(self, filename):如果磁盘上不存在此函数,则使用util/keras_model.py函数创建新模型。 否则,它将模型文件加载到程序中。
  • util/keras_model.py文件中定义的模型是残差 CNN,它与 MCTS 和 UCT 结合使用,表现得像深度强化学习神经网络。 形成的模型具有以下配置:
input_dim: (2, util.BOARD_SIZE_H, util.BOARD_SIZE_W),
policy_dim: util.KEY_SIZE,
res_layer_num: 5,
cnn_filter_num: 64,
cnn_filter_size: 5,
l2_reg: 1e-4,
learning_rate: 0.003,
momentum: 0.9

默认情况下,模型具有五个残差卷积层块。 我们先前在util/internal.py文件中定义了BOARD_SIZE_HBOARD_SIZE_WKEY_SIZE常量:

  • save(self, filename):此函数将模型保存到提供的文件名中。
  • predict(self, x):提供了板状态以及已经进行的移动,此函数输出可以下一步进行的单个移动。
  • fit(self, x, policy, value, batch_size = 256, epochs = 1):此函数用于将新样本拟合到模型并更新权重。

除了上述脚本之外,我们还需要一些驱动脚本。 您可以在该项目的存储库中查找它们,以了解它们的用法。

要运行已完成的项目,您需要执行以下步骤:

  1. 使用以下命令创建新模型:
python run.py newmodel model.h5

这将创建一个新模型并打印出其摘要。

  1. 生成示例游戏日志:
python run.py generate --model model.h5 --simulation 100 -n 5000 --file selfplay.txt --progress

在仿真过程中,上一行为 MCTS 生成了 5,000 个示例游戏,深度为 100。

  1. 训练模型:
python run.py train selfplay.txt model.h5 newmodel.h5 --epoch 3 --progress

前面的命令在游戏文件上训练模型三个时间,并将训练后的模型另存为newmodel.h5

  1. 与 AI 对抗:
python run.py arena human mcts,newmodel.h5,100

前面的命令开始与 AI 进行游戏。 在这里,您将在终端中看到一个面板和游戏选项,如下所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第28张图片

现在,我们已经成功创建了一个基于 Alpha Zero 的程序来学习玩棋盘游戏,现在我们可以将其推论到国际象棋 AI 上了。 但是,在这样做之前,我们将简要地介绍项目架构。

基础项目架构

为了创建国际象棋引擎,将其作为 REST API 托管在 GCP 上,我们将遵循常规项目架构:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第29张图片

虽然上图提供了该项目的非常简化的概述,但它可以用于更复杂的系统,这些系统可以产生更好的自学习象棋引擎。

GCP 上托管的模型将放置在 EC2 VM 实例中,并将包装在基于 Flask 的 REST API 中。

为国际象棋引擎开发 GCP 托管的 REST API

现在我们已经看到了如何继续进行此项目,我们还需要讨论如何将 Connect 4 的游戏映射到国际象棋,以及如何将国际象棋 RL 引擎部署为 API。

您可以在这个页面上找到我们为该象棋引擎创建的文件。 在将这些文件与 Connect 4 项目中的文件映射之前,让我们快速了解一些最重要的文件:

  • src/chess_zero/agent/
  • player_chess.py:此文件描述ChessPlayer类,该类保存有关在任何时间点玩游戏的玩家的信息。 它为与使用蒙特卡洛树搜索来搜索新动作,更改玩家状态以及每个用户在玩游戏期间所需的其他功能的相关方法提供了包装。
  • model_chess.py:此文件描述了此系统中使用的剩余 CNN。
  • src/chess_zero/config/
  • mini.py:此文件定义国际象棋引擎学习或玩的配置。 您将需要在此处有时调整这些参数,以降低在低端计算机上进行训练期间的批量大小或虚拟 RAM 消耗。
  • src/chess_zero/env/
  • chess_env.py:此文件描述棋盘的设置,游戏规则以及执行游戏操作所需的函数。 它还包含检查游戏状态和验证移动的方法。
  • src/chess_zero/worker/
  • evaluate.py:此文件负责与当前最佳模型和下一代模型玩游戏。 如果下一代模型的表现优于 100 款游戏,则它将替代以前的模型。
  • optimize.py:此文件加载当前最佳模型,并在其上执行更多监督的基于学习的训练。
  • self.py:引擎与自己对战并学习新的游戏玩法。
  • sl.py:监督学习的缩写,此文件将来自其他玩家的游戏的 PGN 文件作为输入,并对其进行监督学习。
  • src/chess_zero/play_game/
  • uci.py:此文件提供了通用国际象棋界面UCI)标准环境,可以与引擎进行交互。
  • flask_server.py:该文件创建一个 Flask 服务器,该服务器使用国际象棋游戏的 UCI 表示法与引擎进行通信。

现在我们知道每个文件的作用,让我们建立这些文件与 Connect 4 游戏中文件的映射。

还记得我们在讨论 Connect 4 AI 时制定的步骤吗? 让我们看看国际象棋项目是否也遵循相同的步骤:

  1. 创建棋盘的虚拟代表。 这是在src/chess_zero/env/chess_env.py文件中完成的。
  2. 创建允许根据游戏规则进行移动的函数。 这也可以在src/chess_zero/env/chess_env.py文件中完成。
  3. 原地的状态管理系统:此功能在许多文件上维护,例如src/chess_zero/agent/player_chess.pysrc/chess_zero/env/chess_env.py
  4. 简化游戏:这是通过src/chess_zero/play_game/uci.py文件完成的。
  5. 创建一个可以生成示例游戏玩法的脚本,以供系统学习。 尽管此系统未将生成的游戏玩法明确地存储为磁盘上的文件,但该任务由src/chess_zero/worker/self_play.py执行。
  6. 创建训练函数来训练系统。 这些训练函数位于src/chess_zero/worker/sl.pysrc/chess_zero/worker/self.py处。
  7. 现在,我们需要一个 MCTS 实现。 可以在src/chess_zero/agent/player_chess.py的文件的移动搜索方法中找到该项目的 MCTS 实现。
  8. 神经网络的实现src/chess_zero/agent/model_chess.py中定义了项目的神经网络。

除了前面的映射之外,我们还需要讨论 Universal Chess Interface 和 Flask 服务器脚本,这两个都是游戏性和 API 部署所必需的。

了解通用国际象棋界面

/src/chess_zero/play_game/uci.py上的文件为引擎创建了通用国际象棋界面。 但是,UCI 到底是什么?

UCI 是 Rudolf Huber 和 Stefan Meyer-Kahlen 引入的一种通信标准,它允许在任何控制台环境中使用国际象棋引擎进行游戏。 该标准使用一小组命令来调用国际象棋引擎,以搜索并输出板子任何给定位置的最佳动作。

通过 UCI 进行的通信与标准输入/输出发生,并且与平台无关。 在我们程序的 UCI 脚本中可用的命令如下:

  • uci:打印正在运行的引擎的详细信息。
  • isready:这查询引擎是否准备好进行对抗。
  • ucinewgame:这将启动带有引擎的新游戏。
  • position [fen | startpos] moves:此设置板的位置。 如果用户从非起始位置开始,则用户需要提供 FEN 字符串来设置板。
  • go:这要求引擎进行搜索并提出最佳建议。
  • quit:这将结束游戏并退出界面。

以下代码显示了带有 UCI 引擎的示例游戏玩法:

> uci
id name ChessZero
id author ChessZero
uciok

> isready
readyok

> ucinewgame

> position startpos moves e2e4

> go
bestmove e7e5

> position rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 1 moves g1f3

> go
bestmove b8c6

> quit

要快速生成任何板位置的 FEN 字符串,可以使用板编辑器。

现在,让我们讨论一下 Flask 服务器脚本以及如何在 GCP 实例上部署它。

在 GCP 上部署

该国际象棋引擎程序需要存在 GPU。 因此,我们必须遵循其他步骤,才能在 GCP 实例上部署脚本。

大致的工作流程如下:

  1. 请求增加帐户可用的 GPU 实例的配额。
  2. 创建基于 GPU 的计算引擎实例。
  3. 部署脚本。

我们将在以下各节中详细介绍这些步骤。

请求增加 GPU 实例的配额

第一步将是请求增加 GPU 实例的配额。 默认情况下,您的 GCP 帐户上可拥有的 GPU 实例数为 0。此限制由您的帐户的配额配置设置,您需要请求增加。 这样做,请按照下列步骤操作:

  1. 通过这里打开 Goog​​le Cloud Platform 控制台。
  2. 在左侧菜单上,单击“IAM&Admin | 配额”,如以下屏幕截图所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第30张图片

  1. 单击Metrics过滤器,然后键入 GPU 以找到读取 GPU(所有区域)的条目,如以下屏幕截图所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第31张图片

  1. 选择条目,然后单击“编辑配额”。
  2. 系统将要求您提供身份证明,包括您的电话号码。 填写详细信息,然后单击“下一步”。
  3. 输入您想要将 GPU 配额设置为的限制(最好是1,以避免滥用)。 另外,请提供您提出要求的理由,例如学术研究,机器学习探索或任何适合您的东西!
  4. 单击“提交”。

提出要求后,大约需要 10 到 15 分钟才能将您的配额增加/设置为您指定的数量。 您将收到一封电子邮件,通知您有关此更新。 现在,您准备创建一个 GPU 实例。

创建一个 GPU 实例

下一步是创建 GPU 实例。 创建 GPU 实例的过程与创建非 GPU 实例的过程非常相似,但是需要额外的步骤。 让我们快速完成所有这些步骤:

  1. 在您的 Google Cloud Platform 仪表板上,单击左侧导航菜单中的“Compute Engine | VM 实例”。
  2. 单击“创建实例”。
  3. 单击“计算机类型选择”部分正下方的 CPU 平台和 GPU,如以下屏幕截图所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第32张图片

  1. 单击“添加 GPU”(大加号(+)按钮)。 选择要附加到此 VM 的 GPU 类型和 GPU 数量。
  2. 将启动盘操作系统更改为 Ubuntu 版本 10.10。
  3. 在“防火墙”部分中,检查 HTTP 和 HTTPS 通信权限,如以下屏幕截图所示:

  1. 单击表单底部的“创建”。

几秒钟后,您的实例将成功创建。 如果遇到任何错误,例如超出了区域资源限制,请尝试更改要在其中创建实例的区域/区域。这通常是一个临时问题。

现在,我们可以部署 Flask 服务器脚本。

部署脚本

现在,我们将部署 Flask 服务器脚本。 但是在我们这样做之前,让我们先看一下该脚本的作用:

  1. 脚本的前几行导入了必要的模块,脚本才能正常工作:
from flask import Flask, request, jsonify
import os
import sys
import multiprocessing as mp
from logging import getLogger

from chess_zero.agent.player_chess import ChessPlayer
from chess_zero.config import Config, PlayWithHumanConfig
from chess_zero.env.chess_env import ChessEnv

from chess_zero.agent.model_chess import ChessModel
from chess_zero.lib.model_helper import load_best_model_weight

logger = getLogger(__name__)
  1. 其余代码放入start()函数中,该函数由config对象实例化:
def start(config: Config):
    ## rest of the code
  1. 以下几行创建了引擎和人类玩家的实例,并在脚本开始运行时重置了游戏环境:
def start(config: Config):
    ...
    PlayWithHumanConfig().update_play_config(config.play)

    me_player = None
    env = ChessEnv().reset()
    ...
  1. 将创建模型,并使用以下代码将模型的最佳权重加载到其中:
def start(config: Config):
    ...
    model = ChessModel(config)

        if not load_best_model_weight(model):
            raise RuntimeError("Best model not found!")

    player = ChessPlayer(config, model.get_pipes(config.play.search_threads))
    ...
  1. 前面代码中的最后一行创建具有指定配置和模型知识的国际象棋引擎玩家实例:
def start(config: Config):
    ...
    app = Flask(__name__)

        @app.route('/play', methods=["GET", "POST"])
        def play():
            data = request.get_json()
            print(data["position"])
            env.update(data["position"])
            env.step(data["moves"], False)
            bestmove = player.action(env, False)
            return jsonify(bestmove) 
    ...

前面的代码创建了 Flask 服务器应用的实例。 定义/play路由,使其可以接受位置并移动参数,这与我们先前在 UCI 游戏中使用的命令相同。

  1. 游戏状态将更新,并且要求象棋引擎计算下一个最佳移动。 这以 JSON 格式返回给用户:
def start(config: Config):
    ...
    app.run(host="0.0.0.0", port="8080")

脚本的最后一行在主机0.0.0.0处启动 Flask 服务器,这意味着脚本将监听其运行所在设备的所有打开的 IP。 指定的端口为8080

  1. 最后,我们将脚本部署到我们创建的 VM 实例。 为此,请执行以下步骤:

  2. 打开 GCP 控制台的 VM 实例页面。

  3. 输入在上一节中创建的 VM 后,单击SSH按钮。

  4. SSH 会话激活后,通过运行以下命令来更新系统上的存储库:

sudo apt update
  1. 接下来,使用以下命令克隆存储库:
git clone https://github.com/PacktPublishing/Mobile-Deep-Learning-Projects.git
  1. 将当前工作目录更改为chess文件夹,如下所示:
cd Mobile-Deep-Learning-Projects/Chapter8/chess
  1. 为 Python3 安装 PIP:
sudo apt install python3-pip
  1. 安装项目所需的所有模块:
pip3 install -r requirements.txt
  1. 为最初的监督学习提供训练 PGN。 您可以从这里下载示例 PGN。 ficsgamesdb2017.pgn文件包含 5,000 个已存储的游戏。 您需要将此文件上传到data/play_data/文件夹。
  2. 运行监督学习命令:
python3 src/chess_zero/run.py sl
  1. 运行自学习命令:
python3 src/chess_zero/run.py self

当您对程序可以自行播放的时间感到满意时,请使用Ctrl + C/Z停止脚本。

  1. 运行以下命令以启动服务器:
python3 src/chess_zero/run.py server

现在,您应该能够将职位和移动发送到服务器并获得响应。 让我们快速测试一下。 使用 Postman 或其他任何用于 API 测试的工具,我们将使用 FEN 字符串向 API 发出请求,以设置位置和正在进行的移动。

假设您的 VM 实例正在公共 IP 地址上运行(在 VM 实例仪表板的实例条目上可见)1.2.3.4。 在这里,我们发送以下POST请求:

endpoint: http://1.2.3.4:8080/play
Content-type: JSON
Request body:
{
  "position": "r1bqk2r/ppp2ppp/2np1n2/2b1p3/2B1P3/2N2N2/PPPPQPPP/R1B1K2R w KQkq - 0 1",
  "moves": "f3g5"
}

先前代码的输出为"h7h6"。 让我们直观地了解这种交互。 FEN 中定义的板看起来如下:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第33张图片

我们告诉服务器这是怀特的举动,而怀特玩家的举动是f3g5,这意味着将怀特骑士移动到板上的 G5 位置。 我们传递给 API 的棋盘 FEN 字符串中的'w'表示白人玩家将进行下一回合。

引擎通过将 H7 处的棋子移动到 H6 进行响应,威胁到马的前进,如以下屏幕快照所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第34张图片

现在,我们可以将此 API 与 Flutter 应用集成!

在 Android 上创建简单的国际象棋 UI

现在,我们了解了强化学习以及如何使用它来开发可部署到 GCP 的国际象棋引擎,让我们为游戏创建 Flutter 应用。 该应用将具有两个播放器–用户和服务器。 用户是玩游戏的人,而服务器是我们在 GCP 上托管的国际象棋引擎。 首先,用户采取行动。 记录此移动并将其以 POST 请求的形式发送到国际象棋引擎。 然后,国际象棋引擎以自己的动作进行响应,然后在屏幕上进行更新。

我们将创建一个简单的单屏应用,将棋盘和棋子放置在中间。 该应用将显示如下:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第35张图片

该应用的小部件树如下所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-LwjjJFDo-1681785128422)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/6f6a925e-3b98-4194-aa42-ee318f257593.png)]

让我们开始编写应用代码。

将依赖项添加到pubspec.yaml

首先,将chess_vectors_flutter包添加到pubspec.yaml文件中,以便在将要构建的棋盘上显示实际的棋子。 将以下行添加到pubspec.yaml的依赖项部分:

chess_vectors_flutter: ">=1.0.6 <2.0.0"

运行flutter pub get安装包。

将棋子放置在正确的位置可能会有些棘手。 让我们了解将所有片段放置在正确位置的约定。

了解映射结构

我们将首先创建一个名为chess_game.dart的新 dart 文件。 这将包含所有游戏逻辑。 在文件内部,我们声明一个名为ChessGame的有状态小部件:

  1. 要将棋子映射到棋盘的正方形,我们将使用与构建模型时相同的符号,以便每个正方形均由字母和数字表示。 我们将在ChessGameState内创建一个列表squareList,以便我们可以存储所有索引的正方形,如下所示:
var squareList = [ 
 ["a8","b8","c8","d8","e8","f8","g8","h8"],
 ["a7","b7","c7","d7","e7","f7","g7","h7"],
 ["a6","b6","c6","d6","e6","f6","g6","h6"],
 ["a5","b5","c5","d5","e5","f5","g5","h5"],
 ["a4","b4","c4","d4","e4","f4","g4","h4"],
 ["a3","b3","c3","d3","e3","f3","g3","h3"],
 ["a2","b2","c2","d2","e2","f2","g2","h2"],
 ["a1","b1","c1","d1","e1","f1","g1","h1"],
 ];
  1. 为了将正确的棋子存储在正确的正方形中并根据玩家的移动来更新它们,我们将创建一个名为boardHashMap
HashMap board = new HashMap<String, String>();

HashMap的键将包含正方形的索引,而值将是正方形将保留的片段。 我们将使用一个字符串来表示一块特定的作品,该字符串将根据作品的名称包含一个字母。 例如,K代表王,B代表相。 我们通过使用大写和小写字母来区分白色和黑色部分。 大写字母代表白色,小写字母代表黑色。 例如,K代表白王,b代表黑相。 board['e7'] = "P"表示索引为'e7'的盒子当前有一个白色棋子。

  1. 现在,让我们将它们放置在初始位置。 为此,我们需要定义initializeBoard()方法,如下所示:
 void initializeBoard() {
   setState(() {
     for(int i = 8; i >= 1; i--) {
       for(int j = 97; j <= 104; j++) {
         String ch = String.fromCharCode(j)+'$i';
         board[ch] = " ";
       }
     }

   //Placing White Pieces
   board['a1'] = board['h1']= "R";
   board['b1'] = board['g1'] = "N";
   board['c1'] = board['f1'] = "B";
   board['d1'] = "Q";
   board['e1'] = "K";
   board['a2'] = board['b2'] = board['c2'] = board['d2'] =
   board['e2'] = board['f2'] = board['g2'] = board['h2'] = "P";

   //Placing Black Pieces
   board['a8'] = board['h8']= "r";
   board['b8'] = board['g8'] = "n";
   board['c8'] = board['f8'] = "b";
   board['d8'] = "q";
   board['e8'] = "k";
   board['a7'] = board['b7'] = board['c7'] = board['d7'] =
   board['e7'] = board['f7'] = board['g7'] = board['h7'] = "p";
   });
 }

在前面的方法中,我们使用一个简单的嵌套循环通过从ah的所有行以及从 1 到 8 的所有列进行遍历,使用空白字符串初始化哈希映射板的所有索引。 如“步骤 2”中所述,将其放置在其初始位置上。 为了确保在初始化棋盘时重新绘制 UI,我们将整个分配放在setState()中。

  1. 屏幕启动后,板将被初始化。 为了确保这一点,我们需要覆盖initState()并从那里调用initializeBoard()
 @override
 void initState() {
   super.initState();
   initializeBoard();
 } 

现在我们对映射棋子有了更好的了解,让我们开始在屏幕上放置棋子的实际图像。

放置实际片段的图像

将片段映射到其初始位置后,我们可以开始放置实际的图像向量:

  1. 我们首先定义一个名为mapImages()的函数,该函数采用正方形的索引(即哈希图板的键值)并返回图像:
Widget mapImages(String squareName) {
   board.putIfAbsent(squareName, () => " ");
   String p = board[squareName];
   var size = 6.0;
   Widget imageToDisplay = Container();
   switch (p) {
     case "P":
       imageToDisplay = WhitePawn(size: size);
       break;
     case "R":
       imageToDisplay = WhiteRook(size: size);
       break;
     case "N":
       imageToDisplay = WhiteKnight(size: size);
       break;
     case "B":
       imageToDisplay = WhiteBishop(size: size);
       break;
     case "Q":
       imageToDisplay = WhiteQueen(size: size);
       break;
     case "K":
       imageToDisplay = WhiteKing(size: size);
       break;
     case "p":
       imageToDisplay = BlackPawn(size: size);
       break;
     case "r":
       imageToDisplay = BlackRook(size: size);
       break;
     case "n":
       imageToDisplay = BlackKnight(size: size);
       break;
     case "b":
       imageToDisplay = BlackBishop(size: size);
       break;
     case "q":
       imageToDisplay = BlackQueen(size: size);
       break;
     case "k":
       imageToDisplay = BlackKing(size: size);
       break;
     case "p":
       imageToDisplay = BlackPawn(size: size);
       break;
   }
   return imageToDisplay;
 }

在前面的函数中,我们构建一个与矩形中所含件名相对应的开关盒块。 我们使用哈希图在特定的正方形中找到片段,然后返回相应的图像。 例如,如果将a1的值传递到squareName中,并且哈希图板具有与键值a1对应的值P,则白兵的图像将存储在变量imageToDisplay中。

请注意,在 64 个棋盘格正方形中,只有 32 个包含棋子。 其余将为空白。 因此,在哈希表board中,将存在没有值的键。 如果squareName没有片段,则将其传递给imageToDisplay变量,该变量将只有一个空容器。

  1. 在上一步中,我们构建了对应于棋盘上每个正方形的小部件(图像或空容器)。 现在,让我们将所有小部件排列成行和列。 squareName中的特定元素(例如[a1,b1,....,g1])包含应并排放置的正方形。 因此,我们将它们包装成一行并将这些行中的每一个包装成列。

  2. 让我们从定义buildRow()方法开始,该方法包含一个列表。 这本质上是sqaureName中的元素列表,并构建完整的行。 该方法如下所示:

 Widget buildRow(List<String> children) {
    return Expanded(
      flex: 1,
      child: Row(
        children: children.map((squareName) => getImage(squareName)).toList()
      ),
    );
  }

在前面的代码片段中,我们迭代使用map()方法传递的列表的每个元素。 这会调用getImage()以获取对应于正方形的适当图像。 然后,我们将所有这些返回的图像添加为一行的子级。 该行将一个子代添加到展开的窗口小部件并返回。

  1. getImage()方法定义如下:
 Widget getImage(String squareName) {
   return Expanded(
     child: mapImages(squareName),
   );
 }

只需输入squareName的值,然后返回一个扩展的小部件,其中将包含我们先前定义的mapImages返回的图像。 我们稍后将修改此方法,以确保玩家可以拖动每个图像,以便它们可以在棋盘上移动。

  1. 现在,我们需要构建将包含已构建行的列。 为此,我们需要定义buildChessBoard()方法,如下所示:
  Widget buildChessBoard() {
    return Container(
      height: 350,
      child: Column(
            children: widget.squareList.map((row) {
                return buildRow(row,);
                }).toList()   
      )
    );
  }

在前面的代码中,我们迭代了squareList内部的每一行,这些行表示为一个列表。 我们通过调用buildRow()来构建行,并将它们作为子级添加到列中。 此列作为子级添加到容器中并返回。

  1. 现在,让我们将所有片段以及实际的棋盘图像放到屏幕上。 我们将覆盖build()方法,以构建由棋盘图像及其碎片组成的小部件栈:
@override
 Widget build(BuildContext context) {
   return Container(
       child: Stack(
         children: <Widget>[
           Container(
             child: new Center(child: Image.asset("assets/chess_board.png", fit: BoxFit.cover,)),
           ),
           Center(
             child: Container(
               child: buildChessBoard(),
             ),
           )
         ],
     )
   );
 }

前面的方法使用容器来构建栈,该容器添加存储在assets文件夹中的棋盘图像。 栈的下一个子项是居中对齐的容器,其中所有片段图像都通过对buildChessBoard()的调用以小部件的形式添加为行和列包装。 整个栈作为子级添加到容器中并返回,以便出现在屏幕上。

此时,应用显示棋盘,以及所有放置在其初始位置的棋子。 如下所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第36张图片

现在,让我们使这些棋子变得可移动,以便我们可以玩一个真实的游戏。

使片段移动

在本节中,我们将用可拖动的工具包装每块棋子,以便用户能够将棋子拖动到所需位置。 让我们详细看一下实现:

  1. 回想一下,我们声明了一个哈希图来存储片段的位置。 移动将包括从一个盒子中移出一块并将其放在另一个盒子中。 假设我们有两个变量'from''to',它们存储用于移动片段的盒子的索引。 进行移动后,我们拿起'from'处的片段并将其放入'to'中。 因此,'from'的框变为空。 按照相同的逻辑,我们将定义refreshBoard()方法,该方法在每次移动时都会调用:
void refreshBoard(String from, String to) {
   setState(() {
     board[to] = board[from];
     board[from] = " ";
   });
 }

fromto变量存储源和目标正方形的索引。 这些值在board HasMhap 中用作键。 进行移动时,from处的棋子会移至to.。此后,from处的方块应该变空。 它包含在setState()中,以确保每次移动后都更新 UI。

  1. 现在,让我们将其拖曳。 为此,我们将拖动项附加到getPieceImage()方法返回的木板的每个图像小部件上。 我们通过修改方法来做到这一点:
Widget getImage(String squareName) {
   return Expanded(
     child: DragTarget<List>(builder: (context, accepted, rejected) {
             return Draggable<List>(
                 child: mapImages(squareName),
                 feedback: mapImages(squareName),
                 onDragCompleted: () {},
                 data: [
                   squareName,
                 ],
               );
       }, onWillAccept: (willAccept) {
         return true;
       }, onAccept: (List moveInfo) {
         String from = moveInfo[0];
         String to = squareName;
         refreshBoard(from, to);
       })
     );
 }

在前面的函数中,我们首先将特定正方形的图像包装在Draggable中。 此类用于感测和跟随屏幕上的拖动手势。 child属性用于指定要拖动的窗口小部件,而反馈内部的窗口小部件用于跟踪手指在屏幕上的移动。 当拖动完成并且用户抬起手指时,目标将有机会接受所携带的数据。 由于我们正在源和目标之间移动,因此我们将添加Draggable作为DragTarget的子代,以便可以在源和目标之间移动小部件。 onWillAccept设置为true,以便可以进行所有移动。

可以修改此属性,以使其具有可以区分合法象棋动作并且不允许拖动非法动作的功能。 放下片段并完成拖动后,将调用onAcceptmoveInfo列表保存有关拖动源的信息。 在这里,我们调用refreshBoard(),并传入fromto的值,以便屏幕可以反映运动。 至此,我们完成了向用户显示初始棋盘的操作,并使棋子可以在盒子之间移动。

在下一节中,我们将通过对托管的国际象棋服务器进行 API 调用来增加应用的交互性。 这些将使游戏栩栩如生。

将国际象棋引擎 API 与 UI 集成

托管的棋牌服务器将作为对手玩家添加到应用中。 用户将是白色的一面,而服务器将是黑色的一面。 这里要实现的游戏逻辑非常简单。 第一步是提供给应用用户。 用户进行移动时,他们将棋盘的状态从状态 X 更改为状态 Y。棋盘的状态由 FEN 字符串表示。 同样,他们将一块from移到一个特定的正方形to移到一个特定的正方形,这有助于他们的移动。 当用户完成移动时,状态 X 的 FEN 字符串及其当前移动(通过将fromto正方形连接在一起而获得)以POST请求的形式发送到服务器。 作为回报,服务器从其侧面进行下一步移动,然后将其反映在 UI 上。

让我们看一下此逻辑的代码:

  1. 首先,我们定义一个名为getPositionString()的方法来为应用的特定状态生成 FEN 字符串:
String getPositionString(String move) {
    String s = "";
    for(int i = 8; i >= 1; i--) {
        int count = 0;
        for(int j = 97; j <= 104; j++) {
            String ch = String.fromCharCode(j)+'$i';
            if(board[ch] == " ") {
                count += 1;
                if(j == 104) 
                    s = s + "$count";
            } else {
                if(count > 0) 
                    s = s + "$count";
                s = s + board[ch];count = 0;
            }
        }
    s = s + "/";
    }
    String position = s.substring(0, s.length-1) + " w KQkq - 0 1";
    var json = jsonEncode({"position": position, "moves": move});
}

在前面的方法中,我们将move作为参数,它是fromto变量的连接。 接下来,我们为棋盘的当前状态创建 FEN 字符串。 创建 FEN 字符串背后的逻辑是,我们遍历电路板的每一行并为该行创建一个字符串。 然后将生成的字符串连接到最终字符串。

让我们借助示例更好地理解这一点。 考虑一个rnbqkbnr/pp1ppppp/8/1p6/8/3P4/PPP1PPPP/RNBQKBNR w KQkq - 0 1的 FEN 字符串。 在此,每行可以用八个或更少的字符表示。 特定行的状态通过使用分隔符“/”与另一行分开。 对于特定的行,每件作品均以其指定的符号表示,其中P表示白兵,b表示黑相。 每个占用的正方形均由件符号明确表示。 例如,PpkB指示板上的前四个正方形被白色棋子,黑色棋子,黑色国王和白色主教占据。 对于空盒子,使用整数,该数字表示可传染的空盒子的数量。 注意示例 FEN 字符串中的8。 这表示该行的所有 8 个正方形均为空。 3P4表示前三个正方形为空,第四个方框被白色棋子占据,并且四个正方形为空。

getPositionString()方法中,我们迭代从 8 到 1 的每一行,并为每行生成一个状态字符串。 对于每个非空框,我们只需在's'变量中添加一个表示该块的字符。 对于每个空框,当找到非空框或到达行末时,我们将count的值增加 1 并将其连接到's'字符串。 遍历每一行后,我们添加“/”以分隔两行。 最后,我们通过将生成的's'字符串与w KQkq - 0 1连接来生成位置字符串。 然后,我们通过将jsonEncode()与键值对结合使用来生成所需的 JSON 对象

  1. 我们使用“步骤 1”的“步骤 1”中的fromto变量来保存用户的当前移动。 我们可以通过在refreshBoard()方法中添加两行来实现:
void refreshBoard(String from, String to) {
    String move= from + to;
    getPositionString(move);
    .....
}

在前面的代码片段中,我们将fromto的值连接起来,并将它们存储在名为move的字符串变量中。 然后,我们调用getPositionString(),并将move的值传递给参数。

  1. 接下来,我们使用在上一步中makePOSTRequest()方法中生成的JSON向服务器发出POST请求:
void makePOSTRequest(var json) async{
    var url = 'http://35.200.253.0:8080/play';
    var response = await http.post(url, headers: {"Content-Type": "application/json"} ,body: json);
    String rsp = response.body;
    String from = rsp.substring(0,3);
    String to = rsp.substring(3);
}

首先,将国际象棋服务器的 IP 地址存储在url变量中。 然后,我们使用http.post()发出HTTP POST请求,并为 URL,标头和正文传递正确的值。 POST 请求的响应包含服务器端的下一个动作,并存储在变量响应中。 我们解析响应的主体并将其存储在名为rsp的字符串变量中。 响应基本上是一个字符串,是服务器端的源方和目标方的连接。 例如,响应字符串f4a3表示国际象棋引擎希望将棋子以f4正方形移动到a3正方形。 我们使用substring()分隔源和目标,并将值存储在fromto变量中。

  1. 现在,通过将调用添加到makePOSTrequest()来从getPositionString()发出 POST 请求:
String getPositionString(String move) {
    .....
    makePOSTRequest(json);
}

在 FEN 字符串生成板的给定状态之后,对makePOSTrequest()的调用添加在函数的最后。

  1. 最后,我们使用refreshBoardFromServer()方法刷新板以反映服务器在板上的移动:
void refreshBoardFromServer(String from, String to) {
    setState(() {    
        board[to] = board[from];
        board[from] = " ";
    });
}

前述方法中的逻辑非常简单。 首先,我们将映射到from索引正方形的片段移动到to索引正方形,然后清空from索引正方形。

  1. 最后,我们调用适当的方法以用最新的动作更新 UI:
void makePOSTRequest(var json) async{
    ......
    refreshBoardFromServer(from, to);
    buildChessBoard();
}

发布请求成功完成后,我们收到了服务器的响应,我们将调用refreshBoardFromServer()以更新板上的映射。 最后,我们调用buildChessBoard()以在应用屏幕上反映国际象棋引擎所做的最新动作。

以下屏幕快照显示了国际象棋引擎进行移动后的更新的用户界面:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第37张图片

请注意,黑色的块在白色的块之后移动。 这就是代码的工作方式。 首先,用户采取行动。 它以板的初始状态发送到服务器。 然后,服务器以其移动进行响应,更新 UI。 作为练习,您可以尝试实现一些逻辑以区分有效动作和无效动作。

可以在这个页面中找到此代码。

现在,让我们通过创建材质应用来包装应用。

创建材质应用

现在,我们将在main.dart中创建最终的材质应用。 让我们从以下步骤开始:

  1. 首先,我们创建无状态窗口小部件MyApp,并覆盖其build()方法,如下所示:
class MyApp extends StatelessWidget {
    @override
    Widget build(BuildContext context) {
        return MaterialApp(
            title: 'Chess',
            theme: ThemeData(primarySwatch: Colors.blue,),
            home: MyHomePage(title: 'Chess'),
        );
    }
}
  1. 我们创建一个单独的StatefulWidget,称为MyHomePage,以便将 UI 放置在屏幕中央。 MyHomePagebuild()方法如下所示:
@override
Widget build(BuildContext context) {
    return Scaffold(
        appBar: AppBar(title: Text('Chess'),),
        body: Center(
            child: Column(
                mainAxisAlignment: MainAxisAlignment.center,
                children: <Widget>[ChessGame()
                ],
            ),
        ),
    );
}
  1. 最后,我们通过在main.dart中添加以下行来执行整个代码:
void main() => runApp(MyApp());

而已! 现在,我们有一个交互式的国际象棋游戏应用,您可以与聪明的对手一起玩。 希望你赢!

整个文件的代码可以在这个页面中找到。

总结

在此项目中,我们介绍了强化学习的概念以及为什么强化学习在创建游戏性 AI 的开发人员中很受欢迎。 我们讨论了 Google DeepMind 的 AlphaGo 及其兄弟项目,并深入研究了它们的工作算法。 接下来,我们创建了一个类似的程序来玩 Connect 4,然后下棋。 我们将基于 AI 的国际象棋引擎作为 API 部署到 GPU 实例的 GCP 上,并将其与基于 Flutter 的应用集成。 我们还了解了如何使用 UCI 促进国际象棋的无状态游戏。 完成此项目后,您将对如何将游戏转换为强化学习环境,如何以编程方式定义游戏规则以及如何创建用于玩这些游戏的自学智能体有很好的了解。

在下一章中,我们将创建一个应用,该应用可以使低分辨率图像变成非常高分辨率的图像。 我们将在 AI 的帮助下进行此操作。

八、深度神经网络

在本章中,我们将回顾机器学习,深度神经网络中最先进的技术,也是研究最多的领域之一。

深度神经网络定义

这是一个新闻技术领域蓬勃发展的领域,每天我们都听到成功地将 DNN 用于解决新问题的实验,例如计算机视觉,自动驾驶,语音和文本理解等。

在前几章中,我们使用了与 DNN 相关的技术,尤其是在涉及卷积神经网络的技术中。

出于实际原因,我们将指深度学习和深度神经网络,即其中层数明显优于几个相似层的架构,我们将指代具有数十个层的神经网络架构,或者复杂结构的组合。

穿越时空的深度网络架构

在本节中,我们将回顾从 LeNet5 开始在整个深度学习历史中出现的里程碑架构。

LeNet 5

在 1980 年代和 1990 年代,神经网络领域一直保持沉默。 尽管付出了一些努力,但是架构非常简单,并且需要大的(通常是不可用的)机器力量来尝试更复杂的方法。

1998 年左右,在贝尔实验室中,在围绕手写校验数字分类的研究中,Ian LeCun 开始了一种新趋势,该趋势实现了所谓的“深度学习——卷积神经网络”的基础,我们已经在第 5 章,简单的前馈神经网络中对其进行了研究。

在那些年里,SVM 和其他更严格定义的技术被用来解决这类问题,但是有关 CNN 的基础论文表明,与当时的现有方法相比,神经网络的表现可以与之媲美或更好。

Alexnet

经过几年的中断(即使 LeCun 继续将其网络应用到其他任务,例如人脸和物体识别),可用结构化数据和原始处理能力的指数增长,使团队得以增长和调整模型, 在某种程度上被认为是不可能的,因此可以增加模型的复杂性,而无需等待数月的训练。

来自许多技术公司和大学的计算机研究团队开始竞争一些非常艰巨的任务,包括图像识别。 对于以下挑战之一,即 Imagenet 分类挑战,开发了 Alexnet 架构:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lULpcW1A-1681785128423)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00125.jpg)]

Alexnet 架构

主要功能

从其第一层具有卷积运算的意义上讲,Alexnet 可以看作是增强的 LeNet5。 但要添加未使用过的最大池化层,然后添加一系列密集的连接层,以建立最后的输出类别概率层。 视觉几何组(VGG)模型

图像分类挑战的其他主要竞争者之一是牛津大学的 VGG。

VGG 网络架构的主要特征是它们将卷积滤波器的大小减小到一个简单的3x3,并按顺序组合它们。

微小的卷积内核的想法破坏了 LeNet 及其后继者 Alexnet 的最初想法,后者最初使用的过滤器高达11x11过滤器,但复杂得多且表现低下。 过滤器大小的这种变化是当前趋势的开始:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-r8DheOZh-1681785128423)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00126.jpg)]

VGG 中每层的参数编号摘要

然而,使用一系列小的卷积权重的积极变化,总的设置是相当数量的参数(数以百万计的数量级),因此它必须受到许多措施的限制。

原始的初始模型

在由 Alexnet 和 VGG 主导的两个主要研究周期之后,Google 凭借非常强大的架构 Inception 打破了挑战,该架构具有多次迭代。

这些迭代的第一个迭代是从其自己的基于卷积神经网络层的架构版本(称为 GoogLeNet)开始的,该架构的名称让人想起了始于网络的方法。

GoogLenet(InceptionV1)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ofFPZuno-1681785128424)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00127.jpg)]

InceptionV1

GoogLeNet 是这项工作的第一个迭代,如下图所示,它具有非常深的架构,但是它具有九个链式初始模块的令人毛骨悚然的总和,几乎没有或根本没有修改:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EPjLvndu-1681785128424)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00128.jpg)]

盗梦空间原始架构

与两年前发布的 Alexnet 相比,它是如此复杂,但它设法减少了所需的参数数量并提高了准确率。

但是,由于几乎所有结构都由相同原始结构层构建块的确定排列和重复组成,因此提高了此复杂架构的理解和可伸缩性。

批量归一化初始化(V2)

2015 年最先进的神经网络在提高迭代效率的同时,还存在训练不稳定的问题。

为了理解问题的构成,首先我们将记住在前面的示例中应用的简单正则化步骤。 它主要包括将这些值以零为中心,然后除以最大值或标准偏差,以便为反向传播的梯度提供良好的基线。

在训练非常大的数据集的过程中,发生的事情是,经过大量训练示例之后,不同的值振荡开始放大平均参数值,就像在共振现象中一样。 我们非常简单地描述的被称为协方差平移。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sz6uJZfR-1681785128424)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00129.jpg)]

有和没有批量归一化的表现比较

这是开发批归一化技术的主要原因。

再次简化了过程描述,它不仅包括对原始输入值进行归一化,还对每一层上的输出值进行了归一化,避免了在层之间出现不稳定性之前就开始影响或漂移这些值。

这是 Google 在 2015 年 2 月发布的改进版 GoogLeNet 实现中提供的主要功能,也称为 InceptionV2。

InceptionV3

快进到 2015 年 12 月,Inception 架构有了新的迭代。 两次发行之间月份的不同使我们对新迭代的开发速度有了一个想法。

此架构的基本修改如下:

  • 将卷积数减少到最大3x3
  • 增加网络的总体深度
  • 在每一层使用宽度扩展技术来改善特征组合

下图说明了如何解释改进的启动模块:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6JVkvOHu-1681785128424)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00130.jpg)]

InceptionV3 基本模块

这是整个 V3 架构的表示形式,其中包含通用构建模块的许多实例:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-doHC5UCK-1681785128424)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00131.jpg)]

InceptionV3 总体图

残差网络(ResNet)

残差网络架构于 2015 年 12 月出现(与 InceptionV3 几乎同时出现),它带来了一个简单而新颖的想法:不仅使用每个构成层的输出,还将该层的输出与原始输入结合。

在下图中,我们观察到 ResNet 模块之一的简化​​视图。 它清楚地显示了卷积层栈末尾的求和运算,以及最终的 relu 运算:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lrxuW1RM-1681785128425)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00132.jpg)]

ResNet 一般架构

模块的卷积部分包括将特征从 256 个值减少到 64 个值,一个保留特征数的3x3过滤层以及一个从 64 x 256 个值增加1x1层的特征。 在最近的发展中,ResNet 的使用深度还不到 30 层,分布广泛。

其他深度神经网络架构

最近开发了很多神经网络架构。 实际上,这个领域是如此活跃,以至于我们每年或多或少都有新的杰出架构外观。 最有前途的神经网络架构的列表是:

  • SqueezeNet:此架构旨在减少 Alexnet 的参数数量和复杂性,声称减少了 50 倍的参数数量
  • 高效神经网络(Enet):旨在构建更简单,低延迟的浮点运算数量,具有实时结果的神经网络
  • Fractalnet:它的主要特征是非常深的网络的实现,不需要残留的架构,将结构布局组织为截断的分形

示例 – 风格绘画 – VGG 风格迁移

在此示例中,我们将配合 Leon Gatys 的论文《艺术风格的神经算法》的实现。

注意

此练习的原始代码由 Anish Athalye 提供。

我们必须注意,此练习没有训练内容。 我们将仅加载由 VLFeat 提供的预训练系数矩阵,该矩阵是预训练模型的数据库,可用于处理模型,从而避免了通常需要大量计算的训练:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tTvYAxhb-1681785128425)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00133.jpg)]

风格迁移主要概念

有用的库和方法

  • 使用scipy.io.loadmat加载参数文件
    • 我们将使用的第一个有用的库是scipy.io模块,用于加载系数数据,该数据另存为 matlab 的 MAT 格式。
  • 上一个参数的用法:
scipy.io.loadmat(file_name, mdict=None, appendmat=True, **kwargs) 

  • 返回前一个参数:

    mat_dict : dict :dictionary,变量名作为键,加载的矩阵作为值。 如果填充了mdict参数,则将结果分配给它。

数据集说明和加载

为了解决这个问题,我们将使用预训练的数据集,即 VGG 神经网络的再训练系数和 Imagenet 数据集。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ht3GTlIo-1681785128425)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00134.jpg)]

数据集预处理

假设系数是在加载的参数矩阵中给出的,那么关于初始数据集的工作就不多了。

模型架构

模型架构主要分为两部分:风格和内容。

为了生成最终图像,使用了没有最终完全连接层的 VGG 网络。

损失函数

该架构定义了两个不同的损失函数来优化最终图像的两个不同方面,一个用于内容,另一个用于风格。

内容损失函数

content_loss函数的代码如下:

 # content loss 
        content_loss = content_weight * (2 * tf.nn.l2_loss( 
                net[CONTENT_LAYER] - content_features[CONTENT_LAYER]) / 
                content_features[CONTENT_LAYER].size) 

风格损失函数

损失优化循环

损耗优化循环的代码如下:

        best_loss = float('inf') 
        best = None 
        with tf.Session() as sess: 
            sess.run(tf.initialize_all_variables()) 
            for i in range(iterations): 
                last_step = (i == iterations - 1) 
                print_progress(i, last=last_step) 
                train_step.run() 

                if (checkpoint_iterations and i % checkpoint_iterations == 0) or last_step: 
                    this_loss = loss.eval() 
                    if this_loss < best_loss: 
                        best_loss = this_loss 
                        best = image.eval() 
                    yield ( 
                        (None if last_step else i), 
                        vgg.unprocess(best.reshape(shape[1:]), mean_pixel) 
                    ) 

收敛性测试

在此示例中,我们将仅检查指示的迭代次数(迭代参数)。

程序执行

为了以良好的迭代次数(大约 1000 个)执行该程序,我们建议至少有 8GB 的 RAM 内存可用:

python neural_style.py --content examples/2-content.jpg --styles examples/2-style1.jpg  --checkpoint-iterations=100 --iterations=1000 --checkpoint-output=out%s.jpg --output=outfinal

前面命令的结果如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-IISmOKsl-1681785128425)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00135.jpg)]

风格迁移步骤

控制台输出如下:

Iteration 1/1000
Iteration 2/1000
Iteration 3/1000
Iteration 4/1000
...
Iteration 999/1000
Iteration 1000/1000
  content loss: 908786
    style loss: 261789
       tv loss: 25639.9
    total loss: 1.19621e+06

完整源代码

neural_style.py的代码如下:

import os 

import numpy as np 
import scipy.misc 

from stylize import stylize 

import math 
from argparse import ArgumentParser 

# default arguments 
CONTENT_WEIGHT = 5e0 
STYLE_WEIGHT = 1e2 
TV_WEIGHT = 1e2 
LEARNING_RATE = 1e1 
STYLE_SCALE = 1.0 
ITERATIONS = 100 
VGG_PATH = 'imagenet-vgg-verydeep-19.mat' 

def build_parser(): 
    parser = ArgumentParser() 
    parser.add_argument('--content', 
            dest='content', help='content image', 
            metavar='CONTENT', required=True) 
    parser.add_argument('--styles', 
            dest='styles', 
            nargs='+', help='one or more style images', 
            metavar='STYLE', required=True) 
    parser.add_argument('--output', 
            dest='output', help='output path', 
            metavar='OUTPUT', required=True) 
    parser.add_argument('--checkpoint-output', 
            dest='checkpoint_output', help='checkpoint output format', 
            metavar='OUTPUT') 
    parser.add_argument('--iterations', type=int, 
            dest='iterations', help='iterations (default %(default)s)', 
            metavar='ITERATIONS', default=ITERATIONS) 
    parser.add_argument('--width', type=int, 
            dest='width', help='output width', 
            metavar='WIDTH') 
    parser.add_argument('--style-scales', type=float, 
            dest='style_scales', 
            nargs='+', help='one or more style scales', 
            metavar='STYLE_SCALE') 
    parser.add_argument('--network', 
            dest='network', help='path to network parameters (default %(default)s)', 
            metavar='VGG_PATH', default=VGG_PATH) 
    parser.add_argument('--content-weight', type=float, 
            dest='content_weight', help='content weight (default %(default)s)', 
            metavar='CONTENT_WEIGHT', default=CONTENT_WEIGHT) 
    parser.add_argument('--style-weight', type=float, 
            dest='style_weight', help='style weight (default %(default)s)', 
            metavar='STYLE_WEIGHT', default=STYLE_WEIGHT) 
    parser.add_argument('--style-blend-weights', type=float, 
            dest='style_blend_weights', help='style blending weights', 
            nargs='+', metavar='STYLE_BLEND_WEIGHT') 
    parser.add_argument('--tv-weight', type=float, 
            dest='tv_weight', help='total variation regularization weight (default %(default)s)', 
            metavar='TV_WEIGHT', default=TV_WEIGHT) 
    parser.add_argument('--learning-rate', type=float, 
            dest='learning_rate', help='learning rate (default %(default)s)', 
            metavar='LEARNING_RATE', default=LEARNING_RATE) 
    parser.add_argument('--initial', 
            dest='initial', help='initial image', 
            metavar='INITIAL') 
    parser.add_argument('--print-iterations', type=int, 
            dest='print_iterations', help='statistics printing frequency', 
            metavar='PRINT_ITERATIONS') 
    parser.add_argument('--checkpoint-iterations', type=int, 
            dest='checkpoint_iterations', help='checkpoint frequency', 
            metavar='CHECKPOINT_ITERATIONS') 
    return parser 

def main(): 
    parser = build_parser() 
    options = parser.parse_args() 

    if not os.path.isfile(options.network): 
        parser.error("Network %s does not exist. (Did you forget to download it?)" % options.network) 

    content_image = imread(options.content) 
    style_images = [imread(style) for style in options.styles] 

    width = options.width 
    if width is not None: 
        new_shape = (int(math.floor(float(content_image.shape[0]) / 
                content_image.shape[1] * width)), width) 
        content_image = scipy.misc.imresize(content_image, new_shape) 
    target_shape = content_image.shape 
    for i in range(len(style_images)): 
        style_scale = STYLE_SCALE 
        if options.style_scales is not None: 
            style_scale = options.style_scales[i] 
        style_images[i] = scipy.misc.imresize(style_images[i], style_scale * 
                target_shape[1] / style_images[i].shape[1]) 

    style_blend_weights = options.style_blend_weights 
    if style_blend_weights is None: 
        # default is equal weights 
        style_blend_weights = [1.0/len(style_images) for _ in style_images] 
    else: 
        total_blend_weight = sum(style_blend_weights) 
        style_blend_weights = [weight/total_blend_weight 
                               for weight in style_blend_weights] 

    initial = options.initial 
    if initial is not None: 
        initial = scipy.misc.imresize(imread(initial), content_image.shape[:2]) 

    if options.checkpoint_output and "%s" not in options.checkpoint_output: 
        parser.error("To save intermediate images, the checkpoint output " 
                     "parameter must contain `%s` (e.g. `foo%s.jpg`)") 

    for iteration, image in stylize( 
        network=options.network, 
        initial=initial, 
        content=content_image, 
        styles=style_images, 
        iterations=options.iterations, 
        content_weight=options.content_weight, 
        style_weight=options.style_weight, 
        style_blend_weights=style_blend_weights, 
        tv_weight=options.tv_weight, 
        learning_rate=options.learning_rate, 
        print_iterations=options.print_iterations, 
        checkpoint_iterations=options.checkpoint_iterations 
    ): 
        output_file = None 
        if iteration is not None: 
            if options.checkpoint_output: 
                output_file = options.checkpoint_output % iteration 
        else: 
            output_file = options.output 
        if output_file: 
            imsave(output_file, image) 

def imread(path): 
    return scipy.misc.imread(path).astype(np.float) 

def imsave(path, img): 
    img = np.clip(img, 0, 255).astype(np.uint8) 
    scipy.misc.imsave(path, img) 

if __name__ == '__main__': 
    main() 

Stilize.py的代码如下:

import vgg 

import tensorflow as tf 
import numpy as np 

from sys import stderr 

CONTENT_LAYER = 'relu4_2' 
STYLE_LAYERS = ('relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1') 

try: 
    reduce 
except NameError: 
    from functools import reduce 

def stylize(network, initial, content, styles, iterations, 
        content_weight, style_weight, style_blend_weights, tv_weight, 
        learning_rate, print_iterations=None, checkpoint_iterations=None): 
    """ 
    Stylize images. 

    This function yields tuples (iteration, image); `iteration` is None 
    if this is the final image (the last iteration).  Other tuples are yielded 
    every `checkpoint_iterations` iterations. 

    :rtype: iterator[tuple[int|None,image]] 
    """ 
    shape = (1,) + content.shape 
    style_shapes = [(1,) + style.shape for style in styles] 
    content_features = {} 
    style_features = [{} for _ in styles] 

    # compute content features in feedforward mode 
    g = tf.Graph() 
    with g.as_default(), g.device('/cpu:0'), tf.Session() as sess: 
        image = tf.placeholder('float', shape=shape) 
        net, mean_pixel = vgg.net(network, image) 
        content_pre = np.array([vgg.preprocess(content, mean_pixel)]) 
        content_features[CONTENT_LAYER] = net[CONTENT_LAYER].eval( 
                feed_dict={image: content_pre}) 

    # compute style features in feedforward mode 
    for i in range(len(styles)): 
        g = tf.Graph() 
        with g.as_default(), g.device('/cpu:0'), tf.Session() as sess: 
            image = tf.placeholder('float', shape=style_shapes[i]) 
            net, _ = vgg.net(network, image) 
            style_pre = np.array([vgg.preprocess(styles[i], mean_pixel)]) 
            for layer in STYLE_LAYERS: 
                features = net[layer].eval(feed_dict={image: style_pre}) 
                features = np.reshape(features, (-1, features.shape[3])) 
                gram = np.matmul(features.T, features) / features.size 
                style_features[i][layer] = gram 

    # make stylized image using backpropogation 
    with tf.Graph().as_default(): 
        if initial is None: 
            noise = np.random.normal(size=shape, scale=np.std(content) * 0.1) 
            initial = tf.random_normal(shape) * 0.256 
        else: 
            initial = np.array([vgg.preprocess(initial, mean_pixel)]) 
            initial = initial.astype('float32') 
        image = tf.Variable(initial) 
        net, _ = vgg.net(network, image) 

        # content loss 
        content_loss = content_weight * (2 * tf.nn.l2_loss( 
                net[CONTENT_LAYER] - content_features[CONTENT_LAYER]) / 
                content_features[CONTENT_LAYER].size) 
        # style loss 
        style_loss = 0 
        for i in range(len(styles)): 
            style_losses = [] 
            for style_layer in STYLE_LAYERS: 
                layer = net[style_layer] 
                _, height, width, number = map(lambda i: i.value, layer.get_shape()) 
                size = height * width * number 
                feats = tf.reshape(layer, (-1, number)) 
                gram = tf.matmul(tf.transpose(feats), feats) / size 
                style_gram = style_features[i][style_layer] 
                style_losses.append(2 * tf.nn.l2_loss(gram - style_gram) / style_gram.size) 
            style_loss += style_weight * style_blend_weights[i] * reduce(tf.add, style_losses) 
        # total variation denoising 
        tv_y_size = _tensor_size(image[:,1:,:,:]) 
        tv_x_size = _tensor_size(image[:,:,1:,:]) 
        tv_loss = tv_weight * 2 * ( 
                (tf.nn.l2_loss(image[:,1:,:,:] - image[:,:shape[1]-1,:,:]) / 
                    tv_y_size) + 
                (tf.nn.l2_loss(image[:,:,1:,:] - image[:,:,:shape[2]-1,:]) / 
                    tv_x_size)) 
        # overall loss 
        loss = content_loss + style_loss + tv_loss 

        # optimizer setup 
        train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss) 

        def print_progress(i, last=False): 
            stderr.write('Iteration %d/%d\n' % (i + 1, iterations)) 
            if last or (print_iterations and i % print_iterations == 0): 
                stderr.write('  content loss: %g\n' % content_loss.eval()) 
                stderr.write('    style loss: %g\n' % style_loss.eval()) 
                stderr.write('       tv loss: %g\n' % tv_loss.eval()) 
                stderr.write('    total loss: %g\n' % loss.eval()) 

        # optimization 
        best_loss = float('inf') 
        best = None 
        with tf.Session() as sess: 
            sess.run(tf.initialize_all_variables()) 
            for i in range(iterations): 
                last_step = (i == iterations - 1) 
                print_progress(i, last=last_step) 
                train_step.run() 

                if (checkpoint_iterations and i % checkpoint_iterations == 0) or last_step: 
                    this_loss = loss.eval() 
                    if this_loss < best_loss: 
                        best_loss = this_loss 
                        best = image.eval() 
                    yield ( 
                        (None if last_step else i), 
                        vgg.unprocess(best.reshape(shape[1:]), mean_pixel) 
                    ) 

def _tensor_size(tensor): 
    from operator import mul 
    return reduce(mul, (d.value for d in tensor.get_shape()), 1) 
 vgg.py 
import tensorflow as tf 
import numpy as np 
import scipy.io 

def net(data_path, input_image): 
    layers = ( 
        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 

        'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 

        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 
        'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 

        'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 
        'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 

        'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 
        'relu5_3', 'conv5_4', 'relu5_4' 
    ) 

    data = scipy.io.loadmat(data_path) 
    mean = data['normalization'][0][0][0] 
    mean_pixel = np.mean(mean, axis=(0, 1)) 
    weights = data['layers'][0] 

    net = {} 
    current = input_image 
    for i, name in enumerate(layers): 
        kind = name[:4] 
        if kind == 'conv': 
            kernels, bias = weights[i][0][0][0][0] 
            # matconvnet: weights are [width, height, in_channels, out_channels] 
            # tensorflow: weights are [height, width, in_channels, out_channels] 
            kernels = np.transpose(kernels, (1, 0, 2, 3)) 
            bias = bias.reshape(-1) 
            current = _conv_layer(current, kernels, bias) 
        elif kind == 'relu': 
            current = tf.nn.relu(current) 
        elif kind == 'pool': 
            current = _pool_layer(current) 
        net[name] = current 

    assert len(net) == len(layers) 
    return net, mean_pixel 

def _conv_layer(input, weights, bias): 
    conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1), 
            padding='SAME') 
    return tf.nn.bias_add(conv, bias) 

def _pool_layer(input): 
    return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), 
            padding='SAME') 

def preprocess(image, mean_pixel): 
    return image - mean_pixel 

def unprocess(image, mean_pixel): 
    return image + mean_pixel 

总结

在本章中,我们一直在学习不同的深度神经网络架构。

我们了解了如何构建近年来最著名的架构之一 VGG,以及如何使用它来生成可转换艺术风格的图像。

在下一章中,我们将使用机器学习中最有用的技术之一:图形处理单元。 我们将回顾安装具有 GPU 支持的 TensorFlow 所需的步骤并对其进行训练,并将执行时间与唯一运行的模型 CPU 进行比较。

九、构建图像超分辨率应用

还记得上次和亲人一起旅行并拍了一些漂亮的照片作为记忆,但是当您回到家并刷过它们时,您发现它们非常模糊且质量低下吗? 现在,您剩下的所有美好时光就是您自己的心理记忆和那些模糊的照片。 如果可以使您的照片清晰透明并且可以看到其中的每个细节,那不是很好吗?

超分辨率是基于像素信息的近似将低分辨率图像转换为高分辨率图像的过程。 虽然今天可能还不完全是神奇的,但当技术发展到足以成为通用 AI 应用时,它肯定会在将来挽救生命。

在此项目中,我们将构建一个应用,该应用使用托管在 DigitalOcean Droplet 上的深度学习模型,该模型可以同时比较低分辨率和高分辨率图像,从而使我们更好地了解今天的技术。 我们将使用生成对抗网络GAN)生成超分辨率图像。

在本章中,我们将介绍以下主题:

  • 基本项目架构
  • 了解 GAN
  • 了解图像超分辨率的工作原理
  • 创建 TensorFlow 模型以实现超分辨率
  • 构建应用的 UI
  • 从设备的本地存储中获取图片
  • 在 DigitalOcean 上托管 TensorFlow 模型
  • 在 Flutter 上集成托管的自定义模型
  • 创建材质应用

让我们从了解项目的架构开始。

基本项目架构

让我们从了解项目的架构开始。

我们将在本章中构建的项目主要分为两个部分:

  • Jupyter 笔记本,它创建执行超分辨率的模型。
  • 使用该模型的 Flutter 应用,在 Jupyter 笔记本上接受训练后,将托管在 DigitalOcean 中的 Droplet 中。

从鸟瞰图可以用下图描述该项目:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第38张图片

将低分辨率图像放入模型中,该模型是从 Firebase 上托管的 ML Kit 实例中获取的,并放入 Flutter 应用中。 生成输出并将其作为高分辨率图像显示给用户。 该模型缓存在设备上,并且仅在开发人员更新模型时才更新,因此可以通过减少网络延迟来加快预测速度。

现在,让我们尝试更深入地了解 GAN。

了解 GAN

Ian Goodfellow,Yoshua Bengio 和其他人在 NeurIPS 2014 中引入的 GAN 席卷全球。 可以应用于各种领域的 GAN 会根据模型对实际数据样本的学习近似,生成新的内容或序列。 GAN 已被大量用于生成音乐和艺术的新样本,例如下图所示的面孔,而训练数据集中不存在这些面孔:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第39张图片

经过 60 个周期的训练后,GAN 生成的面孔。 该图像取自这里。

前面面孔中呈现的大量真实感证明了 GAN 的力量–在为他们提供良好的训练样本量之后,他们几乎可以学习生成任何类型的模式。

GAN 的核心概念围绕两个玩家玩游戏的想法。 在这个游戏中,一个人说出一个随机句子,另一个人仅仅考虑第一人称使用的单词就指出它是事实还是假。 第二个人唯一可以使用的知识是假句子和实句中常用的单词(以及如何使用)。 这可以描述为由 minimax 算法玩的两人游戏,其中每个玩家都试图以其最大能力抵消另一位玩家所做的移动。 在 GAN 中,第一个玩家是生成器G),第二个玩家是判别器D)。 GD都是常规 GAN 中的神经网络。 生成器从训练数据集中给出的样本中学习,并基于其认为当观察者查看时可以作为真实样本传播的样本来生成新样本。

判别器从训练样本(正样本)和生成器生成的样本(负样本)中学习,并尝试对哪些图像存在于数据集中以及哪些图像进行分类。 它从G获取生成的图像,并尝试将其分类为真实图像(存在于训练样本中)或生成图像(不存在于数据库中)。

通过反向传播,GAN 尝试不断减少判别器能够对生成器正确生成的图像进行分类的次数。 一段时间后,我们希望达到识别器在识别生成的图像时开始表现不佳的阶段。 这是 GAN 停止学习的地方,然后可以使用生成器生成所需数量的新样本。 因此,训练 GAN 意味着训练生成器以从随机输入产生输出,从而使判别器无法将其识别为生成的图像。

判别器将传递给它的所有图像分为两类:

  • 真实图像:数据集中存在的图像或使用相机拍摄的图像
  • 伪图像:使用某软件生成的图像

生成器欺骗判别器的能力越好,当向其提供任何随机输入序列时,生成的输出将越真实。

让我们以图表形式总结前面关于 GAN 进行的讨论:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第40张图片

GAN 具有许多不同的变体,所有变体都取决于它们正在执行的任务。 其中一些如下:

  • 渐进式 GAN:在 ICLR 2018 上的一篇论文中介绍,渐进式 GAN 的生成器和判别器均以低分辨率图像开始,并随着图像层的增加而逐渐受到训练,从而使系统能够生成高分辨率图像。 例如,在第一次迭代中生成的图像为10x10像素,在第二代中它变为20x20,依此类推,直到获得非常高分辨率的图像为止。 生成器和判别器都在深度上一起增长。
  • 条件 GAN:假设您有一个 GAN 可以生成 10 个不同类别的样本,但是在某个时候,您希望它在给定类别或一组类别内生成样本。 这是有条件 GAN 起作用的时候。有条件 GAN 使我们可以生成 GAN 中经过训练可以生成的所有标签中任何给定标签的样本。 在图像到图像的翻译领域中,已经完成了条件 GAN 的一种非常流行的应用,其中将一个图像生成为相似或相同域的另一个更逼真的图像。 您可以通过这个页面上的演示来尝试涂鸦一些猫,并获得涂鸦的真实感版本。
  • 栈式 GAN:栈式 GAN 的最流行的应用是基于文本描述生成图像。 在第一阶段,GAN 生成描述项的概述,在第二阶段,根据描述添加颜色。 然后,后续层中的 GAN 将更多细节添加到图像中,以生成图像的真实感版本,如描述中所述。 通过观察堆叠 GAN 的第一次迭代中的图像已经处于将要生成最终输出的尺寸,可以将栈式 GAN 与渐进式 GAN 区别开来。但是,与渐进式 GAN 相似,在第一次迭代中, 图像是最小的,并且需要进一步的层才能将其馈送到判别器。

在此项目中,我们将讨论 GAN 的另一种形式,称为超分辨率 GANSRGAN)。 我们将在下一部分中了解有关此变体的更多信息。

了解图像超分辨率的工作原理

几十年来,人们一直在追求并希望能够使低分辨率图像更加精细,以及使高分辨率图像化。 超分辨率是用于将低分辨率图像转换为超高分辨率图像的技术的集合,是图像处理工程师和研究人员最激动人心的工作领域之一。 已经建立了几种方法和方法来实现图像的超分辨率,并且它们都朝着自己的目标取得了不同程度的成功。 然而,近来,随着 SRGAN 的发展,关于使用任何低分辨率图像可以实现的超分辨率的量有了显着的改进。

但是在讨论 SRGAN 之前,让我们了解一些与图像超分辨率有关的概念。

了解图像分辨率

用质量术语来说,图像的分辨率取决于其清晰度。 分辨率可以归类为以下之一:

  • 像素分辨率
  • 空间分辨率
  • 时间分辨率
  • 光谱分辨率
  • 辐射分辨率

让我们来看看每个。

像素分辨率

指定分辨率的最流行格式之一,像素分辨率最通常是指形成图像时涉及的像素数量。 单个像素是可以在任何给定查看设备上显示的最小单个单元。 可以将几个像素组合在一起以形成图像。 在本书的前面,我们讨论了图像处理,并将像素称为存储在矩阵中的颜色信息的单个单元,它代表图像。 像素分辨率定义了形成数字图像所需的像素元素总数,该总数可能与图像上可见的有效像素数不同。

标记图像像素分辨率的一种非常常见的表示法是以百万像素表示。 给定NxM像素分辨率的图像,其分辨率可以写为(NxM / 1000000)百万像素。 因此,尺寸为2,000x3,000的图像将具有 6,000,000 像素,其分辨率可以表示为 6 兆像素。

空间分辨率

这是观察图像的人可以分辨图像中紧密排列的线条的程度的度量。 在这里,严格说来,图像的像素越多,清晰度越好。 这是由于具有较高像素数量的图像的空间分辨率较低。 因此,需要良好的空间分辨率以及具有良好的像素分辨率以使图像以良好的质量呈现。

它也可以定义为像素一侧所代表的距离量。

时间分辨率

分辨率也可能取决于时间。 例如,卫星或使用无人飞行器UAV)无人机拍摄的同一区域的图像可能会随时间变化。 重新捕获相同区域的图像所需的时间称为时间分辨率。

时间分辨率主要取决于捕获图像的设备。 如在图像捕捉的情况下,这可以是变型,例如当在路边的速度陷阱照相机中触发特定传感器时执行图像捕捉。 它也可以是常数。 例如,在配置为每x间隔拍照的相机中。

光谱分辨率

光谱分辨率是指图像捕获设备可以记录的波段数。 也可以将其定义为波段的宽度或每个波段的波长范围。 在数字成像方面,光谱分辨率类似于图像中的通道数。 理解光谱分辨率的另一种方法是在任何给定图像或频带记录中可区分的频带数。

黑白图像中的波段数为 1,而彩色(RGB)图像中的波段数为 3。可以捕获数百个波段的图像,其中其他波段可提供有关图像的不同种类的信息。 图片。

辐射分辨率

辐射分辨率是捕获设备表示在任何频带/通道上接收到的强度的能力。 辐射分辨率越高,设备可以更准确地捕获其通道上的强度,并且图像越真实。

辐射分辨率类似于图像每个像素的位数。 虽然 8 位图像像素可以表示 256 个不同的强度,但是 256 位图像像素可以表示2 ^ 256个不同的强度。 黑白图像的辐射分辨率为 1 位,这意味着每个像素只能有两个不同的值,即 0 和 1。

现在,让我们尝试了解 SRGAN。

了解 SRGAN

SRGAN 是一类 GAN,主要致力于从低分辨率图像创建超分辨率图像。

SRGAN 算法的功能描述如下:该算法从数据集中选取高分辨率图像,然后将其采样为低分辨率图像。 然后,生成器神经网络尝试从低分辨率图像生成高分辨率图像。 从现在开始,我们将其称为超分辨率图像。 将超分辨率图像发送到鉴别神经网络,该神经网络已经在高分辨率图像和一些基本的超分辨率图像的样本上进行了训练,以便可以对它们进行分类。

判别器将由生成器发送给它的超分辨率图像分类为有效的高分辨率图像,伪高分辨率图像或超分辨率图像。 如果将图像分类为超分辨率图像,则 GAN 损失会通过生成器网络反向传播,以便下次产生更好的伪造图像。 随着时间的流逝,生成器将学习如何创建更好的伪造品,并且判别器开始无法正确识别超分辨率图像。 GAN 在这里停止学习,被列为受过训练的人。

可以用下图来总结:

现在,让我们开始创建用于超分辨率的 SRGAN 模型。

创建 TensorFlow 模型来实现超分辨率

现在,我们将开始构建在图像上执行超分辨率的 GAN 模型。 在深入研究代码之前,我们需要了解如何组织项目目录。

项目目录结构

本章中包含以下文件和文件夹:

  • api/
  • model /
  • __init __.py:此文件指示此文件的父文件夹可以像模块一样导入。
  • common.py:包含任何 GAN 模型所需的常用函数。
  • srgan.py:其中包含开发 SRGAN 模型所需的函数。
  • weights/
  • gan_generator.h5:模型的预训练权重文件。 随意使用它来快速运行并查看项目的工作方式。
  • data.py:用于在 DIV2K 数据集中下载,提取和加载图像的工具函数。
  • flask_app.py:我们将使用此文件来创建将在 DigitalOcean 上部署的服务器。
  • train.py:模型训练文件。 我们将在本节中更深入地讨论该文件。

您可以在这个页面中找到项目此部分的源代码。

多样 2KDIV2K)数据集由图像恢复和增强的新趋势NTIRE)2017 单张图像超分辨率挑战赛引入,也用于挑战赛的 2018 版本中。

在下一节中,我们将构建 SRGAN 模型脚本。

创建用于超分辨率的 SRGAN 模型

首先,我们将从处理train.py文件开始:

  1. 让我们从将必要的模块导入项目开始:
import os

from data import DIV2K
from model.srgan import generator, discriminator
from train import SrganTrainer, SrganGeneratorTrainer

前面的导入引入了一些现成的类,例如SrganTrainerSrganGeneratorTrainer等。 在完成此文件的工作后,我们将详细讨论它们。

  1. 现在,让我们为权重创建一个目录。 我们还将使用此目录来存储中间模型:
weights_dir = 'weights'
weights_file = lambda filename: os.path.join(weights_dir, filename)

os.makedirs(weights_dir, exist_ok=True)
  1. 接下来,我们将从 DIV2K 数据集中下载并加载图像。 我们将分别下载训练和验证图像。 对于这两组图像,可以分为两对:高分辨率和低分辨率。 但是,这些是单独下载的:
div2k_train = DIV2K(scale=4, subset='train', downgrade='bicubic')
div2k_valid = DIV2K(scale=4, subset='valid', downgrade='bicubic')
  1. 将数据集下载并加载到变量后,我们需要将训练图像和验证图像都转换为 TensorFlow 数据集对象。 此步骤还将两个数据集中的高分辨率和低分辨率图像结合在一起:
train_ds = div2k_train.dataset(batch_size=16, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=16, random_transform=True, repeat_count=1)
  1. 现在,回想一下我们在“了解 GAN”部分中提供的 GAN 的定义。 为了使生成器开始产生判别器可以评估的伪造品,它需要学习创建基本的伪造品。 为此,我们将快速训练神经网络,以便它可以生成基本的超分辨率图像。 我们将其命名为预训练器。 然后,我们将预训练器的权重迁移到实际的 SRGAN,以便它可以通过使用判别器来学习更多。 让我们构建并运行预训练器
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'.ckpt/pre_generator')
pre_trainer.train(train_ds,
                  valid_ds.take(10),
                  steps=1000000, 
                  evaluate_every=1000, 
                  save_best_only=False)

pre_trainer.model.save_weights(weights_file('pre_generator.h5'))

现在,我们已经训练了一个基本模型并保存了权重。 我们可以随时更改 SRGAN 并通过加载其权重从基础训练中重新开始。

  1. 现在,让我们将预训练器权重加载到 SRGAN 对象中,并执行训练迭代:
gan_generator = generator()
gan_generator.load_weights(weights_file('pre_generator.h5'))

gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=200000)

请注意,在具有 8 GB RAM 和 Intel i7 处理器的普通计算机上,上述代码中的训练操作可能会花费大量时间。 建议在具有图形处理器GPU)的基于云的虚拟机中执行此训练。

  1. 现在,让我们保存 GAN 生成器和判别器的权重:
gan_trainer.generator.save_weights(weights_file('gan_generator.h5'))
gan_trainer.discriminator.save_weights(weights_file('gan_discriminator.h5'))

现在,我们准备继续进行下一部分,在该部分中将构建将使用此模型的 Flutter 应用的 UI。

构建应用的 UI

现在,我们了解了图像超分辨率模型的基本功能并为其创建了一个模型,让我们深入研究构建 Flutter 应用。 在本节中,我们将构建应用的 UI。

该应用的用户界面非常简单:它将包含两个图像小部件和按钮小部件。 当用户单击按钮小部件时,他们将能够从设备的库中选择图像。 相同的图像将作为输入发送到托管模型的服务器。 服务器将返回增强的图像。 屏幕上将放置的两个图像小部件将用于显示服务器的输入和服务器的输出。

下图说明了应用的基本结构和最终流程:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mNqyudKm-1681785128426)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/86a43bbe-4673-4dcb-8a8c-591d7c952df0.png)]

该应用的三个主要小部件可以简单地排列在一列中。 该应用的小部件树如下所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第41张图片

现在,让我们编写代码以构建主屏幕。 以下步骤讨论了该应用小部件的创建和放置:

  1. 首先,我们创建一个名为image_super_resolution.dart的新文件。 这将包含一个名为ImageSuperResolution的无状态窗口小部件。 该小部件将包含应用主屏幕的代码。
  2. 接下来,我们将定义一个名为buildImageInput()的函数,该函数返回一个小部件,该小部件负责显示用户选择的图像:
Widget buildImage1() {
     return Expanded(
         child: Container(
             width: 200, 
             height: 200, 
             child: img1
         )
     );
 }

此函数返回带有Container作为其child.Expanded小部件。Containerwidthheight200Container的子元素最初是存储在资产文件夹中的占位符图像,可以通过img1变量进行访问,如下所示:

 var img1 = Image.asset('assets/place_holder_image.png');

我们还将在pubspec.yaml文件中添加图像的路径,如下所示:

flutter:
    assets:
        - assets/place_holder_image.png
  1. 现在,我们将创建另一个函数buildImageOutput(),该函数返回一个小部件,该小部件负责显示模型返回的增强图像:
Widget buildImageOutput() {
     return Expanded(
         child: Container(
             width: 200, 
             height: 200, 
             child: imageOutput
         )
     );
 }

此函数返回一个以其Container作为其子元素的Expanded小部件。 Container的宽度和高度设置为200Container的子级是名为imageOutput的小部件。 最初,imageOutput还将包含一个占位符图像,如下所示:

Widget imageOutput = Image.asset('assets/place_holder_image.png');

将模型集成到应用中后,我们将更新imageOutput

  1. 现在,我们将定义第三个函数buildPickImageButton(),该函数返回一个Widget,我们可以使用它从设备的图库中选择图像:
Widget buildPickImageButton() {
     return Container(
         margin: EdgeInsets.all(8),
         child: FloatingActionButton(
             elevation: 8,
             child: Icon(Icons.camera_alt),
             onPressed: () => {},
        )
     );
 }

此函数返回以FloatingActionButton作为其子元素的Container。 按钮的elevation属性控制其下方阴影的大小,并设置为8。 为了反映该按钮用于选择图像,通过Icon类为它提供了摄像机的图标。 当前,我们已经将按钮的onPressed属性设置为空白。 我们将在下一部分中定义一个函数,使用户可以在按下按钮时从设备的图库中选择图像。

  1. 最后,我们将覆盖build方法以返回应用的Scaffold
@override
 Widget build(BuildContext context) {
     return Scaffold(
         appBar: AppBar(title: Text('Image Super Resolution')),
         body: Container(
             child: Column(
                 crossAxisAlignment: CrossAxisAlignment.center,
                 children: <Widget>[
                     buildImageInput(),
                     buildImageOutput(),
                     buildPickImageButton()
                 ]
             )
         )
     );
 }

Scaffold包含一个appBar,其标题设置为“图像超分辨率”。 Scaffold的主体为Container,其子代为Column。 该列的子级是我们在先前步骤中构建的三个小部件。 另外,我们将ColumncrossAxisAlignment属性设置为CrossAxisAlignment.center,以确保该列位于屏幕的中央。

至此,我们已经成功构建了应用的初始状态。 以下屏幕截图显示了该应用现在的外观:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第42张图片

尽管屏幕看起来很完美,但目前无法正常工作。 接下来,我们将向应用添加功能。 我们将添加让用户从图库中选择图像的功能。

从设备的本地存储中获取图片

在本节中,我们将添加FloatingActionButton的功能,以使用户可以从设备的图库中选择图像。 这最终将被发送到服务器,以便我们能够收到响应。

以下步骤描述了如何启动图库并让用户选择图像:

  1. 为了允许用户从设备的图库中选择图像,我们将使用image_picker库。 这将启动图库并存储用户选择的图像文件。 我们将从在pubspec.yaml文件中添加依赖项开始:
image_picker: 0.4.12+1

另外,我们通过在终端上运行flutter pub get来获取库。

  1. 接下来,我们将库导入image_super_resolution.dart文件中:
import 'package:image_picker/image_picker.dart';
  1. 现在,让我们定义pickImage()函数,该函数使用户可以从图库中选择图像:
void pickImage() async {
     File pickedImg = await ImagePicker.pickImage(source: ImageSource.gallery);
 }
  1. 从函数内部,我们只需调用ImagePicker.pickImage()并将source指定为ImageSource.gallery即可。 该库本身处理启动设备图库的复杂性。 用户选择的图像文件最终由该函数返回。 我们将函数返回的文件存储在File类型的pickedImg变量中。
  2. 接下来,我们定义loadImage()函数,以便在屏幕上显示用户选择的图像:
void loadImage(File file) {
     setState(() {
         img1 = Image.file(file);
     });
 }

此函数将用户选择的图像文件作为输入。 在函数内部,我们将先前声明的img1变量的值设置为Image.file(file),这将返回从'file'构建的Image小部件。 回想一下,最初,img1被设置为占位符图像。 为了重新渲染屏幕并显示用户选择的图像,我们将img1的新分配放在setState()中。

  1. 现在,将pickImage()添加到builtPickImageButton()内的FloatingActionButtononPressed属性中:
 Widget buildPickImageButton() {
     return Container(
        ....
        child: FloatingActionButton(
            ....
            onPressed: () => pickImage(),
         )
    );
 }

前面的补充内容确保单击按钮时,会启动图库,以便可以选择图像。

  1. 最后,我们将从pickImage()loadImage()添加一个调用:
void pickImage() async {
     ....
     loadImage(pickedImg);
 }

loadImage()内部,我们传入用户选择的图像,该图像存储在pickedImage变量中,以便可以在应用的屏幕上查看该图像。

完成上述所有步骤后,该应用将如下所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第43张图片

至此,我们已经构建了应用的用户界面。 我们还添加了一些功能,使用户可以从设备的图库中选择图像并将其显示在屏幕上。

在下一部分中,我们将学习如何托管在“为超分辨率创建 TensorFlow 模型”中创建的模型作为 API,以便我们可以使用它执行超分辨率。

在 DigitalOcean 上托管 TensorFlow 模型

DigitalOcean 是一个了不起的低成本云解决方案平台,非常易于上手,并提供了应用开发人员为立即可用的应用后端提供动力所需的几乎所有功能。 该界面非常易于使用,并且 DigitalOcean 拥有一些最广泛的文档,这些文档围绕着如何在云上设置不同类型的应用服务器提供入门。

在这个项目中,我们将使用 DigitalOcean 的 Droplet 部署我们的超分辨率 API。 DigitalOcean 中的 Droplet 只是通常在共享硬件空间上运行的虚拟机。

首先,我们将在项目目录中创建flask_app.py文件,并添加服务器工作所需的代码。

创建一个 Flask 服务器脚本

在本节中,我们将处理flask_app.py文件,该文件将作为服务器在云虚拟机上运行。 让我们开始吧:

  1. 首先,我们将对文件进行必要的导入:
from flask import Flask, request, jsonify, send_file
import os
import time

from matplotlib.image import imsave

from model.srgan import generator

from model import resolve_single
  1. 现在,我们将定义weights目录并将生成器权重加载到文件中:
weights_dir = 'weights'
weights_file = lambda filename: os.path.join(weights_dir, filename)

gan_generator = generator()
gan_generator.load_weights(weights_file('gan_generator.h5'))
  1. 接下来,我们将使用以下代码行实例化Flask应用:
app = Flask(__name__)
  1. 现在,我们准备构建服务器将监听的路由。 首先,我们将创建/generate路由,该路由将图像作为输入,生成其超分辨率版本,并将所生成的高分辨率图像的文件名返回给用户:
@app.route('/generate', methods=["GET", "POST"])
def generate():

    global gan_generator
    imgData = request.get_data()
    with open("input.png", 'wb') as output:
        output.write(imgData)

    lr = load_image("input.png")
    gan_sr = resolve_single(gan_generator, lr)
    epoch_time = int(time.time())
    outputfile = 'output_%s.png' % (epoch_time)
    imsave(outputfile, gan_sr.numpy())
    response = {'result': outputfile}

    return jsonify(response)

让我们尝试了解前面的代码块中发生的情况。 /generate路由已设置为仅监听 HTTP 请求的 GET 和 POST 方法。 首先,该方法获取 API 请求中提供给它的图像,将其转换为 NumPy 数组,然后将其提供给 SRGAN 模型。 SRGAN 模型返回超分辨率图像,然后为其分配一个唯一的名称并存储在服务器上。 用户显示文件名,他们可以使用该文件名调用另一个端点来下载文件。 让我们现在构建此端点。

  1. 为了创建端点以便下载生成的文件,我们可以使用以下代码:
@app.route('/download/', methods=['GET'])
def download(fname):
    return send_file(fname)

在这里,我们创建了一个名为/download的端点,该端点附加了文件名后,将其提取并发送回给用户。

  1. 最后,我们可以编写执行该脚本并设置服务器的代码:
app.run(host="0.0.0.0", port="8080")

保存此文件。 确保此时将您的存储库推送到 GitHub/GitLab 存储库。

现在,我们准备将该脚本部署到DigitalOcean Droplet。

将 Flask 脚本部署到 DigitalOcean Droplet

要将 Flask 脚本部署到 DigitalOcean Droplet,您必须创建一个 DigitalOcean 帐户并创建一个 Droplet。 请按照以下步骤操作:

  1. 在您喜欢的 Web 浏览器中转到 digitalocean.com 。

如果您希望在添加帐单详细信息时获得 100 美元的赠金,也可以转到这里。 我们稍后再做。

  1. 在 DigitalOcean 的注册表格中填写您的详细信息,然后提交表格继续进行下一步。

  2. 系统将要求您验证电子邮件并为 DigitalOcean 帐户添加结算方式。

  3. 在下一步中,系统将提示您创建第一个项目。 输入所需的详细信息并提交表单以创建您的项目:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第44张图片

  1. 创建项目后,您将被带到 DigitalOcean 仪表板。 您将能够看到创建 Droplet 的提示,如以下屏幕截图所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第45张图片

  1. 单击“提示”以弹出 Droplet 创建表单。 选择下表中描述的选项:

    字段 说明 要使用的值
    选择一张图片 Droplet 将在其上运行的操作系统。 Ubuntu 18.04(或最新可用版本)
    选择一个计划 选择 Droplet 的配置。 4 GB RAM 或更高
    添加块存储 Droplet 的其他持久性,可拆卸存储容量。 保留默认值
    选择数据中心区域 投放 Droplet 的区域。 根据您的喜好选择任何一个
    选择其他选项 选择将与您的 Droplet 一起使用的所有其他功能。 保留默认值
    认证方式 选择虚拟机的认证方法。 一次性密码
    完成并创建 Droplet 的一些其他设置和选项。 保留默认值
  2. 单击“创建 Droplet”,然后等待 DigitalOcean 设置您的 Droplet。

  3. 创建 Droplet 后,单击其名称以打开 Droplet 管理控制台,该控制台应如下所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第46张图片

  1. 现在,我们可以使用上一幅截图所示的 Droplet 控制台左侧导航菜单上的 Access 选项卡登录到 Droplet。 单击“访问”,然后启动控制台。

  2. 将打开一个新的浏览器窗口,显示您的 Droplet 的 VNC 视图。 系统将要求您输入 Droplet 的用户名和密码。 您必须在此处使用的用户名是root。 可以在您已注册的电子邮件收件箱中找到该密码。

  3. 首次登录时,系统会要求您更改 Droplet 密码。 确保您选择一个强密码。

  4. 登录 Droplet 后,将在 VNC 终端上看到一些 Ubuntu 欢迎文本,如以下屏幕截图所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第47张图片

  1. 现在,按照本书的“附录”中的说明,执行在云 VM 上设置深度学习环境的步骤。
  2. 接下来,将项目存储库克隆到您的 Droplet,并使用以下命令将工作目录更改为存储库的api文件夹:
git clone https://github.com/yourusername/yourrepo.git
cd yourrepo/api
  1. 使用以下命令运行服务器:
python3 flask_app.py

除了来自 TensorFlow 的一些警告消息之外,在终端输出的末尾,您还应该看到以下几行指示服务器已成功启动:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第48张图片

现在,如 Droplet 控制台所示,您的服务器已启动并在 Droplet 的 IP 上运行。

在下一部分中,我们将学习如何使用 Flutter 应用向服务器发出 POST 请求,并在屏幕上显示服务器的响应。

在 Flutter 上集成托管的自定义模型

在本节中,我们将向托管模型发出 POST 请求,并将其传递给用户选择的图像。 服务器将以 PNG 格式响应NetworkImage。 然后,我们将更新之前添加的图像小部件,以显示模型返回的增强图像。

让我们开始将托管模型集成到应用中:

  1. 首先,我们将需要两个以上的外部库来发出成功的 POST 请求。 因此,我们将以下库作为依赖项添加到pubspec.yaml文件:
dependencies:
     flutter:
          http: 0.12.0+4
          mime: 0.9.6+3

http依赖项包含一组类和函数,这些类和函数使使用 HTTP 资源非常方便。 mime依赖性用于处理 MIME 多部分媒体类型的流。

现在,我们需要运行flutter pub get以确保所有依赖项均已正确安装到我们的项目中。

  1. 接下来,我们将所有新添加的依赖项导入image_super_resolution.dart文件:
import 'package:http/http.dart' as http;
import 'package:mime/mime.dart';
  1. 现在,我们需要定义fetchResponse(),它接受所选的图像文件并向服务器创建 POST 请求:
void fetchResponse(File image) async {

    final mimeTypeData =
        lookupMimeType(image.path, headerBytes: [0xFF, 0xD8]).split('/');

    final imageUploadRequest = http.MultipartRequest('POST', Uri.parse("http://x.x.x.x:8080/generate"));

    final file = await http.MultipartFile.fromPath('image', image.path,
        contentType: MediaType(mimeTypeData[0], mimeTypeData[1]));

    imageUploadRequest.fields['ext'] = mimeTypeData[1];
    imageUploadRequest.files.add(file);
    try {
      final streamedResponse = await imageUploadRequest.send();
      final response = await http.Response.fromStream(streamedResponse);
      final Map<String, dynamic> responseData = json.decode(response.body);      
      String outputFile = responseData['result'];
    } catch (e) {
      print(e);
      return null;
    }
  }

在前面的方法中,我们通过使用lookupMimeType函数并使用文件的路径及其头来查找所选文件的 MIME 类型。 然后,按照托管模型的服务器的预期,初始化一个多部分请求。 我们使用 HTTP 执行此操作。 我们使用MultipartFile.fromPath并将image的值设置为作为POST参数附加的路径。 由于image_picker存在一些错误,因此我们将图片的扩展名明确传递给请求主体。 因此,它将图像扩展名与文件名(例如filenamejpeg)混合在一起,这在管理或验证文件扩展名时在服务器端造成了问题。 然后,来自服务器的响应将存储在response变量中。 响应为 JSON 格式,因此我们需要使用json.decode()对其进行解码。 该函数接收响应的主体,可以使用response.body进行访问。 我们将解码后的 JSON 存储在responseData变量中。 最后,使用responseDate['result']访问服务器的输出并将其存储在outputFile变量中。

  1. 接下来,我们定义displayResponseImage()函数,该函数接受服务器在outputFile参数内返回的 PNG 文件的名称:
void displayResponseImage(String outputFile) {
     print("Updating Image");
     outputFile = 'http://x.x.x.x:8080/download/' + outputFile;
     setState(() {        
        imageOutput = Image(image: NetworkImage(outputFile));
    });
 }

根据服务器的自定义,我们需要在文件名之前附加一个字符串以将其显示在屏幕上。 该字符串应包含服务器正在运行的端口地址,后跟'/download/'。 然后,我们将outputFile的最终值用作url值,将imageOutput小部件的值设置为NetworkImage。 另外,我们将其封装在[H​​TG5]中,以便在正确获取响应后可以刷新屏幕。

  1. 接下来,我们在fetchResponse()的最后调用displayResponseImage(),并传入从托管模型收到的outputFile
void fetchResponse(File image) async {
    ....   
    displayResponseImage(outputFile);
}
  1. 最后,通过传入用户最初选择的图像,将调用从pickImage()添加到fetchResponse()
void pickImage() async {
     ....
     fetchResponse(pickedImg);
 }

在前面的步骤中,我们首先向托管模型的服务器发出 POST 请求。 然后,我们解码响应并添加代码以在屏幕上显示它。 在pickImage()末尾添加fetchResponse()可确保仅在用户选择图像后才发出 POST 请求。 另外,为了确保在成功解码来自服务器的输出之后已经尝试显示响应图像,在fetchResponse()的末尾调用displayImageResponse()。 以下屏幕快照显示了屏幕的最终预期状态:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YGKKzKRo-1681785128428)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/14248601-46e6-421b-aa88-3353ff56bd4d.png)]

因此,我们已经完成了应用的构建,以便可以显示模型的输出。 我们将两个图像保存在屏幕上,以便我们可以看到它们之间的差异。

可以在这个页面上访问image_super_resolution.dart文件的代码。

创建材质应用

现在,我们将添加main.dart以创建最终的 Material 应用。 我们将创建一个名为MyApp的无状态小部件,并覆盖build()方法:

class MyApp extends StatelessWidget {
     @override
     Widget build(BuildContext context) {
         return MaterialApp(
             title: 'Flutter Demo',
             theme: ThemeData(
                 primarySwatch: Colors.blue,
             ),
             home: ImageSuperResolution(),
         );
     }
}

最后,我们执行代码,如下所示:

void main() => runApp(MyApp());

至此,我们完成了一个应用的创建,该应用允许用户选择图像并修改其分辨率。

总结

在本章中,我们研究了超分辨率图像以及如何使用 SRGAN 应用它们。 我们还研究了其他类型的 GAN 以及 GAN 的总体工作方式。 然后,我们讨论了如何创建一个 Flutter 应用,该应用可以与 DigitalOcean Droplet 上托管的 API 集成在一起,以便当从图库中拾取图像时可以执行图像超分辨率。 接下来,我们介绍了如何使用 DigitalOcean Droplet,以及由于其低成本和易于使用的界面而成为托管应用后端的理想选择。

在下一章中,我们将讨论一些流行的应用,这些应用通过将深度学习集成到其功能中而获得了很大的改进。 我们还将探索手机深度学习中的一些热门研究领域,并简要讨论已在其上进行的最新工作。

十、前方的路

旅程中最重要的部分是知道结束后要去哪里。 到目前为止,在本系列项目中,我们已经介绍了一些与 Flutter 应用相关的独特且功能强大的深度学习DL)应用,但重要的是,您必须知道在哪里可以找到更多这样的项目,灵感和知识来构建自己的出色项目。 在本章中,我们将简要介绍当今在移动应用上使用 DL 的最流行的应用,当前趋势以及将来在该领域中将会出现的情况。

在本章中,我们将介绍以下主题:

  • 了解移动应用中 DL 的最新趋势
  • 探索移动设备上 DL 的最新发展
  • 探索移动应用中 DL 的当前研究领域

让我们开始研究 DL 移动应用世界中的一些趋势。

了解移动应用中 DL 的最新趋势

特别是 DL,随着最新技术和硬件的发展,人工智能AI)变得越来越移动。 组织一直在使用智能算法来提供个性化的用户体验并提高应用参与度。 借助人脸检测,图像处理,文本识别,对象识别和语言翻译等技术,移动应用已不仅仅是提供静态信息的媒介。 它们能够适应用户的个人偏好和选择以及当前和过去的环境状况,以提供无缝的用户体验。

让我们看一下一些流行的应用及其部署的方法,以提供良好的用户体验,同时增加应用的参与度。

数学求解器

数学求解器应用由微软于 2020 年 1 月 16 日启动,可通过简单地单击智能手机上有问题的图片来帮助学生完成数学作业。 该应用为基本和高级数学问题提供支持,涵盖了广泛的主题,包括基本算术,二次方程,微积分和统计。 以下屏幕截图显示了该应用的工作方式:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第49张图片

用户可以在其智能手机上单击手写或打印问题的图片,或直接在设备上涂鸦或键入图片。 该应用利用 AI 来识别问题并准确解决。 此外,它还可以提供分步说明,并提供其他学习资料,例如与问题有关的工作表和视频教程。

Netflix

Netflix 的推荐系统是在移动应用上使用 DL 的最大成功案例之一。 Netflix 利用多种算法来了解用户的偏好,并提供了他们可能感兴趣的推荐列表。所有内容都标记有标签,这些标签提供了可以从中学习算法的初始数据集。 此外,该系统监视着超过 1 亿个用户个人资料,以分析人们观看的内容,以后可能观看的内容,以前观看的内容,一年前观看的内容,等等。 将收集的所有数据汇总在一起,以了解用户可能感兴趣的内容类型。

然后,将使用标签和用户行为收集的数据汇总在一起,并输入到复杂的 ML 算法中。 这些数据有助于解释可能最重要的因素-例如,如果用户一年前观看的电影与上周观看的系列相比应被计数两次。 该算法还可以从用户行为中学习,例如用户喜欢还是不喜欢特定的内容,或者用户在 2 个晚上观看和观看的节目。 将所有因素汇总在一起并进行仔细分析,从而得出用户可能最感兴趣的建议列表。

谷歌地图

Google Maps 已帮助通勤者前往新地方,探索新城市并监控每日流量。 在 2019 年 6 月上旬,谷歌地图发布了一项新功能,使用户可以监控印度 10 个主要城市的巴士旅行时间,以及从印度铁路局获得实时更新。 该功能位于班加罗尔,钦奈,哥印拜陀,德里,海得拉巴,勒克瑙,孟买,浦那和苏拉特,它利用 Google 的实时交通数据和公交时刻表来计算准确的出行时间和延误。 支持该功能的算法可从总线位置随时间的顺序中学习。 该数据还与通勤时公交车上的汽车速度结合在一起。 数据还用于捕获特定街道的独特属性。 研究人员还模拟了围绕某个区域弹出查询的可能性,以使该模型更加健壮和准确。

Tinder

作为结识新朋友的全球最受欢迎的应用,Tinder 部署了许多学习模型,以增加喜欢特定个人资料的人数。 智能照片功能增加了用户找到正确匹配项的可能性。 该功能随机排序特定用户的图片并将其显示给其他人。 支持该功能的算法分析了向左或向右滑动图片的频率。 它使用该知识根据图片的受欢迎程度对其重新排序。 随着越来越多的数据收集,该算法的准确率一直在不断提高。

Snapchat

Snapchat 使用的过滤器是在图片和视频的顶部添加的设计叠加层,可以跟踪面部移动。 这些过滤器是通过计算机视觉实现的。 应用使用的算法的第一步是检测图像中存在的面部。 它输出包围检测到的面部的框。 然后,它为检测到的每个脸部标记面部标志(例如眼睛,鼻子和嘴唇)。 这里的输出通常是一个包含x-坐标和y-坐标的二维点。 正确检测到面部和面部特征后,它将使用图像处理功能在整个面部上正确放置或应用过滤器。 该算法使用 Active Shape Model 进一步分析了关键的面部特征。 在通过手动标记关键面部特征的边界进行训练后,该模型将创建与屏幕上出现的面部对齐的平均面部。 该模型将创建一个网格,以正确放置过滤器并跟踪其运动。

现在,我们来看看 DL 领域的研究领域。

探索移动设备上 DL 的最新发展

随着 DL 和 AI 的复杂性与移动应用的结合,正在不断进行软件和硬件优化,以在设备上高效运行模型。 让我们看看其中的一些。

谷歌的 MobileNet

Google 的 MobileNet 于 2017 年推出。它是基于 TensorFlow 的一组移动优先计算机视觉模型,经过精心优化以在受限的移动环境中高效运行。 它充当复杂神经网络结构的准确率与移动运行时性能约束之间的桥梁。 由于这些模型具有在设备本身上本地运行的能力,因此 MobileNet 具有安全性,隐私性和灵活的可访问性的优点。 MobileNet 的两个最重要的目标是在处理计算机视觉模型时减小尺寸并降低复杂性。 MobileNet 的第一个版本提供了低延迟模型,该模型能够在受限资源下正常工作。 它们可用于分类,检测,嵌入和分段,支持各种用例。

于 2018 年发布的 MobileNetV2 是对第一个版本的重大增强。 它可以用于语义分割,对象检测和分类。 作为 TensorFlow-Slim 图像分类库的一部分启动的 MobileNetV2,可以从 Colaboratory 直接访问。 也可以在本地下载,使用 Jupyter 进行浏览,也可以从 TF-Hub 和 GitHub 访问。 添加到架构中的两个最重要的功能是层之间的线性瓶颈和瓶颈之间的快捷连接。 瓶颈对中间的输入和输出进行编码,并且内层支持从较低级别的概念转换为较高级别的描述符的功能。 传统的剩余连接和快捷方式有助于减少训练时间并提高准确率。 与第一个版本相比,MobileNetV2 更快,更准确,并且所需的操作和参数更少。 它非常有效地用于对象检测和分割以提取特征。

您可以在此处阅读有关此研究工作的更多信息。

阿里巴巴移动神经网络

阿里巴巴移动神经网络MNN)是开源的轻量级 DL 推理引擎。 阿里巴巴工程副总裁贾阳清说:“与 TensorFlow 和 Caffe2 等通用框架相比,它既涵盖训练又包括推理,MNN 专注于推理的加速和优化,并解决了模型部署过程中的效率问题。 因此可以在移动端更高效地实现模型背后的服务,这实际上与 TensorRT 等服务器端推理引擎中的思想相符在大型机器学习应用中,推理的计算量通常是 10 倍以上,因此,进行推理的优化尤为重要。”

MNN 的主要关注领域是深度神经网络DNN)模型的运行和推断。 它专注于模型的优化,转换和推断。 MNN 已被成功用于阿里巴巴公司的许多移动应用中,例如 Mobile Tmall,Mobile Taobao,Fliggy,UC,Qianuu 和 Juhuasuan。 它涵盖了搜索推荐,短视频捕获,直播,资产分配,安全风险控制,交互式营销,按图像搜索产品以及许多其他实际场景。 菜鸟呼叫机柜等物联网IoT)设备也越来越多地使用技术。 MNN 具有很高的稳定性,每天可以运行超过 1 亿次。

MNN 具有高度的通用性,并为市场上大多数流行的框架提供支持,例如 TensorFlow,Caffe 和开放式神经网络交换ONNX)。 它与卷积神经网络CNN)和循环神经网络RNN)等通用神经网络兼容。 MNN 轻巧且针对移动设备进行了高度优化,并且没有依赖关系。 它可以轻松部署到移动设备和各种嵌入式设备。 它还通过便携式操作系统接口POSIX)支持主要的 Android 和 iOS 移动操作系统以及嵌入式设备。 MNN 不受任何外部库的影响,可提供非常高的性能。 它的核心操作通过大量的手写汇编代码来实现,以充分利用高级 RISC 机器ARM)CPU 的优势。 借助高效的图像处理模块IPM),无需 libyuv 或 OpenCV 即可加速仿射变换和色彩空间变换,MNN 易于使用。

在积极开发和研究这些产品的同时,现在让我们看一下将来有望变得越来越重要的一些领域。

探索移动应用中 DL 的当前研究领域

活跃的研究人员社区要投入时间和精力,对于任何研究领域的健康发展至关重要。 幸运的是,DL 在移动设备上的应用引起了全球开发人员和研究人员的强烈关注,许多手机制造商(例如三星,苹果,Realme 和 Xiaomi)将 DL 直接集成到了系统用户界面中 (UI)为所有设备生成。 这极大地提高了模型的运行速度,并且通过系统更新定期提高模型的准确率。

让我们看一下该领域中一些最受欢迎的研究领域,以及它们是如何发展的。

DeepFashion

在 2019 年,DeepFashion2 数据集由葛玉英,张瑞茂等提出。 该数据集是对 DeepFashion 数据集的改进,包括来自卖方和消费者的 491,000 张图像。 数据集可识别 801,000 件服装。 数据集中的每个项目都标有比例,遮挡,放大,视点,类别,样式,边界框,密集的界标和每个像素的蒙版。

数据集在训练集中有 391,000 张图像,在验证集中有 34,000 张图像,在测试集中有 67,000 张图像。 该数据集提供了提出更好的模型的可能性,该模型能够从图像中识别时装和不同的服装。 可以轻松想象此数据集可能会导致的应用范围-包括在线商店根据消费者经常穿的衣服推荐要购买的产品,以及首选品牌和产品的预期价格范围。 仅通过识别他们所穿的服装和品牌,也有可能识别任何人可能从事的职业及其财务,宗教和地理细节。

您可以在此处阅读有关 DeepFashion2 数据集的更多信息。

自我注意生成对抗网络

我们在“第 9 章”,“构建图像超分辨率应用”中讨论了生成对抗网络GAN)的应用,其中我们从低分辨率图像中生成高分辨率图像。 GAN 在学习模仿艺术和图案方面做得相当不错。 但是,在需要记住更长的序列的情况下,以及在序列的多个部分对于生成生成的输出很重要的情况下,它们无法很好地执行。 因此,我们期待 Ian Goodfellow 及其团队推出的自我注意力 GANSAGAN),它们是对图像生成任务应用注意力驱动的远程依赖建模的 GAN 系统。 该系统在 ImageNet 数据集上具有更好的性能,并有望在将来被广泛采用。

Jason Antic 的 DeOldify 项目是使用 SAGANs 完成的工作的衍生产品。 该项目旨在将色彩带入旧的图像和视频中,从而使它们似乎从来没有缺少色彩。 以下屏幕快照显示了 DeOldify 项目的示例:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第50张图片

Dorothea Lange(1936)的《移民母亲》。 图像取自 DeOldify GitHub 存储库。 该项目可通过这里进行测试和演示。 您可以在这个页面上了解有关 SAGAN 的更多信息。

图片动画

Facebook 是一个流行的社交媒体平台,具有用于多个平台的专用应用,一直在致力于创建工具,使您可以使用普通的相机生成 3D 图像,否则这些相机只会生成 2D 图像。 图像动画是一项类似的技术,可让我们将动画带入静态图像。 可以想象这种技术非常令人兴奋的用法,人们拍摄自拍照,然后从运动库中进行选择以对其图像进行动画处理,就好像他们自己在进行这些运动一样。

图像动画虽然还处于起步阶段,但可以成为流行和有趣的应用,考虑到采用 Deepfake 技术的类似应用已成功地成为一项业务-例如,中国的 Zao 应用。

您可以在此处阅读图像动画研究论文。

总结

在本章中,我们讨论了一些最流行的移动应用,这些应用因其在业务产品中最前沿地使用 DL 而著称,还讨论了 DL 影响其增长的方式。 我们还讨论了移动应用 DL 领域的最新发展。 最后,我们讨论了该领域的一些令人兴奋的研究领域,以及它们将来如何发展成潜在的流行应用。 我们相信,到目前为止,您将对如何在移动应用上部署 DL 以及如何使用 Flutter 来构建可在所有流行的移动平台上运行的跨平台移动应用有一个很好的了解。

我们在本章结束时希望,您将充分利用本项目系列中介绍的思想和知识,并构建出令人敬畏的东西,从而在此技术领域带来一场革命。

十一、附录

计算机科学领域令人兴奋的是,它允许多个软件组件组合在一起并致力于构建新的东西。 在这个简短的附录中,我们介绍了在移动设备上进行深度学习之前需要设置的工具,软件和在线服务。

在本章中,我们将介绍以下主题:

  • 在 Cloud VM 上设置深度学习环境
  • 安装 Dart SDK
  • 安装 Flutter SDK
  • 配置 Firebase
  • 设置 Visual StudioVS)代码

在 Cloud VM 上设置深度学习环境

在本节中,我们将提供有关如何在 Google Cloud PlatformGCP)计算引擎虚拟机(VM)实例以执行深度学习。 您也可以轻松地将此处描述的方法扩展到其他云平台。

我们将以快速指南开始,介绍如何创建您的 GCP 帐户并为其启用结算功能。

创建 GCP 帐户并启用结算

要创建 GCP 帐户,您需要一个 Google 帐户。 如果您有一个以@gmail.com结尾的电子邮件地址或 G Suite 上的帐户,则您已经有一个 Google 帐户。 否则,您可以通过访问这里创建一个 Google 帐户。 登录到 Google 帐户后,请执行以下步骤:

  1. 在浏览器上访问这里。
  2. 接受在弹出窗口中显示给您的所有条款。
  3. 您将能够查看 GCP 控制台信息中心。 您可以通过阅读这个页面上的支持文档来快速使用此仪表板。
  4. 在左侧导航菜单上,单击“计费”以打开计费管理仪表板。 系统将提示您添加一个计费帐户,如以下屏幕截图所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qtGYRynW-1681785128429)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/4246b346-b4a9-4e99-998f-165343832c4e.png)]

  1. 点击“添加结算帐户”。 如果有资格,您将被重定向到GCP Free Trial注册页面。 您可以在这个页面上了解有关免费试用的更多信息。 您应该看到类似于以下屏幕截图的屏幕:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第51张图片

  1. 根据需要填写表格。 创建完帐单后,请返回 GCP 控制台信息中心。

您已成功创建 GCP 帐户并为其启用了结算功能。 接下来,您将能够在 GCP 控制台中创建一个项目并将资源分配给该项目。 我们将在接下来的部分中对此进行演示。

创建一个项目和 GCP Compute Engine 实例

在本部分中,您将在 GCP 帐户上创建一个项目。 GCP 中的所有资源都封装在项目下。 项目可能属于或不属于组织。 一个组织下可以有多个项目,而一个项目中可能有多个资源。 让我们开始创建项目,如以下步骤所示:

  1. 在屏幕的左上方,单击“选择项目”下拉菜单。

  2. 在出现的对话框中,单击对话框右上方的“新建项目”。

  3. 您将看到新的项目创建表单,如以下屏幕截图所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第52张图片

  1. 填写必要的详细信息后,单击CREATE完成创建项目。 创建项目后,将带您到项目的仪表板。 在这里,您将能够查看与当前所选项目相关的一些基本日志记录和监视。 您可以在这个页面上了解有关 GCP 资源组织方式的更多信息。
  2. 在左侧导航窗格中,单击Compute Engine。 系统将提示您创建一个 VM 实例。
  3. 点击“创建”以显示 Compute Engine 实例创建表单。 根据需要填写表格。 我们假设您在创建实例时选择了 Ubuntu 18.04 LTS 发行版。
  4. 确保在防火墙设置中启用对 VM 实例的 HTTP 和 HTTPS 连接的访问​​,如以下屏幕快照所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第53张图片

  1. 单击“创建”。 GCP 开始为您配置 VM 实例。 您将被带到 VM 实例管理页面。 您应该在此页面上看到您的 VM,如以下屏幕截图所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第54张图片

现在,您准备开始配置此 VM 实例以执行深度学习。 我们将在下一部分中对此进行介绍。

配置您的 VM 实例来执行深度学习

在本节中,我们将指导您如何安装包和模块,以在创建的 VM 实例上执行深度学习。 这些包和模块的安装说明在您选择的任何云服务提供商中都是相似的。

您还可以在本地系统上使用类似的命令,以设置本地深度学习环境。

首先调用 VM 的终端:

  1. 单击 VM 实例页面上的SSH按钮,以启动到 VM 的终端会话。

  2. 您应该看到终端会话开始,其中包含一些与系统有关的常规信息以及上次登录的详细信息,如以下屏幕截图所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8qw0xuwj-1681785128430)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/b79b0eac-cebe-4895-9616-87a90fd5d3da.png)]

  1. 现在,让我们对该新创建的实例的包存储库执行更新:
sudo apt update
  1. 接下来,我们将在此 VM 上安装 Anaconda。 Anaconda 是一个受欢迎的包集合,用于使用 Python 执行深度学习和与数据科学相关的任务。 它带有conda包管理器打包在一起,这使得管理系统上安装的 Python 包的不同版本非常容易。 要安装它,我们首先需要获取 Anaconda 安装程序下载链接。 前往这里。 您将转到一个页面,为您提供要安装的 Anaconda 版本的选择,如以下屏幕截图所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第55张图片

  1. 建议您选择 Python 3.7 版本。 右键单击“下载”按钮,然后在菜单中找到允许您复制链接地址的选项。
  2. 切换到您的 VM 实例的终端会话。 使用以下命令将占位符文本粘贴到命令中,从而将其替换为您复制的链接,如下所示:
curl -O <link_you_have_copied>
  1. 前面的命令会将 Anaconda 安装程序下载到当前用户的主目录中。 要对其进行验证,可以使用ls命令。 现在,要将此文件设置为可执行文件,我们将使用以下命令:
chmod +x Anaconda*.sh
  1. 现在,安装程序文件可以由您的系统执行。 要开始执行,请使用以下命令:
./Anaconda*.sh
  1. 安装应开始。 应该显示一个提示,询问您是否接受 Anaconda 软件的许可协议,如下所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第56张图片

  1. 点击Enter继续检查许可证。 您会看到许可证文件。
  2. 点击向下箭头键以阅读协议。 输入yes接受许可证。
  3. 系统将要求您确认 Anaconda 安装的位置,如以下屏幕截图所示:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第57张图片

  1. 点击Enter确认位置。 包提取和安装将开始。 完成此操作后,系统将询问您是否要初始化 Anaconda 环境。 在此处输入yes,如下所示:

  1. 现在,安装程序将完成其任务并退出。 要激活 Anaconda 环境,请使用以下命令:
source ~/.bashrc
  1. 您已经成功安装了 Anaconda 环境并激活了它。 要检查安装是否成功,请在终端中输入以下命令:
python3

如果以下命令的输出在第二行包含单词 Anaconda,Inc.,则表明安装成功。 您可以在以下屏幕截图中看到它:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Yk7ZRtMM-1681785128431)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/6c35a0ee-26db-44c3-831f-a9c4ba78407c.png)]

现在,您可以在此环境上开始运行深度学习脚本。 但是,您将来可能希望向此环境添加更多工具库,例如 PyTorch 或 TensorFlow 或任何其他包。 由于本书假定读者熟悉 Python,因此我们不会详细讨论pip工具。

现在让我们看一下如何在 VM 上安装 TensorFlow。

在 VM 上安装 TensorFlow

TensorFlow 是执行深度学习的绝佳框架。

要安装它,可以使用以下命令:

# TensorFlow 1 with CPU only support
python3 -m pip install tensorflow==1.15

# TensorFlow 1 with GPU support
python3 -m pip install tensorflow-gpu==1.15

# TensorFlow 2 with CPU only support
python3 -m pip install tensorflow

# Tensorflow 2 with GPU support
python3 -m pip install tensorflow-gpu

Python 中另一个经常安装的流行库是自然语言工具包(NLTK)库。 我们将在接下来的部分中演示其安装过程。

在 VM 上安装 NLTK 并下载包

要在 VM 上安装 NLTK 并为其下载数据包,请执行以下步骤:

  1. 使用pip安装 NLTK:
python3 -m pip install nltk
  1. NLTK 有几种不同的数据包。 在大多数情况下,您并不需要全部。 要列出 NLTK 的所有可用数据包,请使用以下命令:
python3 -m nltk.downloader

前面命令的输出将允许您交互式地查看所有可用的包,选择所需的包,然后下载它们。

  1. 但是,如果您只希望下载一个包,请使用以下命令:
python3 -m nltk.downloader stopwords

前面的命令将下载 NLTK 的stopwords数据包。 在极少数情况下,您可能会发现自己需要或使用 NLTK 中可用的所有数据包。

通过这种设置,您应该能够在云 VM 上运行大多数深度学习脚本。

在下一部分中,我们将研究如何在本地系统上安装 Dart。

安装 Dart SDK

Dart 是 Google 开发的一种面向对象的语言。 它用于移动和 Web 应用开发。 Flutter 是用 Dart 构建的。 Dart 具有即时JIT)开发周期,该状态与有状态的热重载兼容,并且具有提前编译的功能,可以快速启动并提供可预测的性能,这使其成为了可能。 适用于 Flutter。

以下各节讨论如何在 Windows,macOS 和 Linux 上安装 Dart。

Windows

在 Windows 中安装 Dart 的最简单方法是使用 Chocolatey。 只需在终端中运行以下命令:

 C:\> choco install dart-sdk

接下来,我们将研究如何在 Mac 系统上安装 Dart。

MacOS

要在 macOS 上安装 Dart,请执行以下步骤:

  1. 通过在终端中运行以下命令来安装 Homebrew:
$ /usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"
  1. 运行以下命令以安装 Dart:
$brew tap dart-lang/dart
$brew install dart

接下来,我们将研究如何在 Linux 系统上安装 Dart。

Linux

Dart SDK 可以如下安装在 Linux 中:

  1. 执行以下一次性设置:
$sudo apt-get update
$sudo apt-get install apt-transport-https
$sudo sh -c 'wget -qO- https://dl-ssl.google.com/linux/linux_signing_key.pub | apt-key add -'
$sudo sh -c 'wget -qO- https://storage.googleapis.com/download.dartlang.org/linux/debian/dart_stable.list > /etc/apt/sources.list.d/dart_stable.list'
  1. 安装稳定版本:
$sudo apt-get update
$sudo apt-get install dart

接下来,我们将研究如何在本地计算机上安装 Flutter SDK。

安装 Flutter SDK

Flutter 是 Google 的一个工具包,用于使用单个代码库构建本地编译的 Android,iOS 和 Web 应用。 Flutter 具有热重载的快速开发,易于构建的表达性 UI 和本机性能等功能,这些都使 Flutter 成为应用开发人员的首选。

以下各节讨论如何在 Windows,macOS 和 Linux 上安装 Flutter SDK。

Windows

以下步骤详细概述了如何在 Windows 上安装 Flutter:

  1. 从这里下载最新的 Flutter SDK 稳定版本。
  2. 解压缩 ZIP 文件夹,并导航到要安装 Flutter SDK 的目录,以放置flutter文件夹。

避免将flutter放在可能需要特殊特权的目录中,例如C:\Program Files\

  1. 在“开始”搜索栏中输入env,然后选择“编辑环境变量”。
  2. 使用;作为分隔符,将flutter/bin的完整路径附加到用户变量下的路径

如果缺少Path条目,只需创建一个新的Path变量并将path设置为flutter/bin作为其值。

  1. 在终端中运行flutter doctor

flutter doctor分析整个 Flutter 的安装,以检查是否需要更多工具才能在计算机上成功运行 Flutter。

接下来,我们将研究如何在 Mac 系统上安装 Flutter。

MacOS

Flutter 可以如下安装在 macOS 上:

  1. 从这里下载最新的稳定 SDK。
  2. 将下载的 ZIP 文件夹解压缩到合适的位置,如下所示:
$cd ~/
$unzip ~/Downloads/flutter_macos_v1.9.1+hotfix.6-stable.zip
  1. flutter工具添加到路径变量:$ export PATH=pwd/flutter/bin:$PATH
  2. 打开bash_profile以永久更新PATH
$cd ~
$nano .bash_profile
  1. 将以下行添加到bash_profile
$export PATH=$HOME/flutter/bin:$PATH
  1. 运行flutter doctor

Linux

以下步骤概述了如何在 Linux 上安装 Flutter:

  1. 从这里下载 SDK 的最新稳定版本。
  2. 将文件提取到合适的位置:
 $cd ~/development
 $tar xf ~/Downloads/flutter_linux_v1.9.1+hotfix.6-stable.tar.xz
  1. flutter添加到path变量中:
$export PATH="$PATH:`pwd`/flutter/bin"
  1. 运行flutter doctor

接下来,我们将研究如何配置 Firebase 以提供 ML Kit 和自定义模型。

配置 Firebase

Firebase 提供了可促进应用开发并帮助支持大量用户的工具。 Firebase 可以轻松用于 Android,iOS 和 Web 应用。 Firebase 提供的产品(例如 Cloud Firestore,ML Kit,Cloud Functions,Authentication,Crashlytics,Performance Monitoring,Cloud Messaging 和 Dynamic Links)有助于构建应用,从而在不断发展的业务中提高应用质量。

要集成 Firebase 项目,您需要创建一个 Firebase 项目并将其集成到您的 Android 或 iOS 应用中。 以下各节讨论如何创建 Firebase 项目并将其集成到 Android 和 iOS 项目中。

创建 Firebase 项目

首先,我们需要创建一个 Firebase 项目并将其链接到我们的 Android 和 iOS 项目。 此链接有助于我们利用 Firebase 提供的功能。

要创建 Firebase 项目,请执行以下步骤:

  1. 通过这里访问 Firebase 控制台。
  2. 单击“添加项目”以添加新的 Firebase 项目:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第58张图片

  1. 为您的项目提供一个名称:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第59张图片

  1. 根据您的要求启用/禁用 Google Analytics(分析)。 通常建议您保持启用状态。

Google Analytics 是一种免费且不受限制的分析解决方案,可在 Firebase Crashlytics,Cloud Messaging,应用内消息传递,远程配置,A/B 测试,预测和 Cloud Functions 中实现目标定位,报告等功能。

  1. 如果您选择 Firebase Analytics,则还需要选择一个帐户:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第60张图片

在 Firebase 控制台上创建项目后,您将需要分别为 Android 和 iOS 平台进行配置。

配置 Android 项目

以下步骤讨论了如何配置 Android 项目以支持 Firebase:

  1. 导航到 Firebase 控制台上的应用。 在项目概述页面的中心,单击 Android 图标以启动工作流程设置:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第61张图片

  1. 添加包名称以在 Firebase 控制台上注册该应用。 此处填写的包名称应与您的应用的包名称匹配。 此处提供的包名称用作标识的唯一密钥:

此外,您可以提供昵称和调试签名证书 SHA-1。

  1. 下载google-services.json文件并将其放在app文件夹中:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第62张图片

google-services.json文件存储开发人员凭据和配置设置,并充当 Firebase 项目和 Android 项目之间的桥梁。

  1. 用于 Gradle 的 Google 服务插件会加载您刚刚下载的google-services.json文件。 项目级别的build.gradle/build.gradle)应该进行如下修改,以使用该插件:
buildscript {
  repositories {
    // Check that you have the following line (if not, add it):
    google()  // Google's Maven repository
  }
  dependencies {
    ...
    // Add this line
    classpath 'com.google.gms:google-services:4.3.3'
  }
}

allprojects {
  ...
  repositories {
    // Check that you have the following line (if not, add it):
    google()  // Google's Maven repository
    ...
  }
}

  1. 这是应用级别的build.gradle roject>/build.gradle):
apply plugin: 'com.android.application'
// Add this line
apply plugin: 'com.google.gms.google-services'

dependencies {
  // add SDKs for desired Firebase products
  // https://firebase.google.com/docs/android/setup#available-libraries
}

现在,您都可以在 Android 项目中使用 Firebase。

配置 iOS 项目

以下步骤演示了如何配置 iOS 项目以支持 Firebase:

  1. 导航到 Firebase 控制台上的应用。 在项目概述页面的中心,单击 iOS 图标以启动工作流程设置:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第63张图片

  1. 添加 iOS 捆绑包 ID 名称,以在 Firebase 控制台上注册该应用。 您可以在“常规”选项卡中的捆绑包标识符中找到应用主要目标的 Xcode。 它用作标识的唯一密钥:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第64张图片

此外,您可以提供昵称和 App Store ID。

  1. 下载GoogleService-Info.plist文件:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第65张图片

  1. 将刚刚下载的GoogleService-Info.plist文件移到 Xcode 项目的根目录中,并将其添加到所有目标中。

Google 服务使用 CocoaPods 来安装和管理依赖项。

  1. 打开一个终端窗口,然后导航到您的应用的 Xcode 项目的位置。 如果没有,请在此文件夹中创建一个 Podfile:
pod init
  1. 打开您的 Podfile 并添加以下内容:
# add pods for desired Firebase products # https://firebase.google.com/docs/ios/setup#available-pods
  1. 保存文件并运行:
pod install

这将为您的应用创建一个.xcworkspace文件。 使用此文件进行应用的所有将来开发。

  1. 要在应用启动时连接到 Firebase,请将以下初始化代码添加到主AppDelegate类中:
import UIKit
import Firebase

@UIApplicationMain
class AppDelegate: UIResponder, UIApplicationDelegate {

  var window: UIWindow?

  func application(_ application: UIApplication,
    didFinishLaunchingWithOptions launchOptions:
      [UIApplicationLaunchOptionsKey: Any]?) -> Bool {
    FirebaseApp.configure()
    return true
  }
}

现在,您都可以在 Android 项目中使用 Firebase。

设置 VS 代码

Visual StudioVS)Code 是由 Microsoft 开发的轻型代码编辑器。 它的简单性和广泛的插件存储库使其成为开发人员的便捷工具。 凭借其 Dart 和 Flutter 插件,以及应用执行和调试支持,Flutter 应用非常易于开发。

在接下来的部分中,我们将演示如何设置 VS Code 以开发 Flutter 应用。 我们将从这里下载最新版本的 VS Code 开始。

安装 Flutter 和 Dart 插件

首先,我们需要在 VS Code 上安装 Flutter 和 Dart 插件。

可以按照以下步骤进行:

  1. 在计算机上加载 VS Code。
  2. 导航到“查看 | 命令面板”。
  3. 开始输入install,然后选择扩展:安装扩展。
  4. 在扩展搜索字段中键入flutter,从列表中选择 Flutter,然后单击安装。 这还将安装所需的 Dart 插件。
  5. 或者,您可以导航到侧栏来安装和搜索扩展:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第66张图片

成功安装 Flutter 和 Dart 扩展后,我们需要验证设置。 下一节将对此进行描述。

用 Flutter Doctor 验证设置

通常建议您验证设置以确保一切正常。

Flutter 安装可以通过以下方式验证:

  1. 导航到“查看 | 命令面板”。
  2. 输入doctor,然后选择Flutter: Run Flutter Doctor
  3. 查看“输出”窗格中的输出。 输出中列出了所有错误或缺少库。
  4. 另外,您可以在终端上运行flutter doctor来检查一切是否正常:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第67张图片

上面的屏幕快照显示,尽管 Flutter 很好用,但其他一些相关的配置却丢失了。 在这种情况下,您可能需要安装所有支持软件并重新运行flutter doctor以分析设置。

在 VS Code 上成功设置 Flutter 之后,我们可以继续创建我们的第一个 Flutter 应用。

创建第一个 Flutter 应用

创建第一个 Flutter 应用非常简单。 执行以下步骤:

  1. 导航到“查看 | 命令面板”。

  2. 开始输入flutter,然后选择Flutter: New Project

  3. 输入项目名称,例如my_sample_app

  4. 点击Enter

  5. 创建或选择新项目文件夹的父目录。

  6. 等待项目创建完成,然后显示main.dart文件。

有关更多详细信息,请参阅这个页面上的文档。

在下一节中,我们将讨论如何运行您的第一个 Flutter 应用。

运行应用

一个新的 Flutter 项目的创建带有一个模板代码,我们可以直接在移动设备上运行它。 创建第一个模板应用后,可以尝试如下运行它:

  1. 导航至“VS Code”状态栏(即窗口底部的蓝色栏):

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第68张图片

  1. 从设备选择器区域中选择您喜欢的设备:
  • 如果没有可用的设备,并且要使用设备模拟器,请单击“无设备”并启动模拟器:

  • 您也可以尝试设置用于调试的真实设备。
  1. 单击设置按钮-位于右上角的齿轮图标齿轮(现已标记为红色或橙色指示器),位于DEBUG文本框旁边,显示为No Configuration。 选择 Flutter,然后选择调试配置以创建仿真器(如果已关闭)或运行仿真器或已连接的设备。
  2. 导航到“调试 | 开始调试”或按F5
  3. 等待应用启动,进度会显示在DEBUG CONSOLE视图中:

应用构建完成后,您应该在设备上看到已初始化的应用:

TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11_第69张图片

在下一节中,我们将介绍 Flutter 的热重载功能,该功能有助于快速开发。

尝试热重载

Flutter 提供的快速开发周期使其适合于时间优化的开发。 它支持有状态热重载,这意味着您可以重载正在运行的应用的代码,而不必重新启动或丢失应用状态。 热重装可以描述为一种方法,您可以通过该方法对应用源进行更改,告诉命令行工具您要热重装,并在几秒钟内在设备或仿真器上查看更改。

在 VS Code 中,可以按以下方式执行热重装:

  1. 打开lib/main.dart

  2. You have pushed the button this many times:字符串更改为You have clicked the button this many times:。 不要停止您的应用。 让您的应用运行。

  3. 保存更改:调用全部保存,或单击Hot Reload

你可能感兴趣的:(人工智能,人工智能,tensorflow,经验分享,笔记,计算机视觉)