std::variant visitor与pattern matching

业余民科,拾人牙慧,垃圾内容

文章目录

  • std::variant
  • pattern matching
  • single dispatch
    • dynamic dispatch
    • static dispatch
  • double dispatch(visitor pattern)
  • std::visit
    • Using Function Objects as Visitors
    • Using Generic Lambdas as Visitors
      • Using Overloaded Lambdas as Visitors

std::variant

我在《CppCon 2016: Ben Deane “Using Types Effectively" 笔记》中提到了Ben认为std::variantstd::optional是C++最重要的新特性。但是在笔记中,我只提到了std::variant是type-safe的union,与ML或者Haskell中pattern matching相关。这里就介绍与std::variant相关的std::visitorpattern matching

除了type-safe union之外,更为重要的是std::variant可以替代传统的一些基于inheritance多态的使用。CppCon 2016: David Sankel “Variants: Past, Present, and Future"给出了一个简单的总结。

Inheritance Variant
open to new alternatives closed to new alternatives
closed to new operations open to new operations
multi-level single level
OO functional
complext simple

pattern matching

我唯一熟悉的functional programming language是ML(在coursera上Programming Language时学的),这里就以ML中的pattern matching为例。

datatype exp = Constant of int
			| Negate    of exp
			| Add       of exp * exp
			| Multiply  of exp * exp

注:上面的代码摘抄自《coursera - Programming Languages》
datatype用来声明新的数据类型,像std::variant一样可以称之为sum type,disjointed union等等。上述的代码递归地定义了一个表示数学表达式的类型exp

我们可以写一个函数来找其中最大constant。

fun max_constant e = 
	case e of 
			Constant i => i
		| Negate e2 => max_constant e2
		| Add(e1, e2) => if max_constant e1 > max_constant e2
						then max_constant e1
						else max_constant e2
		| Multiply(e1, e2) => if max_constant e1 > max_constant e2
						then max_constant e1
						else max_constant e2

注:上面的代码摘抄自《coursera - Programming Language》

上面代码中的case表达式实现的就是pattern matching的效果,C++17之前的代码并不支持pattern matching,但是我们可以使用OOP的方式模拟它。

#include 
#include 
#include 
class exp {
public:
  virtual const int max_constant() = 0;
};

class Constant : public exp {
  int v;

public:
  Constant(int v) : v(v) {}
  int max_constant() const override final { return v; }
};

class Negate : public exp {
  int v;

public:
  Negate(int v) : v(v) {}
  int max_constant() const override final { return v; }
};

class Add : public exp {
  std::unique_ptr<exp> e1;
  std::unique_ptr<exp> e2;

public:
  Add(std::unique_ptr<exp> e1, std::unique_ptr<exp> e2)
      : e1(std::move(e1)), e2(std::move(e2)) {}
  int max_constant() const override final {
    return std::max(e1->max_constant(), e2->max_constant());
  }
};

class Multiply : public exp {
  std::unique_ptr<exp> e1;
  std::unique_ptr<exp> e2;

public:
  Multiply(std::unique_ptr<exp> e1, std::unique_ptr<exp> e2)
      : e1(std::move(e1)), e2(std::move(e2)) {}
  int max_constant() const override final {
    return std::max(e1->max_constant(), e2->max_constant());
  }
};

pattern matching is a dispatch mechanism: choosing which variant of a function is the correct one to call. - 《Pattern Matching》

如果std::variant仅仅只是type-safe union,那么并不能释放std::variant的潜力,需要提供“配套”的处理std::variant的机制。

single dispatch

single dispatch就是我们所说的dynamic dispatch和static dispatch,而在《Programming Language, Part C - 第一周上课笔记》中,Dan提到dynmaic dispatch是OOP中最本质的东西。

dynamic dispatch

In computer science, dynamic dispatch is the process of selecting which implementation of a polymorphic operation (method or function) to call at run time. It is commonly employed in, and considered a prime characteristic of, object-oriented programming (OOP) languages and systems.

上面的C++代码就是用vtable实现的single dispatch(dynamic dispatch)。

dynamic dispatch的总结源于《CppCon 2018: Mateusz Pusz “Effective replacement of dynamic polymorphism with std::variant”》

single dynamic dispatch

  • Open to new alternatives
    - new derived types may be added by clients at any point of time (long after base class implementation is finished)
  • Closed to new operations
    - clients cannot add new operations to dynamic dispatch
  • Multi-level
    - many level of inheritance possible
  • Object Oriented
    - whole framework is based on objects

对应的class类图如下所示:
std::variant visitor与pattern matching_第1张图片

static dispatch

In computing, static dispatch is a form of polymorphism fully resolved during compile time. It is a form of method dispatch, which describes how a language or environment will select which implementation of a method or function to use.

dynamic dispatch就是virtual function通过vtable和RTTI实现的,static dispatch就是通过CRTP实现的。例如下面的代码就是使用CRTP实现的single dispatch(static dispatch)。

#include 
#include 
#include 

template<typename Derived>
class Exp {
public:
  const int max_constant() {
    return static_cast<Derived*>(this)->max_constant();
  }
};

class Constant : public Exp<Constant> {
  int v;

public:
  Constant(int v) : v(v) {}
  int max_constant() const { return v; }
};

class Negate : public Exp<Negate> {
  int v;

public:
  Negate(int v) : v(v) {}
  int max_constant() const { return v; }
};

template<typename T, typename U>
class Add : public Exp<Add<T, U>> {
  std::unique_ptr<Exp<T>> e1;
  std::unique_ptr<Exp<U>> e2;

public:
  Add(std::unique_ptr<Exp<T>> e1, std::unique_ptr<Exp<U>> e2)
      : e1(std::move(e1)), e2(std::move(e2)) {}
  int max_constant() const {
    return std::max(e1->max_constant(), e2->max_constant());
  }
};

template<typename T, typename U>
class Multiply : public Exp<Multiply<T, U>> {
  std::unique_ptr<Exp<T>> e1;
  std::unique_ptr<Exp<U>> e2;

public:
  Multiply(std::unique_ptr<Exp<T>> e1, std::unique_ptr<Exp<U>> e2)
      : e1(std::move(e1)), e2(std::move(e2)) {}
  int max_constant() const {
    return std::max(e1->max_constant(), e2->max_constant());
  }
};

template<typename T>
int max_const(Exp<T>&& e) { return e.max_constant(); }

int main() {
	// 我还没有找到如何省去class template argument的写法
  auto e = Multiply<Add<Negate, Constant>, Multiply<Constant, Constant>>(
    std::make_unique<Add<Negate, Constant>>(std::make_unique<Negate>(10), std::make_unique<Constant>(10)), 
    std::make_unique<Multiply<Constant, Constant>>(std::make_unique<Constant>(10), std::make_unique<Constant>(30)));
  std::cout << max_const(std::move(e)) << std::endl;
}

我们可以看到CRTP的方式的核心在于template继承 + static_cast,但是写模板太痛苦了,得到的好处就是这一切都是在compile time完成的。

single dispatch就是根据特定的类型,执行类型对应的方法或函数。

double dispatch(visitor pattern)

Visitor Pattern对double dispatch进行了简单的解释,

Multiple dispatch is a concept that allows method dispatch to be based not only on the receiving object but also on the parameters of the method’s invocation.

为了解释double dispatch是做什么的,我们以AST为例来解释。例如我们要遍历语法树,打印节点信息。

class StatementAST {
public:
	virtual void print() = 0;
};

class Expr : public StatementAST {
public:
	void print() { std::cout << "This is Expr" << '\n'; }
};

class NumberExpr : public Expr {
public:
	void print() { std::cout << "This is NumberExpr" << '\n'; }
};

class StringLiteral : public Expr {
public:
	void print() { std::cout << "This is StringLiteral" << '\n'; }
};

// ...
class CallExpr : public Expr {
public:
	void print() { 
		std::cout << name << "(";
		for (const auto & a : argu) {
			a->print();
			std::cout << ", ";
		}
		std::cout << name << ")";
	}
};

很简单,我们把print()作为virtual method加到各AST node中。但是类似于这样的需要对不同AST node进行不同处理的需要还有很多,例如semantic analysis或者生成IR。仿照着print(),我们可以为每个节点加上对应的virtual method,例如emitIR()

class StatementAST {
public:
	virtual void print() = 0;
	virtual ir emitIr() = 0;
};

class Expr : public StatementAST {
public:
	void print() { std::cout << "This is Expr" << '\n'; }
};

class NumberExpr : public Expr {
public:
	void print() { std::cout << "This is NumberExpr" << '\n'; }
	ir emitIr() {/* */};
};

class StringLiteral : public Expr {
public:
	void print() { std::cout << "This is StringLiteral" << '\n'; }
	ir emitIr() {/* */};
};

// ...
class CallExpr : public Expr {
public:
	void print() { 
		std::cout << name << "(";
		for (const auto & a : argu) {
			a->print();
			std::cout << ", ";
		}
		std::cout << name << ")";
	}
	ir emitIr() {/* (1) emit code for argument (2) emit call expression */};
};

但是这样的需求还有很多很多,把生成IR的代码直接放到AST node中不够模块化,前中后端掺杂在一起。一种可行的实现方式如下:

class Visitor {
public:
	virtual void visit(const StatementAST *S) = 0;
	virtual void visit(const Expr *E) = 0;
	virtual void visit(const NumberExpr *NE) = 0;
	virtual void visit(const StringLiteral *SL) = 0;
	// ...
	virtual void visit(const CallExpr *CE) = 0;
};
class StatementAST {
public:
	virtual void accept(Visitor& visitor) = 0;
};

class Expr : public StatementAST {
public:
	void accept(Visitor& visitor) override final {
		visitor.Visit(*this);
	}
};

class NumberExpr : public Expr {
public:
	void accept(Visitor& visitor) override final {
		visitor.Visit(*this);
	}
};

class StringLiteral : public Expr {
public:
	void accept(Visitor& visitor) override final {
		visitor.Visit(*this);
	}
};

// ...
class CallExpr : public Expr {
public:
	void accept(Visitor& visitor) override final {
		visitor.Visit(*this);
	}
};

// 你可以定义多种visitor,例如print visitor,或者IR generate visitor.
class PrintVisitor : public Visitor {
public:
	void visit(const StatementAST *S) {
		std::cout << "This is StatementAST" << '\n';
	}
	void visit(const Expr *E) {
		std::cout << "This is Expr" << '\n';
	}
	void visit(const NumberExpr *NE) {
		std::cout << "This is NumberExpr" << '\n';
	}
	void visit(const StringLiteral *SL) {
		std::cout << "This is StringLiteral" << '\n';
	}
	// ...
	void visit(const CallExpr *CE) {
		std::cout << name << "(";
		for (const auto & a : argu) {
			a->accept(*this);
			std::cout << ", ";
		}
		std::cout << name << ")";
	}
}

class IRGenerateVisitor : public Visitor {
public:
	void visit(const StatementAST *S) { /**/ }
	void visit(const Expr *E) { /**/ }
	void visit(const NumberExpr *NE) { /**/ }
	void visit(const StringLiteral *SL) { /**/ }
	// ...
	void visit(const CallExpr *CE) { /**/ }
};

double dispatch对应的图形如下所示:
std::variant visitor与pattern matching_第2张图片
但是这里的visitor pattern还不是很完善,每次添加一个新的AST node class,我们都需要修改visitor。这里我们使用两次vtable实现double dispatch,但是我们也可以重载accept method,从而得到不同的行为,此时就是vtable + overload实现double dispatch。

同样是《CppCon 2018: Mateusz Pusz “Effective replacement of dynamic polymorphism with std::variant”》
给出了double dispatch的总结。
double dynamic dispatch

  • Open to new alternatives (因为你需要同时修改visitor)
  • Closed to new operations
    - clients cannot add new operations to dynamic dispatch
  • Multi-level
    - many level of inheritance possible
  • Object Oriented
    - whole framework is based on objects

std::visit

那么如何用std::variant来表达我们最前面提到的Exp例子呢,事实上没有直接的方式实现,本质上是C++没有recursive variant,声明时需要complete type。下面的代码是编译不过的,相关问题《C++ Mutually Recursive Variant Type (Again)》。《C++ 17 in detail》这本书列出了boost::variant和std::variant的对比,如下。
std::variant visitor与pattern matching_第3张图片

class Constant {
  int v;

public:
  Constant(int v) : v(v) {}
  int value() const { return v; }
};

class Negate {
  int v;

public:
  Negate(int v) : v(v) {}
  int value() const { return v; }
};

class Add {
  std::variant<Constant, Negate, Add, Multiply> e1;
  std::variant<Constant, Negate, Add, Multiply> e2;

public:
  Add(const std::variant<Constant, Negate, Add, Multiply>& e1, std::variant<Constant, Negate, Add, Multiply>& e2) : e1(e1), e2(e2) {}
  std::variant<Constant, Negate, Add, Multiply> getOperand1() const { return e1; }
  std::variant<Constant, Negate, Add, Multiply> getOperand2() const { return e2; }
};

class Multiply {
  std::variant<Constant, Negate, Add, Multiply> e1;
  std::variant<Constant, Negate, Add, Multiply> e2;

public:
  Multiply(std::variant<Constant, Negate, Add, Multiply> e1, std::variant<Constant, Negate, Add, Multiply> e2) : e1(e1), e2(e2) {}
  std::variant<Constant, Negate, Add, Multiply> getOperand1() const { return e1; }
  std::variant<Constant, Negate, Add, Multiply> getOperand2() const { return e2; }
};

struct ExpMaxConstVisitor {
  int operator()(Constant c) const { return c.value(); }
  int operator()(Negate c) const { return c.value(); }
  int operator()(Add a) const {
    return std::max(std::visit(*this, a.getOperand1()),
                    std::visit(*this, a.getOperand2()));
  }
  int operator()(Multiply m) const {
    return std::max(std::visit(*this, m.getOperand1()),
                    std::visit(*this, m.getOperand2()));
  }
};

Using Function Objects as Visitors

我把上面的代码简化一下让它编译通过,来介绍其中的Visitor

class Constant {
  int v;

public:
  Constant(int v) : v(v) {}
  int value() const { return v; }
};

class Negate {
  int v;

public:
  Negate(int v) : v(v) {}
  int value() const { return v; }
};

class Add {
  int e1;
  int e2;

public:
  Add(int e1, int e2) : e1(e1), e2(e2) {}
  int getOperand1() const { return e1; }
  int getOperand2() const { return e2; }
};

class Multiply {
  int e1;
  int e2;

public:
  Multiply(int e1, int e2) : e1(e1), e2(e2) {}
  int getOperand1() const { return e1; }
  int getOperand2() const { return e2; }
};

struct ExpMaxConstVisitor {
  int operator()(Constant c) const { return c.value(); }
  int operator()(Negate c) const { return c.value(); }
  int operator()(Add a) const {
    // 如果std::variant是递归的话,这里本应该是递归访问的
    // return std::max(std::visit(*this, a.getOperand1()),
    //                 std::visit(*this, a.getOperand2()));
    return std::max(a.getOperand1(), a.getOperand2());
  }
  int operator()(Multiply m) const {
    // 如果std::variant是递归的话,这里本应该是递归访问的
    // return std::max(std::visit(*this, m.getOperand1()),
    //                 std::visit(*this, m.getOperand2()));
    return std::max(m.getOperand1(), m.getOperand2());
  }
};

int main() {
  std::variant<Constant, Negate, Add, Multiply> exp(Add(10, 30));
  std::cout << std::visit(ExpMaxConstVisitor(), exp) << std::endl;
  return 0;
}

上述代码中的ExpMaxConstVisitor是最常见的std::variant visitor类型。

The call of visit() is a compile-time error if not all possible types are supported by an operator() or if the call is ambiguous.

std::visit也提供了type-safe的保证,如果没有保证穷尽所有的case,compiler可能会抛出下面的error message。

`std::visit` requires the visitor to beexhaustive

Using Generic Lambdas as Visitors

generic lambda是C++14提出来的特性,例如传统的lambda

auto add = [](int a, int b) -> int { return a + b; }

generic lambda如下所示:

auto add = [](auto a, auto b) { return a + b; }

如果加上-std=c++11 option则会抛出下面的错误。


#1 with x86-64 clang 9.0.0
<source>:3:19: error: 'auto' not allowed in lambda parameter

    auto add = [](auto a, auto b) { return a + b ;};

                  ^~~~

<source>:3:27: error: 'auto' not allowed in lambda parameter

    auto add = [](auto a, auto b) { return a + b ;};

                          ^~~~

<source>:4:15: error: invalid operands to binary expression ('std::ostream' (aka 'basic_ostream') and 'void')

    std::cout << add(10, 20) << std::endl;

通过c++ insights,我们知道generic lambda add相当于下面的代码。

class __lambda_3_16 {
public: 
  template<class type_parameter_0_0, class type_parameter_0_1>
  inline /*constexpr */ auto operator()(type_parameter_0_0 a, type_parameter_0_1 b) const {
    return a + b;
  }
private: 
  template<class type_parameter_0_0, class type_parameter_0_1>
  static inline auto __invoke(type_parameter_0_0 a, type_parameter_0_1 b) {
    return a + b;
  }  
};

使用generic lambda visitor代码如下:

int main() {
  std::variant<Constant, Negate, Add, Multiply> exp(Add(10, 30));
  std::cout << std::visit([](auto& val) {
    if constexpr(std::is_convertible_v<decltype(val), Constant>) {
      return val.value();
    }
    if constexpr(std::is_convertible_v<decltype(val), Negate>) {
      return val.value();
    }
    if constexpr(std::is_convertible_v<decltype(val), Add>) {
      return std::max(val.getOperand1(), val.getOperand2());
    }
    if constexpr(std::is_convertible_v<decltype(val), Multiply>) {
      return std::max(val.getOperand1(), val.getOperand2());
    }
  }, exp) << std::endl;
  return 0;
}

上面我们使用了compile-time if的特性,也就是c++17提出来的特性if constexpr。同样的,通过c++ insights展开上述代码。可以发现这一切都是template argument type deduction实现的。

Using Overloaded Lambdas as Visitors

By using an overloader for function objects and lambdas, you can also define a set of lambdas where the best match is used as a visitor.

template<typename... Ts> struct overload : Ts... {
  using Ts::operator()...;
};
template<typename... Ts>
overload(Ts...) -> overload<Ts...>;

int main() {
  std::variant<Constant, Negate, Add, Multiply> exp(Add(10, 30));
  std::cout << std::visit(overload{
    [](auto a) { return a.value(); },
    [](Add a) {return std::max(a.getOperand1(), a.getOperand2());},
    [](Multiply a) {return std::max(a.getOperand1(), a.getOperand2());}
  }, exp) << std::endl;
  return 0;
}

std::variant现阶段还有很多可以提升的地方,例如recursive viariantvisitor的部分写起来还算简单。可以减少部分class inheritance的使用。但和functional programming中的pattern matching比较来说,稍微还有点儿复杂,但有终归比没有好。

关于recursive variant,还有很多可以讲的地方。像Variant design review中介绍的,

Recursive variants are variants that (conceptually) have itself as one of the alternatives. There are good reasons to add support for a recursive variant; for instance to build AST nodes. There are also good reasons not to do so, and to instead use unique_ptr> as an alternative. A recursive variant can be implemented as an extension to variant, see for instance what is done for boost::variant. The proposals does not contain support for recursive variants; they also do not preclude a proposal for them

你可能感兴趣的:(c++基础)