业余民科,拾人牙慧,垃圾内容
我在《CppCon 2016: Ben Deane “Using Types Effectively" 笔记》中提到了Ben认为std::variant
和std::optional
是C++最重要的新特性。但是在笔记中,我只提到了std::variant
是type-safe的union,与ML或者Haskell中pattern matching相关。这里就介绍与std::variant
相关的std::visitor
和pattern matching
。
除了type-safe union之外,更为重要的是std::variant
可以替代传统的一些基于inheritance
多态的使用。CppCon 2016: David Sankel “Variants: Past, Present, and Future"给出了一个简单的总结。
Inheritance | Variant |
---|---|
open to new alternatives | closed to new alternatives |
closed to new operations | open to new operations |
multi-level | single level |
OO | functional |
complext | simple |
我唯一熟悉的functional programming language是ML(在coursera上Programming Language时学的),这里就以ML中的pattern matching为例。
datatype exp = Constant of int
| Negate of exp
| Add of exp * exp
| Multiply of exp * exp
注:上面的代码摘抄自《coursera - Programming Languages》
datatype用来声明新的数据类型,像std::variant
一样可以称之为sum type,disjointed union等等。上述的代码递归地定义了一个表示数学表达式的类型exp
。
我们可以写一个函数来找其中最大constant。
fun max_constant e =
case e of
Constant i => i
| Negate e2 => max_constant e2
| Add(e1, e2) => if max_constant e1 > max_constant e2
then max_constant e1
else max_constant e2
| Multiply(e1, e2) => if max_constant e1 > max_constant e2
then max_constant e1
else max_constant e2
注:上面的代码摘抄自《coursera - Programming Language》
上面代码中的case
表达式实现的就是pattern matching的效果,C++17之前的代码并不支持pattern matching,但是我们可以使用OOP的方式模拟它。
#include
#include
#include
class exp {
public:
virtual const int max_constant() = 0;
};
class Constant : public exp {
int v;
public:
Constant(int v) : v(v) {}
int max_constant() const override final { return v; }
};
class Negate : public exp {
int v;
public:
Negate(int v) : v(v) {}
int max_constant() const override final { return v; }
};
class Add : public exp {
std::unique_ptr<exp> e1;
std::unique_ptr<exp> e2;
public:
Add(std::unique_ptr<exp> e1, std::unique_ptr<exp> e2)
: e1(std::move(e1)), e2(std::move(e2)) {}
int max_constant() const override final {
return std::max(e1->max_constant(), e2->max_constant());
}
};
class Multiply : public exp {
std::unique_ptr<exp> e1;
std::unique_ptr<exp> e2;
public:
Multiply(std::unique_ptr<exp> e1, std::unique_ptr<exp> e2)
: e1(std::move(e1)), e2(std::move(e2)) {}
int max_constant() const override final {
return std::max(e1->max_constant(), e2->max_constant());
}
};
pattern matching is a dispatch mechanism: choosing which variant of a function is the correct one to call. - 《Pattern Matching》
如果std::variant
仅仅只是type-safe union,那么并不能释放std::variant
的潜力,需要提供“配套”的处理std::variant
的机制。
single dispatch就是我们所说的dynamic dispatch和static dispatch,而在《Programming Language, Part C - 第一周上课笔记》中,Dan提到dynmaic dispatch是OOP中最本质的东西。
In computer science, dynamic dispatch is the process of selecting which implementation of a polymorphic operation (method or function) to call at run time. It is commonly employed in, and considered a prime characteristic of, object-oriented programming (OOP) languages and systems.
上面的C++代码就是用vtable
实现的single dispatch(dynamic dispatch)。
dynamic dispatch的总结源于《CppCon 2018: Mateusz Pusz “Effective replacement of dynamic polymorphism with std::variant”》
single dynamic dispatch
In computing, static dispatch is a form of polymorphism fully resolved during compile time. It is a form of method dispatch, which describes how a language or environment will select which implementation of a method or function to use.
dynamic dispatch就是virtual function通过vtable和RTTI实现的,static dispatch就是通过CRTP实现的。例如下面的代码就是使用CRTP
实现的single dispatch(static dispatch)。
#include
#include
#include
template<typename Derived>
class Exp {
public:
const int max_constant() {
return static_cast<Derived*>(this)->max_constant();
}
};
class Constant : public Exp<Constant> {
int v;
public:
Constant(int v) : v(v) {}
int max_constant() const { return v; }
};
class Negate : public Exp<Negate> {
int v;
public:
Negate(int v) : v(v) {}
int max_constant() const { return v; }
};
template<typename T, typename U>
class Add : public Exp<Add<T, U>> {
std::unique_ptr<Exp<T>> e1;
std::unique_ptr<Exp<U>> e2;
public:
Add(std::unique_ptr<Exp<T>> e1, std::unique_ptr<Exp<U>> e2)
: e1(std::move(e1)), e2(std::move(e2)) {}
int max_constant() const {
return std::max(e1->max_constant(), e2->max_constant());
}
};
template<typename T, typename U>
class Multiply : public Exp<Multiply<T, U>> {
std::unique_ptr<Exp<T>> e1;
std::unique_ptr<Exp<U>> e2;
public:
Multiply(std::unique_ptr<Exp<T>> e1, std::unique_ptr<Exp<U>> e2)
: e1(std::move(e1)), e2(std::move(e2)) {}
int max_constant() const {
return std::max(e1->max_constant(), e2->max_constant());
}
};
template<typename T>
int max_const(Exp<T>&& e) { return e.max_constant(); }
int main() {
// 我还没有找到如何省去class template argument的写法
auto e = Multiply<Add<Negate, Constant>, Multiply<Constant, Constant>>(
std::make_unique<Add<Negate, Constant>>(std::make_unique<Negate>(10), std::make_unique<Constant>(10)),
std::make_unique<Multiply<Constant, Constant>>(std::make_unique<Constant>(10), std::make_unique<Constant>(30)));
std::cout << max_const(std::move(e)) << std::endl;
}
我们可以看到CRTP的方式的核心在于template继承 + static_cast,但是写模板太痛苦了,得到的好处就是这一切都是在compile time完成的。
single dispatch就是根据特定的类型,执行类型对应的方法或函数。
Visitor Pattern对double dispatch进行了简单的解释,
Multiple dispatch is a concept that allows method dispatch to be based not only on the receiving object but also on the parameters of the method’s invocation.
为了解释double dispatch是做什么的,我们以AST为例来解释。例如我们要遍历语法树,打印节点信息。
class StatementAST {
public:
virtual void print() = 0;
};
class Expr : public StatementAST {
public:
void print() { std::cout << "This is Expr" << '\n'; }
};
class NumberExpr : public Expr {
public:
void print() { std::cout << "This is NumberExpr" << '\n'; }
};
class StringLiteral : public Expr {
public:
void print() { std::cout << "This is StringLiteral" << '\n'; }
};
// ...
class CallExpr : public Expr {
public:
void print() {
std::cout << name << "(";
for (const auto & a : argu) {
a->print();
std::cout << ", ";
}
std::cout << name << ")";
}
};
很简单,我们把print()
作为virtual method加到各AST node中。但是类似于这样的需要对不同AST node进行不同处理的需要还有很多,例如semantic analysis或者生成IR。仿照着print()
,我们可以为每个节点加上对应的virtual method,例如emitIR()
。
class StatementAST {
public:
virtual void print() = 0;
virtual ir emitIr() = 0;
};
class Expr : public StatementAST {
public:
void print() { std::cout << "This is Expr" << '\n'; }
};
class NumberExpr : public Expr {
public:
void print() { std::cout << "This is NumberExpr" << '\n'; }
ir emitIr() {/* */};
};
class StringLiteral : public Expr {
public:
void print() { std::cout << "This is StringLiteral" << '\n'; }
ir emitIr() {/* */};
};
// ...
class CallExpr : public Expr {
public:
void print() {
std::cout << name << "(";
for (const auto & a : argu) {
a->print();
std::cout << ", ";
}
std::cout << name << ")";
}
ir emitIr() {/* (1) emit code for argument (2) emit call expression */};
};
但是这样的需求还有很多很多,把生成IR的代码直接放到AST node中不够模块化,前中后端掺杂在一起。一种可行的实现方式如下:
class Visitor {
public:
virtual void visit(const StatementAST *S) = 0;
virtual void visit(const Expr *E) = 0;
virtual void visit(const NumberExpr *NE) = 0;
virtual void visit(const StringLiteral *SL) = 0;
// ...
virtual void visit(const CallExpr *CE) = 0;
};
class StatementAST {
public:
virtual void accept(Visitor& visitor) = 0;
};
class Expr : public StatementAST {
public:
void accept(Visitor& visitor) override final {
visitor.Visit(*this);
}
};
class NumberExpr : public Expr {
public:
void accept(Visitor& visitor) override final {
visitor.Visit(*this);
}
};
class StringLiteral : public Expr {
public:
void accept(Visitor& visitor) override final {
visitor.Visit(*this);
}
};
// ...
class CallExpr : public Expr {
public:
void accept(Visitor& visitor) override final {
visitor.Visit(*this);
}
};
// 你可以定义多种visitor,例如print visitor,或者IR generate visitor.
class PrintVisitor : public Visitor {
public:
void visit(const StatementAST *S) {
std::cout << "This is StatementAST" << '\n';
}
void visit(const Expr *E) {
std::cout << "This is Expr" << '\n';
}
void visit(const NumberExpr *NE) {
std::cout << "This is NumberExpr" << '\n';
}
void visit(const StringLiteral *SL) {
std::cout << "This is StringLiteral" << '\n';
}
// ...
void visit(const CallExpr *CE) {
std::cout << name << "(";
for (const auto & a : argu) {
a->accept(*this);
std::cout << ", ";
}
std::cout << name << ")";
}
}
class IRGenerateVisitor : public Visitor {
public:
void visit(const StatementAST *S) { /**/ }
void visit(const Expr *E) { /**/ }
void visit(const NumberExpr *NE) { /**/ }
void visit(const StringLiteral *SL) { /**/ }
// ...
void visit(const CallExpr *CE) { /**/ }
};
double dispatch对应的图形如下所示:
但是这里的visitor pattern还不是很完善,每次添加一个新的AST node class,我们都需要修改visitor。这里我们使用两次vtable实现double dispatch,但是我们也可以重载accept method,从而得到不同的行为,此时就是vtable + overload实现double dispatch。
同样是《CppCon 2018: Mateusz Pusz “Effective replacement of dynamic polymorphism with std::variant”》
给出了double dispatch的总结。
double dynamic dispatch
那么如何用std::variant
来表达我们最前面提到的Exp
例子呢,事实上没有直接的方式实现,本质上是C++没有recursive variant
,声明时需要complete type。下面的代码是编译不过的,相关问题《C++ Mutually Recursive Variant Type (Again)》。《C++ 17 in detail》这本书列出了boost::variant和std::variant的对比,如下。
class Constant {
int v;
public:
Constant(int v) : v(v) {}
int value() const { return v; }
};
class Negate {
int v;
public:
Negate(int v) : v(v) {}
int value() const { return v; }
};
class Add {
std::variant<Constant, Negate, Add, Multiply> e1;
std::variant<Constant, Negate, Add, Multiply> e2;
public:
Add(const std::variant<Constant, Negate, Add, Multiply>& e1, std::variant<Constant, Negate, Add, Multiply>& e2) : e1(e1), e2(e2) {}
std::variant<Constant, Negate, Add, Multiply> getOperand1() const { return e1; }
std::variant<Constant, Negate, Add, Multiply> getOperand2() const { return e2; }
};
class Multiply {
std::variant<Constant, Negate, Add, Multiply> e1;
std::variant<Constant, Negate, Add, Multiply> e2;
public:
Multiply(std::variant<Constant, Negate, Add, Multiply> e1, std::variant<Constant, Negate, Add, Multiply> e2) : e1(e1), e2(e2) {}
std::variant<Constant, Negate, Add, Multiply> getOperand1() const { return e1; }
std::variant<Constant, Negate, Add, Multiply> getOperand2() const { return e2; }
};
struct ExpMaxConstVisitor {
int operator()(Constant c) const { return c.value(); }
int operator()(Negate c) const { return c.value(); }
int operator()(Add a) const {
return std::max(std::visit(*this, a.getOperand1()),
std::visit(*this, a.getOperand2()));
}
int operator()(Multiply m) const {
return std::max(std::visit(*this, m.getOperand1()),
std::visit(*this, m.getOperand2()));
}
};
我把上面的代码简化一下让它编译通过,来介绍其中的Visitor
。
class Constant {
int v;
public:
Constant(int v) : v(v) {}
int value() const { return v; }
};
class Negate {
int v;
public:
Negate(int v) : v(v) {}
int value() const { return v; }
};
class Add {
int e1;
int e2;
public:
Add(int e1, int e2) : e1(e1), e2(e2) {}
int getOperand1() const { return e1; }
int getOperand2() const { return e2; }
};
class Multiply {
int e1;
int e2;
public:
Multiply(int e1, int e2) : e1(e1), e2(e2) {}
int getOperand1() const { return e1; }
int getOperand2() const { return e2; }
};
struct ExpMaxConstVisitor {
int operator()(Constant c) const { return c.value(); }
int operator()(Negate c) const { return c.value(); }
int operator()(Add a) const {
// 如果std::variant是递归的话,这里本应该是递归访问的
// return std::max(std::visit(*this, a.getOperand1()),
// std::visit(*this, a.getOperand2()));
return std::max(a.getOperand1(), a.getOperand2());
}
int operator()(Multiply m) const {
// 如果std::variant是递归的话,这里本应该是递归访问的
// return std::max(std::visit(*this, m.getOperand1()),
// std::visit(*this, m.getOperand2()));
return std::max(m.getOperand1(), m.getOperand2());
}
};
int main() {
std::variant<Constant, Negate, Add, Multiply> exp(Add(10, 30));
std::cout << std::visit(ExpMaxConstVisitor(), exp) << std::endl;
return 0;
}
上述代码中的ExpMaxConstVisitor
是最常见的std::variant visitor类型。
The call of
visit()
is a compile-time error if not all possible types are supported by anoperator()
or if the call is ambiguous.
std::visit
也提供了type-safe的保证,如果没有保证穷尽所有的case,compiler可能会抛出下面的error message。
`std::visit` requires the visitor to beexhaustive
generic lambda
是C++14提出来的特性,例如传统的lambda
auto add = [](int a, int b) -> int { return a + b; }
而generic lambda
如下所示:
auto add = [](auto a, auto b) { return a + b; }
如果加上-std=c++11
option则会抛出下面的错误。
#1 with x86-64 clang 9.0.0
<source>:3:19: error: 'auto' not allowed in lambda parameter
auto add = [](auto a, auto b) { return a + b ;};
^~~~
<source>:3:27: error: 'auto' not allowed in lambda parameter
auto add = [](auto a, auto b) { return a + b ;};
^~~~
<source>:4:15: error: invalid operands to binary expression ('std::ostream' (aka 'basic_ostream' ) and 'void')
std::cout << add(10, 20) << std::endl;
通过c++ insights,我们知道generic lambda add
相当于下面的代码。
class __lambda_3_16 {
public:
template<class type_parameter_0_0, class type_parameter_0_1>
inline /*constexpr */ auto operator()(type_parameter_0_0 a, type_parameter_0_1 b) const {
return a + b;
}
private:
template<class type_parameter_0_0, class type_parameter_0_1>
static inline auto __invoke(type_parameter_0_0 a, type_parameter_0_1 b) {
return a + b;
}
};
使用generic lambda visitor代码如下:
int main() {
std::variant<Constant, Negate, Add, Multiply> exp(Add(10, 30));
std::cout << std::visit([](auto& val) {
if constexpr(std::is_convertible_v<decltype(val), Constant>) {
return val.value();
}
if constexpr(std::is_convertible_v<decltype(val), Negate>) {
return val.value();
}
if constexpr(std::is_convertible_v<decltype(val), Add>) {
return std::max(val.getOperand1(), val.getOperand2());
}
if constexpr(std::is_convertible_v<decltype(val), Multiply>) {
return std::max(val.getOperand1(), val.getOperand2());
}
}, exp) << std::endl;
return 0;
}
上面我们使用了compile-time if
的特性,也就是c++17提出来的特性if constexpr
。同样的,通过c++ insights展开上述代码。可以发现这一切都是template argument type deduction实现的。
By using an
overloader
for function objects and lambdas, you can also define a set of lambdas where the best match is used as a visitor.
template<typename... Ts> struct overload : Ts... {
using Ts::operator()...;
};
template<typename... Ts>
overload(Ts...) -> overload<Ts...>;
int main() {
std::variant<Constant, Negate, Add, Multiply> exp(Add(10, 30));
std::cout << std::visit(overload{
[](auto a) { return a.value(); },
[](Add a) {return std::max(a.getOperand1(), a.getOperand2());},
[](Multiply a) {return std::max(a.getOperand1(), a.getOperand2());}
}, exp) << std::endl;
return 0;
}
std::variant
现阶段还有很多可以提升的地方,例如recursive viariant
,visitor的部分写起来还算简单。可以减少部分class inheritance的使用。但和functional programming中的pattern matching比较来说,稍微还有点儿复杂,但有终归比没有好。
关于recursive variant,还有很多可以讲的地方。像Variant design review中介绍的,
Recursive variants are variants that (conceptually) have itself as one of the alternatives. There are good reasons to add support for a recursive variant; for instance to build AST nodes. There are also good reasons not to do so, and to instead use unique_ptr
> as an alternative . A recursive variant can be implemented as an extension to variant, see for instance what is done for boost::variant. The proposals does not contain support for recursive variants; they also do not preclude a proposal for them