SimpleNN 是 AdamYuan 在高中一年级时用 1 天时间写出来的简易 CNN, 使用 SFML 做 UI, 用于交互式输入手写数字,这个数字被训练好的 CNN 网络执行推理得到识别结果, 它的运行效果如下:
这一篇我们来分析 UI 界面的代码, 规划如下:
实际用时: 10:40~14:30
原版代码使用 Makefile, 其中添加了 -std=c++11, 换了 g++ 为 clang++, 我是在 macOS 下:
all: MnistTrainer MnistUI MnistTest
MnistTrainer: mnist_trainer.cpp */*.hpp */*.cpp
clang++ -std=c++11 mnist_trainer.cpp */*.cpp -Ofast -o MnistTrainer -lm -lpthread
MnistUI: mnist_ui.cpp NN/NN.* NN/Util.hpp MNIST/Util.hpp
clang++ -std=c++11 mnist_ui.cpp NN/NN.cpp -Ofast -o MnistUI -lm -lsfml-system -lsfml-window -lsfml-graphics
MnistTest: mnist_test.cpp NN/NN.* MNIST/Loader.* NN/Util.hpp MNIST/Util.hpp
clang++ -std=c++11 mnist_test.cpp NN/NN.cpp MNIST/Loader.cpp -Ofast -o MnistTest -lm
为什么不用 Makefile: 因为 makefile 没有内置的包管理器, pkg-config 配置多个包的话感觉很麻烦. 使用 CMake 稍微缓解一些。
找到了 3 个 main(
函数, 和 makefile 里的 3 个 target 对应:
➜ SimpleNN git:(master) ✗ ag 'main\(' --ignore-dir build
mnist_ui.cpp
113:int main(int argc, char **argv)
mnist_test.cpp
6:int main(int argc, char **argv)
mnist_trainer.cpp
7:int main(int argc, char **argv)
对于 UI 界面显示, 不需要 mnist_trainer.cpp
和 mnist_test.cpp
, 因此写出 CMakeLists.txt:
cmake_minimum_required(VERSION 3.20)
project(SimpleNN)
set(CMAKE_CXX_STANDARD 11)
add_executable(MnistUI
mnist_ui.cpp
MNIST/Loader.cpp
NN/NN.cpp
NN/Trainer.cpp
)
find_package(SFML 2.6 COMPONENTS system window graphics REQUIRED)
target_link_libraries(MnistUI PRIVATE
pthread
sfml-system
sfml-window
sfml-graphics
)
为了后续源码分析和测试方便, 再增加一个 MnistUI_my
的可执行文件目标:
add_executable(MnistUI_my
mnist_ui_my.cpp
MNIST/Loader.cpp
NN/NN.cpp
NN/Trainer.cpp
)
target_link_libraries(MnistUI_my PRIVATE
pthread
sfml-system
sfml-window
sfml-graphics
)
拆解为: 确定 UI 相关的代码文件; 粗略分析 UI 代码组成部分.
涉及的文件:
mnist_ui.cpp
: UI 代码, 170 行ui/VCR_OSD_MONO_1.001.ttf
: 字体文件下面是 mnist_ui.cpp
的简单解读:
使用了全局变量 snn
, 从传入的参数表示的文件来加载 cnn 网络相关的内容:
SimpleNN snn;
int main(int argc, char **argv)
{
if(argc != 2)
{
printf("Usage: ./MnistUI [snn filename]\n");
return EXIT_FAILURE;
}
snn.Load(argv[1]);
...
}
InitWindow(); // 窗口部件的创建、 布局的设定
Clear(); // 设定鼠标绘制区域的颜色
while(window.isOpen())
{
while(window.pollEvent(event))
{
// 事件处理
}
// 如果鼠标左键按下了, 那么渲染鼠标的轨迹
if(mouse_down)
Paint();
window.draw(paint_sprite);
// 渲染输入纹理
window.draw(input_sprite);
// 渲染输出纹理
window.draw(output_sprite);
// 渲染输出数字纹理
window.draw(output_digits_sprite);
// 渲染鼠标为圆形
Cursor();
window.display(); // 绘制
}
这一节是通过拆解 UI 代码的部件, 对每个部件进行代码粗略分析, 并摘录出用到的代码到单独的文件 Mnist_UI_my.cpp 中验证效果.
整体布局
这一小节,需要看的是 InitWindow()
函数, 以及 main()
函数里 window.draw()
相关的几句调用。
在 InitWindow()
里, 设置了各个部件的大小:
window.draw(sf::Sprite(paint_tex.getTexture()));
sf::Sprite input_sprite{input_tex.getTexture()};
input_sprite.setPosition(kSize, 0);
window.draw(input_sprite);
sf::Sprite output_sprite{output_tex.getTexture()};
output_sprite.setPosition(kSize*2, 0);
window.draw(output_sprite);
InitWindow() 详细注释
void InitWindow()
{
window.create(sf::VideoMode(kSize*2 + kOutSize, kSize), "Mnist Demo", sf::Style::Titlebar | sf::Style::Close);
paint_tex.create(kSize, kSize); // kSize=20*28, 这是560x560方形纹理
input_tex.create(kSize, kSize);
output_tex.create(kOutSize, kSize); // kOutSize=kSize/10=2*28=56, 56x560的大小
output_digits_tex.create(kOutSize, kSize); // 56x560的大小, 是一个竖条形状
sf::Font font; font.loadFromFile("./ui/VCR_OSD_MONO_1.001.ttf");
sf::Text text;
text.setFont(font); text.setCharacterSize(kOutSize);
text.setFillColor(sf::Color(0, 0, 0, 255));
// 竖条分成 10 部分, 每个部分是 56x56 的方格, 每个方格绘制一个数字
for(unsigned i = 0; i < 10; ++i)
{
text.setPosition(0, i * kOutSize);
text.setString(std::to_string(i));
output_digits_tex.draw(text);
}
output_digits_tex.display();
// sf::CircleShape brush_circle, cursor_circle; 这里猜测是鼠标绘制时, 鼠标自身 以及 刷子 的形状
brush_circle.setFillColor(sf::Color(0, 0, 0));
cursor_circle.setFillColor(sf::Color(0, 0, 0, 100));
brush_circle.setRadius(radius);
cursor_circle.setRadius(radius);
// sf::RectangleShape input_rect, output_rect; 这里暂时没看出来用途。
input_rect.setSize(sf::Vector2f(kGridSize, kGridSize)); //20x20
output_rect.setSize(sf::Vector2f(kOutSize, kOutSize)); //56x56
}
Clear()函数
void Clear()
{
paint_tex.clear(sf::Color(255, 255, 255));
}
Clear()
把屏幕左侧的 paint_tex 区域背景颜色设定为白色.
完整代码
这里说的完整代码, 是把刚刚分析的代码摘录出来, 放到 Mnist_UI_my.cpp 里, 并编译运行
#include
sf::RenderWindow window;
sf::Event event;
constexpr int kGridSize = 20, kSize = 28*kGridSize, kOutSize = kSize / 10;
constexpr float kMinRadius = 8.0, kMaxRadius = 30.0, kRadiusStep = 1.0;
sf::RenderTexture paint_tex, input_tex, output_tex, output_digits_tex;
float radius{(kMinRadius + kMaxRadius) * 0.5f};
sf::CircleShape brush_circle, cursor_circle;
sf::RectangleShape input_rect, output_rect;
void InitWindow()
{
window.create(sf::VideoMode(kSize*2 + kOutSize, kSize), "Mnist Demo", sf::Style::Titlebar | sf::Style::Close);
paint_tex.create(kSize, kSize);
input_tex.create(kSize, kSize);
output_tex.create(kOutSize, kSize);
output_digits_tex.create(kOutSize, kSize);
const std::string asset_dir = "../";
sf::Font font; font.loadFromFile(asset_dir+"/ui/VCR_OSD_MONO_1.001.ttf");
sf::Text text;
text.setFont(font); text.setCharacterSize(kOutSize);
text.setFillColor(sf::Color(0, 0, 0, 255));
for(unsigned i = 0; i < 10; ++i)
{
text.setPosition(0, i * kOutSize);
text.setString(std::to_string(i));
output_digits_tex.draw(text);
}
output_digits_tex.display();
brush_circle.setFillColor(sf::Color(0, 0, 0));
cursor_circle.setFillColor(sf::Color(0, 0, 0, 100));
brush_circle.setRadius(radius);
cursor_circle.setRadius(radius);
input_rect.setSize(sf::Vector2f(kGridSize, kGridSize));
output_rect.setSize(sf::Vector2f(kOutSize, kOutSize));
}
void Clear()
{
paint_tex.clear(sf::Color(255, 255, 255));
}
int main()
{
InitWindow();
Clear();
while(window.isOpen())
{
while(window.pollEvent(event))
{
if(event.type == sf::Event::EventType::Closed)
{
window.close();
}
}
sf::Sprite paint_sprite{paint_tex.getTexture()};
auto paint_sprite_position = paint_sprite.getPosition();
printf("paint_sprite_position: %f, %f\n", paint_sprite_position.x, paint_sprite_position.y);
window.draw(sf::Sprite(paint_tex.getTexture()));
sf::Sprite input_sprite{input_tex.getTexture()};
input_sprite.setPosition(kSize, 0);
window.draw(input_sprite);
sf::Sprite output_sprite{output_tex.getTexture()};
output_sprite.setPosition(kSize*2, 0);
window.draw(output_sprite);
sf::Sprite output_digits_sprite{output_digits_tex.getTexture()};
output_digits_sprite.setPosition(kSize*2, 0);
window.draw(output_digits_sprite);
window.display();
}
return 0;
}
由于省略了 event 的处理, 鼠标事件自然是没有响应的, 界面非常枯燥, 看起来只有左右的白色、黑色两个部分:
需要先开启鼠标和键盘事件的处理, 然后再启用 paint_tex 的绘制。
处理鼠标事件
main()
函数里处理鼠标事件:
while(window.pollEvent(event))
{
...
if(event.type == sf::Event::EventType::MouseButtonPressed)
mouse_down = true;
if(event.type == sf::Event::EventType::MouseButtonReleased)
mouse_down = false;
}
if(mouse_down)
Paint();
处理键盘事件
main()
函数中处理键盘事件: 如果用户按下了空格键, 那么调用 Clear()
函数来把左侧输入区域显示的内容清空:
while(window.pollEvent(event))
{
...
if(event.type == sf::Event::EventType::KeyReleased
&& event.key.code == sf::Keyboard::Space)
{
// window.setTitle("Recognize: " + std::to_string(Recognize())); 目前不需要调用 Recognize函数,先注释掉
Clear();
}
}
由于 Clear()
本身是一个不复杂的函数调用, 仅仅是把 input_tex 这个纹理的颜色设定为白色。 如果是稍微耗时一些的任务,通常是在事件处理函数的地方做判断, 在外部处理。
void Clear()
{
paint_tex.clear(sf::Color(255, 255, 255));
}
绘制 paint 区域
调用的 Paint()
函数是本小节的关键
void Paint()
{
// 获取鼠标在窗口 window 内的位置
sf::Vector2i xy = sf::Mouse::getPosition(window);
// 如果鼠标坐标在窗口内部
if(xy.x >= 0 && xy.x < kSize && xy.y >= 0 && xy.y < kSize)
{
// 如果鼠标不在左侧的 input_tex 范围, 那么就做 clip
int x = std::max(0, std::min(xy.x, kSize)) - radius;
// 在纵向方向上, 也做了 clip, 因此如果打算在界面布局上再增加底栏,也是能处理鼠标在 input_tex 的显示的
int y = std::max(0, std::min(xy.y, kSize)) - radius;
// 设置笔刷的坐标
brush_circle.setPosition(x, y);
// 在 paint_tex 上绘制笔刷
paint_tex.draw(brush_circle);
}
paint_tex.display();
}
其中存在 sf::CirleShape
-> sf::Texture
的对象“存放”关系: 把一个 shape 存放到一个 texture 中。
而在 main()
中则进一步做了 sf::Texture
-> sf::Sprite
的处理:
window.draw(sf::Sprite(paint_tex.getTexture()));
在官方教程 https://www.sfml-dev.org/tutorials/2.6/graphics-sprite.php 里给出了解释:
Most (if not all) of you are already familiar with these two very common objects, so let’s define them very briefly.
A texture is an image. But we call it “texture” because it has a very specific role: being mapped to a 2D entity.
A sprite is nothing more than a textured rectangle.
纹理(texture)是一幅图像(image)。但我们称它为 texture,因为它有一个非常具体的作用:被映射到一个2D实体上。
精灵(sprite)只不过是一个带有纹理的矩形.
为什么使用 texture + sprite, 而不是 RectangleShape?
从 SFML 的代码层更容易理解: window.draw()
我们目前写过的代码, 主要是绘制形状, 也绘制过顶点 sf::Vertex
. 对于绘制形状:
class Window
{
public:
...
void draw(const Drawable& drawable, const RenderStates& states = RenderStates::Default);
};
因此, 如果要绘制 texture, 就需要让 texture 继承自 sf::Drawable
. 但是 sf::Texture
和 sf::RenderTexture
都没有继承自 sf::Drawable
:
class SFML_GRAPHICS_API Texture : GlResource
{
...
};
class SFML_GRAPHICS_API RenderTexture : public RenderTarget
{
...
};
而 sf::Sprite
则是继承了 sf::Drawable
, 并且能从 sf::Texture
创建对象:
class SFML_GRAPHICS_API Sprite : public Drawable, public Transformable
{
public:
explicit Sprite(const Texture& texture); // 从整个 texture 创建 sprite
Sprite(const Texture& texture, const IntRect& rectangle); // 从 ROI 创建 sprite
...
};
因此, 目前遇到的三种绘制方式:
sf::CircleShape
-> window.draw(circle)
sf::Vertex
-> window.draw(vertex, 2, sf::Lines)
sf::CirleShape
-> sf::Texture
-> sf::Sprite
-> window.draw(sprite)
第三种方式中的 Sprite 是为了承载 Texture, 那么 Texture 是为了什么呢? 准确的说, 是 sf::RenderTexture
对象的 .getTexture()
方法返回的 sf::Texture
对象:
sf::RenderTexture paint_tex, input_tex, output_tex, output_digits_tex;
...
sf::Sprite input_sprite{input_tex.getTexture()};
input_sprite.setPosition(kSize, 0);
window.draw(input_sprite);
而 sf::RenderTexture
和 sf::Texture
没有直接的继承关系:
class SFML_GRAPHICS_API RenderTexture : public RenderTarget
{
...
};
对于 input_tex
这个 sf::RenderTexture
来说, 它仅仅是被创建 (.create()
), 然后就没有主动调用什么方法了; input_sprite
则是对它设定了位置:
input_tex.create(kSize, kSize);
sf::Sprite input_sprite{input_tex.getTexture()};
input_sprite.setPosition(kSize, 0);
window.draw(input_sprite);
为什么能设定位置? 因为 sf::Sprite
继承了 Transformable
类:
class SFML_GRAPHICS_API Sprite : public Drawable, public Transformable
看起来好像用 sf::RectangleShape
也能完成同样功能, GPT4 给的解释是:
sf::Texture
这个纹理数据是被上传到 GPU 显存中, GPU 处理的速度快; 如果有多个 sf::Sprite
实例共享使用同一个 texture, 那么不需要重新上传, 只需要上传一次, 减少了显存使用和数据传输的开销。
完整的代码
把用到的代码抽取出来, 放到 Mnist_UI_my.cpp 中, 本节的代码能够在左侧区域中,使用鼠标绘制, 使用空格键清理:
#include
sf::RenderWindow window;
sf::Event event;
constexpr int kGridSize = 20, kSize = 28*kGridSize, kOutSize = kSize / 10;
constexpr float kMinRadius = 8.0, kMaxRadius = 30.0, kRadiusStep = 1.0;
sf::RenderTexture paint_tex, input_tex, output_tex, output_digits_tex;
float radius{(kMinRadius + kMaxRadius) * 0.5f};
sf::CircleShape brush_circle, cursor_circle;
sf::RectangleShape input_rect, output_rect;
void InitWindow()
{
window.create(sf::VideoMode(kSize*2 + kOutSize, kSize), "Mnist Demo", sf::Style::Titlebar | sf::Style::Close);
paint_tex.create(kSize, kSize);
input_tex.create(kSize, kSize);
output_tex.create(kOutSize, kSize);
output_digits_tex.create(kOutSize, kSize);
const std::string asset_dir = "../";
sf::Font font; font.loadFromFile(asset_dir+"/ui/VCR_OSD_MONO_1.001.ttf");
sf::Text text;
text.setFont(font); text.setCharacterSize(kOutSize);
text.setFillColor(sf::Color(0, 0, 0, 255));
for(unsigned i = 0; i < 10; ++i)
{
text.setPosition(0, i * kOutSize);
text.setString(std::to_string(i));
output_digits_tex.draw(text);
}
output_digits_tex.display();
brush_circle.setFillColor(sf::Color(0, 0, 0));
cursor_circle.setFillColor(sf::Color(0, 0, 0, 100));
brush_circle.setRadius(radius);
cursor_circle.setRadius(radius);
input_rect.setSize(sf::Vector2f(kGridSize, kGridSize));
output_rect.setSize(sf::Vector2f(kOutSize, kOutSize));
}
void Clear()
{
paint_tex.clear(sf::Color(255, 255, 255));
}
void Paint()
{
sf::Vector2i xy = sf::Mouse::getPosition(window);
if(xy.x >= 0 && xy.x < kSize && xy.y >= 0 && xy.y < kSize)
{
int x = std::max(0, std::min(xy.x, kSize)) - radius, y = std::max(0, std::min(xy.y, kSize)) - radius;
brush_circle.setPosition(x, y);
paint_tex.draw(brush_circle);
}
paint_tex.display();
}
int main()
{
InitWindow();
Clear();
bool mouse_down = false;
while(window.isOpen())
{
while(window.pollEvent(event))
{
if(event.type == sf::Event::EventType::Closed)
{
window.close();
}
if(event.type == sf::Event::EventType::KeyReleased
&& event.key.code == sf::Keyboard::Space)
{
//window.setTitle("Recognize: " + std::to_string(Recognize()));
Clear();
}
if(event.type == sf::Event::EventType::MouseButtonPressed)
mouse_down = true;
if(event.type == sf::Event::EventType::MouseButtonReleased)
mouse_down = false;
}
if(mouse_down)
Paint();
sf::Sprite paint_sprite{paint_tex.getTexture()};
auto paint_sprite_position = paint_sprite.getPosition();
printf("paint_sprite_position: %f, %f\n", paint_sprite_position.x, paint_sprite_position.y);
window.draw(sf::Sprite(paint_tex.getTexture()));
sf::Sprite input_sprite{input_tex.getTexture()};
input_sprite.setPosition(kSize, 0);
window.draw(input_sprite);
sf::Sprite output_sprite{output_tex.getTexture()};
output_sprite.setPosition(kSize*2, 0);
window.draw(output_sprite);
sf::Sprite output_digits_sprite{output_digits_tex.getTexture()};
output_digits_sprite.setPosition(kSize*2, 0);
window.draw(output_digits_sprite);
window.display();
}
return 0;
}
所谓 input 纹理, 说的是把窗口左侧的 paint 区域得到的内容, 做处理后, 能够作为 cnn 网络输入的时候(或者之前一点点), 这个处理过的输入是什么样子。 换言之, 是 CNN 网络看到的图像对应的纹理, 我们对它做一个可视化。 可视化的时候, 为了看的清楚, 肯定不是 28x28 那么小的输入,但是 cnn 网络的输入大概是 28x28 的大小。
本小节我们只关注 input 区域的显示, 不关注 cnn 网络的推理, 因此需要展开 Recognize()
函数的大部分, 但也略去其中 snn
对象的 evaluate()
等方法的调用, 也就省略了最终预测结果中的数字的显示。
从键盘事件到Recognize
回顾 main() 中的键盘处理:
if(event.type == sf::Event::EventType::KeyReleased
&& event.key.code == sf::Keyboard::Space)
{
window.setTitle("Recognize: " + std::to_string(Recognize()));
Clear();
}
按下空格键后会执行 Recognize()
Recognize()浅析
Recognize()
函数, 将 paint_tex
区域手绘的内容, 拷贝一份独立的图像, 并将每个 20x20 大小的网格“捏成一个像素”, 捏的手法类似于 area resize / average pooling, 但是原始像素被 0/1 二值化处理了, 因此相当于先做阈值为 1 的二值化, 然后做 area resize, 得到了 28x28=764 大小的一维数组 nn_input, 每个元素是 [0, 1] 范围的浮点数。
对于 nn_input 每个元素, 为了在 input_tex
显示, 让每个像素映射到 [0, 255] 范围整数, 并且 ”填充“ 到 20x20 的区域, 这和原本的 "捏” 动作相反, 但是由于“捏”的过程中已经做了二值化处理, 因此现在 “填充” 回去的时候, 效果是 “像素化” 的。
width_normalize()
函数意义不明, 先注释掉。
至于 snn 网络的推理, 现在先把代码注释掉。
unsigned Recognize()
{
// 根据 paint 区域绘制的纹理, 创建独立的图像拷贝
sf::Image img{paint_tex.getTexture().copyToImage()};
// 获取图像像素的 raw buffer
const sf::Uint8 *ptr = img.getPixelsPtr();
// 网络输入是 28x28=784 大小,float 类型
std::vector<float> nn_input(784);
// 将每个 grid 区域(kGridSize x kGridSize, 20x20) 捏成一个像素
for(unsigned i = 0; i < 784; ++i)
{
float v = 0.0;
unsigned gx = i % 28;
unsigned gy = i / 28;
unsigned px = gx * (kGridSize << 2);
unsigned py = gy * kGridSize;
// 对于每个 20x20 大小的方格, 如果不是 0,那么计数器加 1, 如果是 0 则计数器不变
for(unsigned y = py; y < py + kGridSize; ++y)
{
for(unsigned x = px; x < px + (kGridSize << 2); x += 4)
{
v += float(ptr[y * (kSize << 2) + x] == 0);
}
}
// 统计了 20x20 方格区域内非 0 元素数量 v, 数量 v 除以总数 20x20, 这个比值作为 28x28 网络输入的一个元素。
nn_input[i] = v / float(kGridSize * kGridSize);
}
// width_normalize(&nn_input); 先不调用它,看是什么效果
for(unsigned i = 0; i < 784; ++i)
{
// 把 nn_input[i], 从 [0, 1] 范围的浮点数转到 [0, 255] 范围的整数 c
unsigned c = 255 * nn_input[i];
c = std::min(c, 255u);
// 在 20x20 的区域内, 绘制相同的颜色 c
unsigned gx = i % 28;
unsigned gy = i / 28;
input_rect.setPosition(gx * kGridSize, gy * kGridSize);
input_rect.setFillColor(sf::Color(c, c, c, 255));
input_tex.draw(input_rect); // 在一个 texture 的 ROI 区域上进行绘制
//putchar(nn_input[i] >= 0.25 ? (nn_input[i] >= 0.5 ? (nn_input[i] >= 0.75 ? '@' : '?') : '.') : ' ');
//if(i % 28 == 27) putchar('\n');
}
input_tex.display(); // 更新 target texture 内容。 如果不调用,我观察到的是上下颠倒的内容
// 先不看 output 的处理
#if 0
{
snn.Evaluate(nn_input);
unsigned res = std::max_element(snn.GetOutput(), snn.GetOutput() + 10) - snn.GetOutput();
for(unsigned i = 0; i < 10; ++i)
{
unsigned c = 255 * snn.GetOutput()[i];
c = std::min(c, 255u);
output_rect.setPosition(0, i * kOutSize);
output_rect.setFillColor(sf::Color(c, c, c, 255));
output_tex.draw(output_rect);
}
output_tex.display();
}
#endif
return 0;
}
补充说明 input_tex.display()
的调用: 它是更新纹理绘制的内容, 如果不调用, 那么内容是 “垃圾值”, 我在 M1 mac-mini 上的结果是, 不调用它会得到上下颠倒的内容。
效果和代码
#include
sf::RenderWindow window;
sf::Event event;
constexpr int kGridSize = 20, kSize = 28*kGridSize, kOutSize = kSize / 10;
constexpr float kMinRadius = 8.0, kMaxRadius = 30.0, kRadiusStep = 1.0;
sf::RenderTexture paint_tex, input_tex, output_tex, output_digits_tex;
float radius{(kMinRadius + kMaxRadius) * 0.5f};
sf::CircleShape brush_circle, cursor_circle;
sf::RectangleShape input_rect, output_rect;
void InitWindow()
{
window.create(sf::VideoMode(kSize*2 + kOutSize, kSize), "Mnist Demo", sf::Style::Titlebar | sf::Style::Close);
paint_tex.create(kSize, kSize);
input_tex.create(kSize, kSize);
output_tex.create(kOutSize, kSize);
output_digits_tex.create(kOutSize, kSize);
const std::string asset_dir = "../";
sf::Font font; font.loadFromFile(asset_dir+"/ui/VCR_OSD_MONO_1.001.ttf");
sf::Text text;
text.setFont(font); text.setCharacterSize(kOutSize);
text.setFillColor(sf::Color(0, 0, 0, 255));
for(unsigned i = 0; i < 10; ++i)
{
text.setPosition(0, i * kOutSize);
text.setString(std::to_string(i));
output_digits_tex.draw(text);
}
output_digits_tex.display();
brush_circle.setFillColor(sf::Color(0, 0, 0));
cursor_circle.setFillColor(sf::Color(0, 0, 0, 100));
brush_circle.setRadius(radius);
cursor_circle.setRadius(radius);
input_rect.setSize(sf::Vector2f(kGridSize, kGridSize));
output_rect.setSize(sf::Vector2f(kOutSize, kOutSize));
}
void Clear()
{
paint_tex.clear(sf::Color(255, 255, 255));
}
void Paint()
{
sf::Vector2i xy = sf::Mouse::getPosition(window);
if(xy.x >= 0 && xy.x < kSize && xy.y >= 0 && xy.y < kSize)
{
int x = std::max(0, std::min(xy.x, kSize)) - radius, y = std::max(0, std::min(xy.y, kSize)) - radius;
brush_circle.setPosition(x, y);
paint_tex.draw(brush_circle);
}
paint_tex.display();
}
unsigned Recognize()
{
sf::Image img{paint_tex.getTexture().copyToImage()};
const sf::Uint8 *ptr = img.getPixelsPtr();
std::vector<float> nn_input(784);
for(unsigned i = 0; i < 784; ++i)
{
float v = 0.0;
unsigned gx = i % 28, gy = i / 28;
unsigned px = gx * (kGridSize << 2), py = gy * kGridSize;
for(unsigned y = py; y < py + kGridSize; ++y)
for(unsigned x = px; x < px + (kGridSize << 2); x += 4)
v += float(ptr[y * (kSize << 2) + x] == 0);
nn_input[i] = v / float(kGridSize * kGridSize);
}
// width_normalize(&nn_input);
for(unsigned i = 0; i < 784; ++i)
{
unsigned gx = i % 28, gy = i / 28, c = 255 * nn_input[i];
c = std::min(c, 255u);
input_rect.setPosition(gx * kGridSize, gy * kGridSize);
input_rect.setFillColor(sf::Color(c, c, c, 255));
input_tex.draw(input_rect);
//putchar(nn_input[i] >= 0.25 ? (nn_input[i] >= 0.5 ? (nn_input[i] >= 0.75 ? '@' : '?') : '.') : ' ');
//if(i % 28 == 27) putchar('\n');
}
input_tex.display(); // 更新 target texture 内容。 如果不调用,我观察到的是上下颠倒的内容
#if 0
{
snn.Evaluate(nn_input);
unsigned res = std::max_element(snn.GetOutput(), snn.GetOutput() + 10) - snn.GetOutput();
for(unsigned i = 0; i < 10; ++i)
{
unsigned c = 255 * snn.GetOutput()[i];
c = std::min(c, 255u);
output_rect.setPosition(0, i * kOutSize);
output_rect.setFillColor(sf::Color(c, c, c, 255));
output_tex.draw(output_rect);
}
output_tex.display();
}
#endif
return 0;
}
int main()
{
InitWindow();
Clear();
bool mouse_down = false;
while(window.isOpen())
{
while(window.pollEvent(event))
{
if(event.type == sf::Event::EventType::Closed)
{
window.close();
}
if(event.type == sf::Event::EventType::KeyReleased
&& event.key.code == sf::Keyboard::Space)
{
window.setTitle("Recognize: " + std::to_string(Recognize()));
Clear();
}
if(event.type == sf::Event::EventType::MouseButtonPressed)
mouse_down = true;
if(event.type == sf::Event::EventType::MouseButtonReleased)
mouse_down = false;
}
if(mouse_down)
Paint();
sf::Sprite paint_sprite{paint_tex.getTexture()};
auto paint_sprite_position = paint_sprite.getPosition();
printf("paint_sprite_position: %f, %f\n", paint_sprite_position.x, paint_sprite_position.y);
window.draw(sf::Sprite(paint_tex.getTexture()));
sf::Sprite input_sprite{input_tex.getTexture()};
input_sprite.setPosition(kSize, 0);
window.draw(input_sprite);
sf::Sprite output_sprite{output_tex.getTexture()};
output_sprite.setPosition(kSize*2, 0);
window.draw(output_sprite);
sf::Sprite output_digits_sprite{output_digits_tex.getTexture()};
output_digits_sprite.setPosition(kSize*2, 0);
window.draw(output_digits_sprite);
window.display();
}
return 0;
}
加载网络文件
int main(int argc, char **argv)
{
if(argc != 2)
{
printf("Usage: ./MnistUI [snn filename]\n");
return EXIT_FAILURE;
}
snn.Load(argv[1]);
...
}
width_normalize(): 裁剪掉无效图像区域
没调用 width_normalize()
时, input_tex
里存在大量空白区域(黑色), 数字大小和绘制大小一样的;
调用 width_normalize()
后, 相当于获取了 bounding box, 并将 bounding box 外部的区域建材掉, 将剩余的有效区域像素放大到了 28x28 大小; 识别准确率也上来了:
关于 width_normalize()
的源码, 本篇不做分析, 下一篇剖析 SimpleNN 实现的代码时再分析。
鼠标滚轮控制 cursor 大小
while(window.isOpen())
{
while(window.pollEvent(event))
{
...
if(event.type == sf::Event::EventType::MouseWheelScrolled)
{
radius += kRadiusStep * (event.mouseWheel.x > 0 ? -1 : 1);
radius = std::min(std::max(kMinRadius, radius), kMaxRadius);
brush_circle.setRadius(radius);
cursor_circle.setRadius(radius);
}
}
}
把鼠标形状改为圆球: Cursor()
实际上是鼠标周围一圈有一个圆形, 就像是拖着一个墨球:
int main()
{
while() {
while() {
...
sf::Sprite output_digits_sprite{output_digits_tex.getTexture()};
output_digits_sprite.setPosition(kSize*2, 0);
window.draw(output_digits_sprite);
Cursor(); /// 此处修改鼠标形状
window.display();
}
}
void Cursor()
{
sf::Vector2i xy = sf::Mouse::getPosition(window);
if(xy.x >= 0 && xy.x < kSize && xy.y >= 0 && xy.y < kSize)
{
int x = std::max(0, std::min(xy.x, kSize)) - radius, y = std::max(0, std::min(xy.y, kSize)) - radius;
cursor_circle.setPosition(x, y);
window.draw(cursor_circle);
}
}
这里贴出我做测试、添加了一些注释的 Mnist_UI_my.cpp 代码, 大部分是本篇解读过的, SimpleNN snn
对应的 NN.hpp
, 以及 width_normalize()
对应的 MNIST/Util.hpp
则不在这个文件里, 使用原版的。
#include
#include "NN/NN.hpp"
#include "MNIST/Util.hpp"
sf::RenderWindow window;
sf::Event event;
constexpr int kGridSize = 20, kSize = 28*kGridSize, kOutSize = kSize / 10;
constexpr float kMinRadius = 8.0, kMaxRadius = 30.0, kRadiusStep = 1.0;
sf::RenderTexture paint_tex, input_tex, output_tex, output_digits_tex;
float radius{(kMinRadius + kMaxRadius) * 0.5f};
sf::CircleShape brush_circle, cursor_circle;
sf::RectangleShape input_rect, output_rect;
SimpleNN snn;
void InitWindow()
{
window.create(sf::VideoMode(kSize*2 + kOutSize, kSize), "Mnist Demo", sf::Style::Titlebar | sf::Style::Close);
paint_tex.create(kSize, kSize);
input_tex.create(kSize, kSize);
output_tex.create(kOutSize, kSize);
output_digits_tex.create(kOutSize, kSize);
const std::string asset_dir = "../";
sf::Font font; font.loadFromFile(asset_dir+"/ui/VCR_OSD_MONO_1.001.ttf");
sf::Text text;
text.setFont(font); text.setCharacterSize(kOutSize);
text.setFillColor(sf::Color(0, 0, 0, 255));
for(unsigned i = 0; i < 10; ++i)
{
text.setPosition(0, i * kOutSize);
text.setString(std::to_string(i));
output_digits_tex.draw(text);
}
output_digits_tex.display();
brush_circle.setFillColor(sf::Color(0, 0, 0));
cursor_circle.setFillColor(sf::Color(0, 0, 0, 100));
brush_circle.setRadius(radius);
cursor_circle.setRadius(radius);
input_rect.setSize(sf::Vector2f(kGridSize, kGridSize));
output_rect.setSize(sf::Vector2f(kOutSize, kOutSize));
}
void Clear()
{
paint_tex.clear(sf::Color(255, 255, 255));
}
void Cursor()
{
sf::Vector2i xy = sf::Mouse::getPosition(window);
if(xy.x >= 0 && xy.x < kSize && xy.y >= 0 && xy.y < kSize)
{
int x = std::max(0, std::min(xy.x, kSize)) - radius, y = std::max(0, std::min(xy.y, kSize)) - radius;
cursor_circle.setPosition(x, y);
window.draw(cursor_circle);
}
}
void Paint()
{
sf::Vector2i xy = sf::Mouse::getPosition(window);
if(xy.x >= 0 && xy.x < kSize && xy.y >= 0 && xy.y < kSize)
{
int x = std::max(0, std::min(xy.x, kSize)) - radius, y = std::max(0, std::min(xy.y, kSize)) - radius;
brush_circle.setPosition(x, y);
paint_tex.draw(brush_circle);
}
paint_tex.display();
}
unsigned Recognize()
{
// 根据 paint 区域绘制的纹理, 创建独立的图像拷贝
sf::Image img{paint_tex.getTexture().copyToImage()};
// 获取图像像素的 raw buffer
const sf::Uint8 *ptr = img.getPixelsPtr();
// 网络输入是 28x28=784 大小,float 类型
std::vector<float> nn_input(784);
// 将每个 grid 区域(kGridSize x kGridSize, 20x20) 捏成一个像素
for(unsigned i = 0; i < 784; ++i)
{
float v = 0.0;
unsigned gx = i % 28;
unsigned gy = i / 28;
unsigned px = gx * (kGridSize << 2);
unsigned py = gy * kGridSize;
// 对于每个 20x20 大小的方格, 如果不是 0,那么计数器加 1, 如果是 0 则计数器不变
for(unsigned y = py; y < py + kGridSize; ++y)
{
for(unsigned x = px; x < px + (kGridSize << 2); x += 4)
{
v += float(ptr[y * (kSize << 2) + x] == 0);
}
}
// 统计了 20x20 方格区域内非 0 元素数量 v, 数量 v 除以总数 20x20, 这个比值作为 28x28 网络输入的一个元素。
nn_input[i] = v / float(kGridSize * kGridSize);
}
width_normalize(&nn_input); // 负责砍掉图像周围的空白区域
for(unsigned i = 0; i < 784; ++i)
{
// 把 nn_input[i], 从 [0, 1] 范围的浮点数转到 [0, 255] 范围的整数 c
unsigned c = 255 * nn_input[i];
c = std::min(c, 255u);
// 在 20x20 的区域内, 绘制相同的颜色 c
unsigned gx = i % 28;
unsigned gy = i / 28;
input_rect.setPosition(gx * kGridSize, gy * kGridSize);
input_rect.setFillColor(sf::Color(c, c, c, 255));
input_tex.draw(input_rect); // 在一个 texture 的 ROI 区域上进行绘制
//putchar(nn_input[i] >= 0.25 ? (nn_input[i] >= 0.5 ? (nn_input[i] >= 0.75 ? '@' : '?') : '.') : ' ');
//if(i % 28 == 27) putchar('\n');
}
input_tex.display(); // 更新 target texture 内容。 如果不调用,我观察到的是上下颠倒的内容
#if 1
{
snn.Evaluate(nn_input);
unsigned res = std::max_element(snn.GetOutput(), snn.GetOutput() + 10) - snn.GetOutput();
for(unsigned i = 0; i < 10; ++i)
{
unsigned c = 255 * snn.GetOutput()[i];
c = std::min(c, 255u);
output_rect.setPosition(0, i * kOutSize);
output_rect.setFillColor(sf::Color(c, c, c, 255));
output_tex.draw(output_rect);
}
output_tex.display();
return res;
}
#endif
return 0;
}
int main(int argc, char **argv)
{
if(argc != 2)
{
printf("Usage: ./MnistUI [snn filename]\n");
return EXIT_FAILURE;
}
snn.Load(argv[1]);
InitWindow();
Clear();
bool mouse_down = false;
while(window.isOpen())
{
while(window.pollEvent(event))
{
if(event.type == sf::Event::EventType::Closed)
{
window.close();
}
if(event.type == sf::Event::EventType::KeyReleased
&& event.key.code == sf::Keyboard::Space)
{
window.setTitle("Recognize: " + std::to_string(Recognize()));
Clear();
}
if(event.type == sf::Event::EventType::MouseButtonPressed)
mouse_down = true;
if(event.type == sf::Event::EventType::MouseButtonReleased)
mouse_down = false;
}
if(mouse_down)
Paint();
sf::Sprite paint_sprite{paint_tex.getTexture()};
auto paint_sprite_position = paint_sprite.getPosition();
printf("paint_sprite_position: %f, %f\n", paint_sprite_position.x, paint_sprite_position.y);
window.draw(sf::Sprite(paint_tex.getTexture()));
sf::Sprite input_sprite{input_tex.getTexture()};
input_sprite.setPosition(kSize, 0);
window.draw(input_sprite);
sf::Sprite output_sprite{output_tex.getTexture()};
output_sprite.setPosition(kSize*2, 0);
window.draw(output_sprite);
sf::Sprite output_digits_sprite{output_digits_tex.getTexture()};
output_digits_sprite.setPosition(kSize*2, 0);
window.draw(output_digits_sprite);
Cursor();
window.display();
}
return 0;
}
本篇分析了 SimpleNN 的 UI 部分, 它是基于 SFML 实现的交互式手写数字识别程序, 提供了鼠标绘制手写数字, 空格键触发输入的处理和显示、 网络的推理和结果显示, 即使不会写代码也能使用它执行手写数字识别。
具体的代码分析中, 先从界面布局的划分(paint region, input region, output region) 入手, 然后对 paint, input 两个区域的内容的绘制、 鼠标键盘的交互, 做了比较详细的分析。 而输入数据的归一化, 网络的推理, 分析的稍微粗糙一些, 主要是因为相关图像处理内容比较熟悉, 不是 UI 界面的重点。
对于 Texture 的使用, 增加了一些经验, 目前遇到过的处理方式有这几种:
sf::CircleShape
-> window.draw(circle)
sf::Vertex
-> window.draw(vertex, 2, sf::Lines)
sf::CirleShape
-> sf::Texture
-> sf::Sprite
-> window.draw(sprite)
sf::RectangleShape input_rect;
input_rect.setSize(sf::Vector2f(kGridSize, kGridSize));
input_rect.setPosition(gx * kGridSize, gy * kGridSize);
sf::RenderTexture input_tex;
input_tex.create(kSize, kSize);
input_tex.draw(input_rect); // 区域渲染
input_tex.display(); // update content
这也让我想到前一篇基于 SFML 实现的 tic-tac-toe 井字棋游戏,渲染的代码写的不太好,是对 3x3 每个区域分别绘制纹理,其实可以制作一个整个的纹理, 然后更新每个 grid 区域。
因此后续的方向有这几个: