libtorch学习第四

切片和索引操作

#include
#include

using std::cout; using std::endl;

int main()
{
	auto b = torch::rand({ 10,3,28,28 });
	/*cout << b[0].sizes() << endl;
	cout << b[0][0].sizes() << endl;
	cout << b[0][0][0].sizes() << endl;*/

	// 选择第0维中的index为[0,3,3]的数据组成新的tensor
	//cout << b.index_select(0, torch::tensor({ 0,3,3 })).sizes() << endl;

	//cout << b.index_select(1, torch::tensor({ 0,2 })).sizes() << endl;

	// 取第2维中[0-8]
	//cout << b.index_select(2, torch::arange(0, 8)).sizes() << endl;

	// 取第1维,从index为0开始,长度为2
	//cout << b.narrow(1, 0, 2).sizes() << endl;

	// 选择第3维,index为2
	//cout << b.select(3, 2).sizes() << endl;

	//
	auto c = torch::randn({ 3,4 });
	auto mask = torch::zeros({ 3,4 });
	mask[0][0] = 1; mask[0][2] = 1;
	cout << c << endl;

	auto d = c.index({ mask.to(torch::kBool) });
	//cout << d << endl;

	auto e = c.index_put_({ mask.to(torch::kBool) }, c.index({ mask.to(torch::kBool) }) + 1.5);
	cout << e << endl;
	

	return 0;
}

你可能感兴趣的:(pytorch,学习,算法)