ComPtr源码分析

ComPtr源码分析

ComPtr是微软提供的用来管理COM组件的智能指针。DirectX的API是由一系列的COM组件来管理的,形如ID3D12DeviceIDXGISwapChain等的接口类最终都继承自IUnknown接口类,这个接口类包含AddRefRelease两个方法,分别用来增加和减少内部的引用计数。当引用计数为0时,内存才会真正释放。在实际使用中,我们肯定是不希望,对这类指针,都去人工地调用这两个方法来维护引用计数。这样做的心智负担太大,如果忘记释放某个接口类指针,就会造成内存泄漏;而如果忘记增加引用计数,则可能会在错误的时机提前释放了内存,导致运行时错误。

ComPtr就是为了解决上述问题而存在的。我们来看下面这个例子:

#include 
#include 

using Microsoft::WRL::ComPtr;
using namespace std;

class A
{
public:
    unsigned long ref = 0;
    void AddRef()
    {
        ref++;
        cout << "incr ref, cur ref " << ref << endl;
    }
    unsigned long Release()
    {
        ref--;
        cout << "decr ref, cur ref " << ref << endl;
        if(ref == 0)
        {
            cout << "release!" << endl;
        }
        return ref;
    }
};

int main()
{
    A* p = new A;
    ComPtr p1 = p;
    ComPtr p2 = p;
    return 0;
}

这里我们实现了一个类A,它包含AddRefRelease两个方法,分别用来增减引用计数ref。当ref为0时,打印一个release的log。例子的运行结果如下:

ComPtr源码分析1

可以看到,使用ComPtr之后,我们无需对A类指针p进行计数管理,ComPtr会帮我们维护好p的引用计数。当p1和p2离开作用域时,会对p的引用计数减一,当为0时触发真正的release,这里就是打印一句log。

在了解了ComPtr的基本用途之后,我们来欣赏一下ComPtr的源码。它的实现位于Windows SDK的client.h文件中。ComPtr类的数据成员只有一个原始指针,因此不会有额外的空间开销。它首先对原始指针的AddRefRelease方法进行了封装,后面的方法调用都围绕着这两个封装方法展开:

template 
class ComPtr
{
public:
    typedef T InterfaceType;

protected:
    InterfaceType *ptr_;
    template friend class ComPtr;

    void InternalAddRef() const throw()
    {
        if (ptr_ != nullptr)
        {
            ptr_->AddRef();
        }
    }

    unsigned long InternalRelease() throw()
    {
        unsigned long ref = 0;
        T* temp = ptr_;

        if (temp != nullptr)
        {
            ptr_ = nullptr;
            ref = temp->Release();
        }

        return ref;
    }
}

可以看到InternalRelease函数会将持有的原始指针置为空,并调用原始指针的Release函数返回当前的引用计数。

ComPtr提供了若干类型的构造函数:

ComPtr() throw() : ptr_(nullptr)
{
}

ComPtr(decltype(__nullptr)) throw() : ptr_(nullptr)
{
}

template
ComPtr(_In_opt_ U *other) throw() : ptr_(other)
{
    InternalAddRef();
}

ComPtr(const ComPtr& other) throw() : ptr_(other.ptr_)
{
    InternalAddRef();
}

// copy constructor that allows to instantiate class when U* is convertible to T*
template
ComPtr(const ComPtr &other, typename Details::EnableIf::value, void *>::type * = 0) throw() :
    ptr_(other.ptr_)
{
    InternalAddRef();
}

ComPtr(_Inout_ ComPtr &&other) throw() : ptr_(nullptr)
{
    if (this != reinterpret_cast(&reinterpret_cast(other)))
    {
        Swap(other);
    }
}

// Move constructor that allows instantiation of a class when U* is convertible to T*
template
ComPtr(_Inout_ ComPtr&& other, typename Details::EnableIf::value, void *>::type * = 0) throw() :
    ptr_(other.ptr_)
{
    other.ptr_ = nullptr;
}

如果构造函数传入的参数中包含原始指针,那么这里会调用InternalAddRef来增加原始指针的引用计数。如果传入的参数为类型U的ComPtr,那么还需要判断U类型的指针是否能成功转换为T类型指针,如果不能编译期就要报错,要做到这个就需要借助模板的力量:

typename Details::EnableIf::value, void *>::type * = 0

如果U*可以转换到T*,那么IsConvertible的value成员值为true,进而EnableIf的type类型就可以推导为void *,编译可以正常通过;反之则type类型将不存在,那么编译就会报错,通过这个手段就可以在编译期把问题抛出来。比如以下代码:

#include 
#include 

using Microsoft::WRL::ComPtr;
using namespace std;

class A
{
public:
    unsigned long ref = 0;
    void AddRef()
    {
        ref++;
        cout << "incr ref, cur ref " << ref << endl;
    }
    unsigned long Release()
    {
        ref--;
        cout << "decr ref, cur ref " << ref << endl;
        if(ref == 0)
        {
            cout << "release!" << endl;
        }
        return ref;
    }
};

class B
{
public:
    unsigned long ref = 0;
    void AddRef()
    {
        ref++;
        cout << "incr ref, cur ref " << ref << endl;
    }
    unsigned long Release()
    {
        ref--;
        cout << "decr ref, cur ref " << ref << endl;
        if(ref == 0)
        {
            cout << "release!" << endl;
        }
        return ref;
    }
};

int main()
{
    A* p = new A;
    ComPtr p1 = p;
    ComPtr p2 = p1;
    return 0;
}

由于A和B类型没啥关系,所以指针也是不能互相转换的,那么编译期就会报错:

ComPtr源码分析2

那什么样的A和B类型指针可以互相转换呢?看下面这个例子:

#include 
#include 

using Microsoft::WRL::ComPtr;
using namespace std;

class A
{
public:
    unsigned long ref = 0;
    void AddRef()
    {
        ref++;
        cout << "incr ref, cur ref " << ref << endl;
    }
    unsigned long Release()
    {
        ref--;
        cout << "decr ref, cur ref " << ref << endl;
        if(ref == 0)
        {
            cout << "release!" << endl;
        }
        return ref;
    }
};

class B : public A
{
};

int main()
{
    B* p = new B;
    ComPtr p1 = p;
    ComPtr p2 = p1;
    return 0;
}

这里B类型继承A类型,那么B类型的指针就可以安全地转换为A类型的指针,编译就能顺利通过了。

对于参数为右值引用的构造函数,根据语义,需要把传入ComPtr的原始指针进行转移。既然只是转移,就不需要对原始指针的引用计数进行增减。注意到构造函数实现里有一句:

this != reinterpret_cast(&reinterpret_cast(other))

这句代码的作用其实就是判断传入的other对象是否就是当前的this对象。不直接使用取地址操作符来判断this != &other的原因是因为ComPtr重载了取地址操作符,只能转而使用这种很trick的手段。

与构造函数类似,ComPtr也提供了与之对应的赋值操作符重载的函数。内部实现基本上都是新创建一个对象,然后与当前的this对象进行交换,这里就不展开了。

我们刚刚说过,ComPtr提供了取地址操作符的重载函数,但它又提供了一个名为GetAddressOf的函数,那么它们的区别是什么呢?关于这一点,MSDN上特别做了说明:

This method differs from ComPtr::GetAddressOf in that this method releases a reference to the interface pointer. Use ComPtr::GetAddressOf when you require the address of the interface pointer but don’t want to release that interface.

也就是说,调用取地址操作符时会触发一次Release操作,而GetAddressOf是不会的。我们从源码上也能看出端倪:

Details::ComPtrRef> operator&() throw()
{
    return Details::ComPtrRef>(this);
}

const Details::ComPtrRef> operator&() const throw()
{
    return Details::ComPtrRef>(this);
}

而ComPtrRef类中有个类型转换函数:

operator InterfaceType**() throw()
{
    return this->ptr_->ReleaseAndGetAddressOf();
}

当转换为原始类型的二级指针时,会触发ComPtr的ReleaseAndGetAddressOf函数,这个函数的定义如下:

T** ReleaseAndGetAddressOf() throw()
{
    InternalRelease();
    return &ptr_;
}

它和GetAddressOf函数的实现就多了一句Release:

T** GetAddressOf() throw()
{
    return &ptr_;
}

那么区别就非常明显了。最后我们写个例子来验证一下:

#include 
#include 

using Microsoft::WRL::ComPtr;
using namespace std;

class A
{
public:
    unsigned long ref = 0;
    void AddRef()
    {
        ref++;
    }
    unsigned long Release()
    {
        ref--;
        return ref;
    }
};

void f(A** pp)
{

}

int main()
{
    A* p = new A;
    ComPtr p1 = p;
    ComPtr p2 = p;
    f(&p1);
    f(p2.GetAddressOf());

    cout << boolalpha;
    cout << "p1 nullptr " << ( p1.Get() == nullptr ) << endl;
    cout << "p2 nullptr " << ( p2.Get() == nullptr ) << endl;
    return 0;
}

运行结果如下:

ComPtr源码分析3

如果你觉得我的文章有帮助,欢迎关注我的微信公众号 我是真的想做游戏啊

Reference

[1] ComPtr Class

[2] DirectX11–ComPtr智能指针

你可能感兴趣的:(c++,directx,游戏引擎)