【Libtorch】对tensor的切片索引

本人在使用Libtorch对tensor实现对某一维度前K行的提取,简单地说对某一个大小为[N,M]的tensor,提取前K行,获取到[K,M]的tensor。
最关键的函数就是tensor的index_select函数和torch::arange函数。
比如对大小为[5.3]的tensor提取前4行。
代码如下:

	torch::Tensor example_ = torch::rand({ 5,3 });
	std::cout << "example_:" << std::endl;
	std::cout << example_ << std::endl;
	torch::Tensor index__ = torch::arange(0, 4);
	std::cout << "index__:" << std::endl;
	std::cout << index__ << std::endl;
	torch::Tensor example_result = example_.index_select(0, index__);//0代表对第0维度来索引
	std::cout << "example_result:" << std::endl;
	std::cout << example_result << std::endl;

打印结果如下:

example_:
 0.0168  0.9895  0.3301
 0.0651  0.3022  0.5376
 0.3109  0.6183  0.8319
 0.8542  0.6346  0.4402
 0.4982  0.3748  0.8173
[ CPUFloatType{5,3} ]
index__:
 0
 1
 2
 3
[ CPULongType{4} ]
example_result:
 0.0168  0.9895  0.3301
 0.0651  0.3022  0.5376
 0.3109  0.6183  0.8319
 0.8542  0.6346  0.4402
[ CPUFloatType{4,3} ]

Tip:
本人在实际应用中对前向传播后的tensor做上述操作的时候一直有问题,一执行index_select函数就会直接结束,也不会提示问题点是什么。
后来发现原来要索引的tensor在执行index_select时所建立的索引tensor(也就是上述例子中的index__ = torch::arange(0, 4))也需要加载到GPU上,因为前向传播后的tensor此时还在GPU上。
所以在建立index的tensor时加到GPU上:

	torch::Tensor example_ = torch::rand({ 5,3 }).to(torch::Device(torch::kCUDA));
	std::cout << "example_:" << std::endl;
	std::cout << example_ << std::endl;
	torch::Tensor index__ = torch::arange(0, 4).to(torch::Device(torch::kCUDA));
	std::cout << "index__:" << std::endl;
	std::cout << index__ << std::endl;
	torch::Tensor example_result = example_.index_select(0, index__);
	std::cout << "example_result:" << std::endl;
	std::cout << example_result << std::endl;

打印结果如下:

example_:
 0.1363  0.6183  0.2576
 0.7965  0.0724  0.1440
 0.4968  0.7539  0.4834
 0.3361  0.4781  0.1167
 0.7776  0.1284  0.9714
[ CUDAFloatType{5,3} ]
index__:
 0
 1
 2
 3
[ CUDALongType{4} ]
example_result:
 0.1363  0.6183  0.2576
 0.7965  0.0724  0.1440
 0.4968  0.7539  0.4834
 0.3361  0.4781  0.1167
[ CUDAFloatType{4,3} ]

仔细对比可以发现数据类型是不一样的,上面是CPUFloatType,下面是CUDAFloatType。

你可能感兴趣的:(Libtorch,C++,深度学习,深度学习,人工智能)