再读CVPR2022-CrossPoint(附代码)

目录

写在前面

网络架构回顾 

项目架构

DGCNN

ResNet()

DGCNN()

DGCNN_PartSeg

Train CrossPoint

写在后面

参考


 

再读CVPR2022-CrossPoint(附代码)_第1张图片

写在前面

经过上一篇对CrossPoint的阅读,个人对网络整体而言都有了一个大致的理解,但具体怎么实现还是懵懵懂懂,于是,打算配上作者开源在Github上的代码再进行一下阅读。话不多说,直接开始。

项目地址放这了

CrossPoint项目icon-default.png?t=M276https://github.com/MohamedAfham/CrossPoint

网络架构回顾 

再读CVPR2022-CrossPoint(附代码)_第2张图片

简单回顾一下整体的网络架构叭。 

点云,image分别经过各自的extractor(f_{\theta}),接一个mlp(g_{\theta}),然后两个loss function使同类型输入尽可能输出相似。

项目架构

再读CVPR2022-CrossPoint(附代码)_第3张图片

 可以看到,作者把项目分为了四个目录,分别是下载数据集脚本,数据集加载文件(包括数据集加载类,数据处理工具(比如前文提到的对点云的Scale,Rotation,Normalize,Sample等等),任务模型(这里直接使用的是DGCNN作为点云的extractor,Resnet作为image的extractor,最后各自拼一个mlp,点击我查看DGCNN论文),训练与测试脚本(名为CrossPoint的Pretrain,DGCNN的Part Segmentation))。

下面对其具体讲解(下载数据集脚本、数据集的加载与处理不做讲解)。

DGCNN

讲解DGCNN之前,放上DGCNN原文提供的网络架构

再读CVPR2022-CrossPoint(附代码)_第4张图片

这里简单说明一下,spatial transform其实就是point net中使用到的T-Net(对点云模型做一个矫正),然后跟point net++还是挺类似,比如knn,然后max pooling,再然后拼接,吧啦吧啦,创新点就是引入了一个叫edge convolution的东西,后面不断更新边,所以叫dynamic graph,想具体了解的话点击搞懂DGCNN,这篇就够了!

再读CVPR2022-CrossPoint(附代码)_第5张图片

 这是DGCNN的模块,可以看到其中包括了基本dgcnn基本模块(看名字应该能猜到是对应上述的哪个),这里要说一下就是ResNet是残差神经网络(你必须要知道CNN模型:ResNet),做image的extractor的,也就是f_{\theta _{i}},那当然DGCNN(cls=-1)自然就是f_{\theta _{p}}了。

ResNet(f_{\theta_{i}}+g_{\phi _{i}}

再读CVPR2022-CrossPoint(附代码)_第6张图片

 这里的话,其实就在创建模型的时候把pytorch的内置ResNet模型传到model参数里直接使用(点击我了解ResNet),至于哪个Identity,网上说的是一个占一个网络层数的,不改变输入,然后那个inv_head应该就是g_{\phi_{i}}了。

DGCNN(f_{\theta_{p}}+g_{\phi_{p}}

再读CVPR2022-CrossPoint(附代码)_第7张图片

再读CVPR2022-CrossPoint(附代码)_第8张图片

 这里的话其实还是比较简单的,对着前面给的DGCNN架构(上面做分类那个),然后重复的这个EdgeConv,最后做一个拼接得到全局特征,就是extractor了,不过并不需要spatial transform,原因应该是在训练CrossPoint时本来输入就是两个(还有一个是augmented version),cls=-1的作用是判断是不是做分类用的,也就是拼接完后后面调整到类别数的操作。

看一下forward

再读CVPR2022-CrossPoint(附代码)_第9张图片

前面几大块都是EdgeConv,到了x1,x2这里就有区别了,这里做了一个自适应池化,然后拼接(至于为什么要这样就不清楚了,反正本来就是要做最大池化,跳过),然后就是判断需不需要分类,其实最后返回的inv_feat才是最后需要的输出。

DGCNN_PartSeg

再读CVPR2022-CrossPoint(附代码)_第10张图片

再读CVPR2022-CrossPoint(附代码)_第11张图片

 这里的话前面与分类十分类似,到了判断是否有pretrain这里,如果没有就是完整的一个做分类的DGCNN了(if not下面一大坨),如果有就会截至在做完max pooling这,然后进入inv_head,不用走那么长了。

Train CrossPoint

下面挑一些重点来看吧

再读CVPR2022-CrossPoint(附代码)_第12张图片

 这里的话,其实就是point cloud以及image的f_{\theta}以及g_{\phi}了,当然train CrossPoint选的是"dgcnn"了

再读CVPR2022-CrossPoint(附代码)_第13张图片

 这里的话不明白model.load_state_dict的model哪来的,前面也没有定义,后面的话就是选择传进去优化器的参数了(这里args.use_sgd为True,所以选的model.parameters也有点懵)。

再读CVPR2022-CrossPoint(附代码)_第14张图片

再读CVPR2022-CrossPoint(附代码)_第15张图片

最后保存模型,pretrain完毕。

这个Train_PartSeg就不说了,就是加载pretrained的模型,继续训练DGCNN_PartSeg。 

写在后面

不得不说,看代码确实是一份苦力活,特别是没什么注释的,其中部分还有没看懂的,期待评论区解答与补充。

参考

[1]. 搞懂DGCNN,这篇就够了!论文及代码完全解析

[2]. pytorch torch.nn.Identity() 是干啥的,解释。

[3]. AdaptiveMaxPool的作用

[4]. torchvision.models.resnet.resnet50

[5]. 你必须要知道CNN模型:ResNet

你可能感兴趣的:(python,计算机视觉,人工智能)