利用指向数据成员的指针实现容器对象数据成员的筛选

// TF v2.11.1 
// tensorflow\compiler\xla\mlir_hlo\lib\Dialect\mhlo\IR\hlo_ops.cc
// tensorflow\compiler\xla\xla_data.proto:468
// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window
// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution
// https://www.tensorflow.org/xla/operation_semantics#reducewindow
// https://www.tensorflow.org/xla/operation_semantics#conv_convolution
struct WindowDimension {
  int64_t size = 0;
  int64_t stride = 1;
  int64_t paddingLow = 0;
  int64_t paddingHigh = 0;
  int64_t windowDilation = 1;
  int64_t baseDilation = 1;
  bool windowReversal = false;
};
using DimensionVector = std::vector<WindowDimension>;

template<typename T, typename Cont = std::vector<T>, int num_dims = 4>
llvm::ArrayRef<T> getKernelProps(const DimensionVector& dims, Cont& props, const T WindowDimension::* ptm /* pointer to T member */){
	// Note: llvm::ArrayRef refers to an external buf, so a local container like below will not work correctly
	// Cont props;
	
	// >=c++14
	auto map_fn = [&](const auto& d){
		// <> item59 SFINAE
		// casting 0 to pointer to member, indicates padding
		if(ptm == static_cast<T WindowDimension::*>(0)){
			props.push_back(d.paddingLow);
			props.push_back(d.paddingHigh);
		}else{
			props.push_back(d.*ptm);//accessing d's member variable via d.*ptm
		}
	};
	std::for_each(dims.begin(), dims.end(), map_fn);
	return llvm::ArrayRef(props);
}

// partial specialization for bool: because std::vector has a special implementatation for bool type, we use std::array instead here
template<int num_dims = 4>
llvm::ArrayRef<bool> getKernelProps(const DimensionVector& dims, std::array<bool, num_dims>& props, const T WindowDimension::* /* pointer to bool member */){
	int index = 0;
	std::for_each(dims.begin(), dims.end(), [&](const auto& d){
		props[index++] = d.windowReversal;
	});
	return llvm::ArrayRef(props);
}

// client code
template<int num_dims = 4> // num_dims = 4: 2d MaxPool rewrite, NHWC
void ConvertMaxPoolOp(){
	DimensionVector dims;
	// set up dims
	// ...
	
	// get a normal int64_t prop
	std::vector<int64_t> size_vector;
	// num_dims is always equal to dims.size()
	size_vector.reserve(dims.size());
	auto size_props = getKernelProps(dims, size_vector, &WindowDimension::size);

	// special case1: get a padding prop, dual prop(low and high)
	std::vector<int64_t> pad_vector;
	pad_vector.reserve(dims.size() * 2);
	auto pad_props = getKernelProps(dims, pad_vector, static_cast<int64_t WindowDimension::*>(0));
	
	// special case2: get a bool prop
	std::array<bool, num_dims>  wr_array;
	auto wr_props = getKernelProps<num_dims>(dims, wr_array, &WindowDimension::windowReversal);
}

A Simpler Solution

using std::array instead of std::vector to avoid “bool problems”.

template<typename T, int array_size>
llvm::ArrayRef<T> getKernelProps(const DimensionVector& dims, std::array<T, array_size>& props, const T WindowDimension::* ptm /* pointer to T member */){
	int index = 0;
	auto map_fn = [&](const auto& d){
		if(ptm == static_cast<T WindowDimension::*>(0)){
			props[index++] = d.paddingLow;
			props[index++] = d.paddingHigh;
		}else{
			props[index++] = d.*ptm;//accessing d's member variable via d.*ptm
		}
	};
	std::for_each(dims.begin(), dims.end(), map_fn);
	return llvm::ArrayRef(props);
}

// client code
std::array<bool, num_dims>  wr_array;
auto wr_props = getKernelProps<bool, num_dims>(dims, wr_array, &WindowDimension::windowReversal);

// padding
constexpr auto double_num_dims = 2 * num_dims;
std::array<int64_t, double_num_dims> pad_array;
auto pad_props = getKernelProps<int64_t, double_num_dims>(dims, pad_array, static_cast<int64_t WindowDimension::*>(0));

你可能感兴趣的:(设计模式,c++)