最近在尝试把TensoRF在jittor框架下实现(可以去我的仓库查看转换好的模型,目前还在完善中),pytorch中具有的函数在jittor中大多具有同名函数,直接将torch改为jittor就能用,但还是有少部分地方需要进行修改,因此记录一下在转换中遇到的问题,其中红色部分是需要十分注意,否则可能会导致训练失败的地方
该函数是将一个不可训练的tensor转换为可以训练的parameter,在jittor中不需要该函数(jittor为了兼容性保留了该函数,不过它啥事也不做)
注意:在torch中使用了@torch.no_grad()后,仍然可以用Parameter()将其requires_grad设置为True,但在jittor中由于这个函数被取消,因此requires_grad始终为False,会影响你的训练!我的解决方法是在只不需要梯度的地方用with jt.no_grad(),避免影响参数训练
在torch中,用 b=a.data() 或者 b=a.detach() 可以使b共享a的内存地址,并且b的requires_grad为False,但在jittor中,经过我的实验,这两个特性似乎都无法实现
在torch中,max是获得指定维度下的最大值以及索引,而amax仅获得最大值
在jittor中,max仅获得最大值,如果需要索引应该用这个函数:arg_reduce('max')
jittor采用的数据类型是Var,可以直接用jittor.array(data)或者jt.Var(data)将ndarray转换为Var或者生成新的Var,其用法与pytorch基本类似
Var转换为ndarray只需要在数据后面加上.numpy()即可
注意:用jt.array()转换ndarray时,哪怕原数据为float64,其转换得到的array仍为float32,可能会影响精度,你可以用jt.float64()进行转换,避免精度损失
用execute代替forward
用optimizer.step(loss)代替下面三行:
optimizer.zero_grad()
loss.backward()
optimizer.step()
此外,jittor中Dataset的位置有改变,如下所示
from jittor.dataset.dataset import Dataset
在torch中,你需要利用.to(device)手动实现数据在cpu和gpu之间的转换,但在jittor中采用的是统一内存管理策略,即jittor将cpu和gpu的内存统一起来,你不需要(好像也不能。。。)手动对具体的数据进行转换,只需要在代码运行前用下面的代码设置flag便可,其中x为0代表在cpu运行,为1代表尝试在gpu运行(如果显存不够会使用cpu,速度会慢很多),为2时代表强制在gpu下运行
jittor.flags.use_cuda = x
理论上应该只需要在运行的第一个文件前加上就行,但我为了保守起见在每个文件前都加上了,测试了一下好像都一样
在torch中,view()等几个函数并不生成新的数据,只是改变了数据的索引方式,因此其存储是不连续的(指索引与存储位置不一致),有时候需要用contiguous()函数使数据变得一致,否则有些函数无法处理(会报错),但在jittor中并没有这个函数,我尝试直接删去contiguous之后并没有报错,如果你遇到报错的话,可以尝试用reshape()代替view(),因为该函数直接生成了新的数据,不存在数据不连续的情况
对于bool类型的数据data,在torch中用~data就可以对data进行取反,但在jittor中并没有这个函数,可以利用jittor.logic_not(data)达到相同的效果
在jittor中,你可以用下面的代码来清理显存,其中最后一行是显示占用显存的情况。
我的选择是每隔500次迭代清理一次,你可以根据情况自行尝试
jittor.clean_graph()
jittor.sync_all()
jittor.gc()
jittor.display_memory_info()
在torch中这个函数可以指定dtype,但在jittor中不能,其输出与输入的数据类型一致
当grid为空时,torch可以输出一个空的tensor,但在jittor运行时,这个函数会卡死不动(gpu下会卡死,但用cpu的时候可以正常运行,我也没搞懂为什么。。。)
在torch中需要指定device加载模型,如下所示,但在jittor中不需要,直接忽略该参数即可
torch.load(ckpt, map_location=device)
在jittor为jittor.concat()
在torch中有lr的默认值,在jittor中需要手动设置
在jittor中用jittor.set_seed()设置seed
在torch中该函数可以将Image得到的图像进行处理,转换为tensor并进行归一化和转置等操作,得到的tensor形状为(C, H, W),取值范围为[0, 1]
在Jittor中为
jittor.transform.ToTensor()
求逆矩阵操作,在Jittor中如下(跟numpy一致)
jittor.linalg.inv()
矩阵转置,在jittor中没有同名函数,只能用 jittor.transpose(1, 0)
在jittor的ReLU中没有inplace参数,直接删掉即可
在jittor中为jittor.sqr()
在jittor中参数分别改为min_v , max_v
long函数在torch中是将数据转换为int64,但在jittor中是int32,并且用jittor.arrar(data,dtype=jittor.int64)也没法转换,需要采用下面的方法进行转换,其它类型也一样
data.int64()
我对比了jittor和pytorch的部分运算,发现jittor速度快很多,但我转换过的模型却比原版慢,目前还没排除到原因
在将模型参数以Var格式保存后,用jittor.load(ckpt)加载得到的参数却是array格式,很奇怪。。。