自定义异步可插入协议 (代码记录)

参考:

http://www.codeguru.com/cpp/com-tech/atl/misc/article.php/c37/Asynchronous-Pluggable-Protocol-Implementation-with-ATL.htm

http://blog.csdn.net/cumtzly/article/details/40072613

    // http_protocol.cc  
      
    #include "trident/glue/protocol_impl/http_protocol.h"  
    #include "base/logging.h"  
    #include <WinInet.h>  
    #include <ExDisp.h>  
      
    namespace trident {  
      
    HttpProtocol::HttpProtocol(IUnknown* pOuterUnknown)  
        : reference_count_(0),   
          outer_unknown_(pOuterUnknown),   
          grf_BindF_(0),  
          inner_unknown_(NULL) {  
      inner_unknown_ = reinterpret_cast<IUnknown*>((INonDelegatingUnknown*)(this));    
      ZeroMemory(&bind_info_, sizeof(BINDINFO));  
      bind_info_.cbSize = sizeof(BINDINFO);  
    }  
      
    HttpProtocol::~HttpProtocol() {  
    }  
      
    // INonDelegatingUnknown  
    STDMETHODIMP HttpProtocol::NonDelegatingQueryInterface(REFIID riid, void** ppvObject) {  
      if(ppvObject == NULL){  
        return E_INVALIDARG;  
      }  
      
      HRESULT result = E_NOINTERFACE;  
      *ppvObject = NULL;  
      NonDelegatingAddRef();  
      if (riid == IID_IUnknown) {  
        *ppvObject = static_cast<INonDelegatingUnknown*>(this);  
      }else if(riid == IID_IInternetProtocolRoot) {   
        *ppvObject = static_cast<IInternetProtocolRoot*>(this);  
      } else if (riid == IID_IInternetProtocol) {  
        *ppvObject = static_cast<IInternetProtocol*>(this);  
      } else if (riid == IID_IInternetProtocolEx) {                                     
        *ppvObject = static_cast<IInternetProtocolEx*>(this);  
      } else if (riid == IID_IInternetProtocolInfo) {                                     
        *ppvObject = static_cast<IInternetProtocolInfo*>(this);  
      }  
      if(*ppvObject)  
        result = S_OK;  
      else  
        NonDelegatingRelease();  
      return result;  
    }  
      
    STDMETHODIMP_(ULONG) HttpProtocol::NonDelegatingAddRef() {  
      return (ULONG)::InterlockedIncrement(&reference_count_);  
    }  
      
    STDMETHODIMP_(ULONG) HttpProtocol::NonDelegatingRelease() {  
      ::InterlockedDecrement(&reference_count_);  
      if (reference_count_ == 0) {  
        delete this;  
      }  
      return reference_count_;  
    }  
      
    // IUnknown  
    STDMETHODIMP HttpProtocol::QueryInterface(REFIID riid, void** ppvObject) {  
      if (outer_unknown_) {  
        return outer_unknown_->QueryInterface(riid, ppvObject);  
      } else {  
        return inner_unknown_->QueryInterface(riid, ppvObject);  
      }  
    }  
      
    STDMETHODIMP_(ULONG) HttpProtocol::AddRef() {  
      if (outer_unknown_) {  
        return outer_unknown_->AddRef();  
      } else {  
        return inner_unknown_->AddRef();  
      }  
    }  
      
    STDMETHODIMP_(ULONG) HttpProtocol::Release() {  
      if (outer_unknown_) {  
        return outer_unknown_->Release();  
      } else {  
        return inner_unknown_->Release();  
      }  
    }  
      
    // IInternetProtocolRoot , XP SP2及以下版本走这个接口  
    STDMETHODIMP HttpProtocol::Start(LPCWSTR url, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) {  
       
      if(bind_info == NULL || protocol_sink == NULL || url == NULL)  
        return E_INVALIDARG;  
       
      bind_url_ = GURL(url);  
      
      spSink_ = protocol_sink;  
      
      spBindinfo_ = bind_info;     
      
      spSink_->QueryInterface(IID_IServiceProvider, (void**)&spServiceProvider_);  
      if(!spServiceProvider_)  
        spBindinfo_->QueryInterface(IID_IServiceProvider, (void**)&spServiceProvider_);  
      DCHECK(spServiceProvider_);                                               
      
      // BINDINFO  
      //http://msdn.microsoft.com/en-us/library/ie/aa767897(v=vs.85).aspx  
      //http://msdn.microsoft.com/en-us/library/ie/aa741006(v=vs.85).aspx#Handling_BINDINFO_St  
      HRESULT result = spBindinfo_->GetBindInfo(&grf_BindF_, &bind_info_);  
      DCHECK(result == S_OK);  
      
      if( !bind_info_.dwCodePage )  
        bind_info_.dwCodePage = ::GetACP();  
      
      /*bind_info_->ReportProgress(BINDSTATUS_FINDINGRESOURCE, strData); 
      bind_info_->ReportProgress(BINDSTATUS_CONNECTING, strData); 
      bind_info_->ReportProgress(BINDSTATUS_SENDINGREQUEST, strData); 
      bind_info_->ReportProgress(BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE, CAtlString(m_url.GetMimeType())); 
      bind_info_->ReportData(BSCF_FIRSTDATANOTIFICATION, 0, bind_url_.GetDataLength()); 
      bind_info_->ReportData(BSCF_LASTDATANOTIFICATION | BSCF_DATAFULLYAVAILABLE, m_url.GetDataLength(), m_url.GetDataLength());*/  
      
      
      return S_OK;  
      
      return S_OK;  
    }  
      
    STDMETHODIMP HttpProtocol::Continue(PROTOCOLDATA* pProtocolData) {  
      return S_OK;  
    }  
      
    // IE6/IE8下有断言,发现调用Terminate后还会调用Abort  
    STDMETHODIMP HttpProtocol::Abort(HRESULT reason, DWORD options) {  
      return S_OK;  
    }  
      
    STDMETHODIMP HttpProtocol::Terminate(DWORD options) {  
      return S_OK;  
    }  
      
    STDMETHODIMP HttpProtocol::Suspend() {  
      return E_NOTIMPL;  
    }  
      
    STDMETHODIMP HttpProtocol::Resume() {   
      return E_NOTIMPL;  
    }  
      
    STDMETHODIMP HttpProtocol::Read(void* pv, ULONG size, ULONG* pcbRead) {  
      return S_OK;  
    }  
      
    STDMETHODIMP HttpProtocol::Seek(LARGE_INTEGER move, DWORD origin, ULARGE_INTEGER* new_position) {    
      return S_OK;  
    }  
      
    STDMETHODIMP HttpProtocol::LockRequest(DWORD options) {  
      has_lock_request_ = true;  
      return S_OK;  
    }  
      
    STDMETHODIMP HttpProtocol::UnlockRequest() {  
      has_lock_request_ = false;  
      return S_OK;  
    }  
      
    // XP SP3及以上版本走这个接口  
    STDMETHODIMP HttpProtocol::StartEx(IUri* uri, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) {   
      if(uri == NULL) {  
        return E_INVALIDARG;  
      }  
      
      BSTR uri_URL = NULL;  
      std::wstring url;  
      uri->GetAbsoluteUri(&uri_URL);  
      if (uri_URL != NULL) {  
        url = uri_URL;  
        ::SysFreeString(uri_URL);  
      }  
      uri->Release();  
      
      return Start(url.c_str(), protocol_sink, bind_info, flags, reserved);  
    }  
      
    STDMETHODIMP  HttpProtocol::ParseUrl(LPCWSTR pwzUrl, PARSEACTION ParseAction, DWORD dwParseFlags, LPWSTR pwzResult,    
      DWORD cchResult,  DWORD *pcchResult, DWORD dwReserved) {  
      return S_OK;  
    }  
      
    STDMETHODIMP  HttpProtocol::CombineUrl( LPCWSTR pwzBaseUrl,  LPCWSTR pwzRelativeUrl,  DWORD dwCombineFlags, LPWSTR pwzResult,  
      DWORD cchResult,DWORD *pcchResult,DWORD dwReserved)  {  
      return S_OK;  
    }  
      
    STDMETHODIMP  HttpProtocol::CompareUrl( LPCWSTR pwzUrl1,LPCWSTR pwzUrl2,DWORD dwCompareFlags) {  
      return S_OK;  
    }  
      
    STDMETHODIMP  HttpProtocol::QueryInfo(LPCWSTR pwzUrl,  QUERYOPTION OueryOption, DWORD dwQueryFlags, LPVOID pBuffer, DWORD cbBuffer,   
      DWORD *pcbBuf, DWORD dwReserved)  {  
       return S_OK;  
    }  
      
    std::wstring HttpProtocol::GetVerbStr() const {  
      wchar_t* pszRes = NULL;  
      switch (bind_info_.dwBindVerb)  
      {  
      case BINDVERB_GET      :  
        pszRes = L"GET";  
        break;  
      case BINDVERB_POST     :  
        pszRes = L"POST";  
        break;  
      case BINDVERB_PUT      :  
        pszRes = L"PUT";  
        break;  
      case BINDVERB_CUSTOM   :  
        pszRes = bind_info_.szCustomVerb;  
        break;  
      }  
      DCHECK(pszRes);  
      return pszRes;  
    }  
      
      
    bool HttpProtocol::GetDataToSend(char** lplpData, DWORD* pdwSize) const {  
      
      if(bind_info_.dwBindVerb == BINDVERB_GET)  
        return false;  
      
      if (bind_info_.stgmedData.tymed == TYMED_HGLOBAL) {  
        if(lplpData)  
          *lplpData = (char*)bind_info_.stgmedData.hGlobal;  
        if(pdwSize)  
          *pdwSize = bind_info_.cbstgmedData;  
        return true;  
      } else {  
        return false;  
      }  
    }  
      
    }  

    <pre name="code" class="cpp">// http_protocol.h  
      
    #ifndef TRIDENT_PROTOCOL_HTTP_PROTOCOL_H_  
    #define TRIDENT_PROTOCOL_HTTP_PROTOCOL_H_  
      
    // 实现参考 Win2K 源码  
    // private\inet\urlmon\iapp\cnet.cxx  
    // private\inet\urlmon\iapp\cnethttp.cxx  
      
    #include <atlbase.h>  
    #include <urlmon.h>  
    #include <vector>  
    #include "base/basictypes.h"  
    #include "url/gurl.h"  
      
    namespace trident {  
      
    // COM组件聚合帮助接口  
    // 参考:http://msdn.microsoft.com/en-us/library/windows/desktop/dd390339(v=vs.85).aspx  
    struct INonDelegatingUnknown {  
      STDMETHOD(NonDelegatingQueryInterface)(REFIID riid, void** ppvObject) = 0;  
      STDMETHOD_(ULONG, NonDelegatingAddRef)() = 0;  
      STDMETHOD_(ULONG, NonDelegatingRelease)() = 0;  
    };  
      
    class HttpProtocol : public INonDelegatingUnknown,  
                         public IInternetProtocolEx,   
                         public IInternetProtocolInfo{  
    public:  
      HttpProtocol(IUnknown* pOuterUnknown);  
      virtual ~HttpProtocol();  
      
    public:  
      
      // INonDelegatingUnknown  
      // 只提供Protocol接口查询,不提供Sink接口查询  
      STDMETHOD(NonDelegatingQueryInterface)(REFIID riid, void** ppvObject);  
      STDMETHOD_(ULONG, NonDelegatingAddRef)();  
      STDMETHOD_(ULONG, NonDelegatingRelease)();  
      
      // IUnknown  
      STDMETHOD(QueryInterface)(REFIID riid, void** ppvObject);  
      STDMETHOD_(ULONG, AddRef)();  
      STDMETHOD_(ULONG, Release)();  
      
      // IInternetProtocolRoot  
      STDMETHOD(Start)(LPCWSTR url, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved);  
      STDMETHOD(Continue)(PROTOCOLDATA* pProtocolData);  
      STDMETHOD(Abort)(HRESULT reason, DWORD options);  
      STDMETHOD(Terminate)(DWORD options);  
      STDMETHOD(Suspend)();  
      STDMETHOD(Resume)();  
      
      // IInternetProtocol : public IInternetProtocolRoot  
      STDMETHOD(Read)(void* pv, ULONG size, ULONG* pcbRead);  
      STDMETHOD(Seek)(LARGE_INTEGER move, DWORD origin, ULARGE_INTEGER* new_position);  
      STDMETHOD(LockRequest)(DWORD options);  
      STDMETHOD(UnlockRequest)();  
      
      // IInternetProtocolEx : public IInternetProtocol  
      STDMETHOD(StartEx)(IUri* uri, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved);  
      
      // IInternetProtocolInfo  
      STDMETHOD(ParseUrl)(LPCWSTR pwzUrl, PARSEACTION ParseAction, DWORD dwParseFlags, LPWSTR pwzResult,  DWORD cchResult,  DWORD *pcchResult, DWORD dwReserved) ;  
      STDMETHOD(CombineUrl)( LPCWSTR pwzBaseUrl,  LPCWSTR pwzRelativeUrl,  DWORD dwCombineFlags, LPWSTR pwzResult,DWORD cchResult,DWORD *pcchResult,DWORD dwReserved) ;  
      STDMETHOD(CompareUrl)( LPCWSTR pwzUrl1,LPCWSTR pwzUrl2,DWORD dwCompareFlags) ;  
      STDMETHOD(QueryInfo)(LPCWSTR pwzUrl,  QUERYOPTION OueryOption, DWORD dwQueryFlags, LPVOID pBuffer, DWORD cbBuffer, DWORD *pcbBuf, DWORD dwReserved);  
      
    private:  
      std::wstring GetVerbStr() const ;  
      bool  GetDataToSend(char** lplpData, DWORD* pdwSize) const ;  
      
    private:  
      
      volatile LONG reference_count_;  
      
      IUnknown* outer_unknown_;  
      
      IUnknown* inner_unknown_;  
      
      CComPtr<IInternetProtocolSink> spSink_;  
      
      CComPtr<IInternetBindInfo> spBindinfo_;  
      
      CComPtr<IServiceProvider> spServiceProvider_;  
      
      BINDINFO bind_info_;  
      
      DWORD grf_BindF_;  
      
      bool has_lock_request_;  
      
      GURL bind_url_;  
      
      DISALLOW_COPY_AND_ASSIGN(HttpProtocol);  
    };  
      
    }  //namespace trident  
      
    #endif  // TRIDENT_PROTOCOL_HTTP_PROTOCOL_H_  

    <pre name="code" class="cpp">// http protocol factory.cc  
      
    #include "trident/glue/protocol_impl/http_protocol_factory.h"  
    #include "base/logging.h"  
    #include "trident/glue/protocol_impl/http_protocol.h"  
      
    namespace trident {  
      
    HttpProtocolFactory::HttpProtocolFactory(bool is_https_protocol) : reference_count_(1) {  
      HRESULT result = S_OK;  
      if (is_https_protocol) {  
        result = ::CoGetClassObject(CLSID_HttpSProtocol, CLSCTX_INPROC_SERVER, NULL, IID_IClassFactory, (void**)&origin_factory_);  
      } else {  
        result = ::CoGetClassObject(CLSID_HttpProtocol, CLSCTX_INPROC_SERVER, NULL, IID_IClassFactory, (void**)&origin_factory_);  
      }  
      DCHECK(result == S_OK);  
      DCHECK(origin_factory_ != NULL);  
    }  
      
    HttpProtocolFactory::~HttpProtocolFactory() {  
    }  
      
    // IUnknown  
    STDMETHODIMP HttpProtocolFactory::QueryInterface(REFIID riid, void** ppvObject) {  
      if (!ppvObject) {  
        return E_INVALIDARG;  
      }  
      *ppvObject = NULL;  
      HRESULT result = E_NOINTERFACE;  
      if (riid == IID_IUnknown) {  
        *ppvObject = static_cast<IUnknown*>(this);  
      } else if (riid == IID_IClassFactory) {  
        *ppvObject = static_cast<IClassFactory*>(this);  
      }  
      
      if (*ppvObject) {  
        static_cast<IUnknown*>(*ppvObject)->AddRef();  
        result = S_OK;  
      }  
      
      return result;  
    }  
      
    STDMETHODIMP_(ULONG) HttpProtocolFactory::AddRef() {  
      return ::InterlockedIncrement(&reference_count_);  
    }  
      
    STDMETHODIMP_(ULONG) HttpProtocolFactory::Release() {  
      ULONG count = ::InterlockedDecrement(&reference_count_);  
      if (count == 0) {  
        delete this;  
        return 0;  
      }  
      
      return count;  
    }  
      
    // IClassFactory  
    STDMETHODIMP HttpProtocolFactory::CreateInstance(IUnknown* pUnkOuter, REFIID riid, void** ppvObject) {  
      if (pUnkOuter && riid != IID_IUnknown) {  
        return CLASS_E_NOAGGREGATION;  
      }  
      
      HttpProtocol* http_protocol = new HttpProtocol(pUnkOuter);  
      if(http_protocol->NonDelegatingQueryInterface(riid, ppvObject) != S_OK) {  
        delete http_protocol;  
        *ppvObject = NULL;  
        return E_NOINTERFACE;  
      }else {  
        return S_OK;  
      }  
    }  
      
    STDMETHODIMP HttpProtocolFactory::LockServer(BOOL fLock) {  
      if(fLock)  
        AddRef();  
      else  
        Release();  
      return S_OK;  
    }  
      
    }  

    // http protocol factroy .h  
      
    #ifndef TRIDENT_HTTP_PROTOCOL_FACTORY_H_  
    #define TRIDENT_HTTP_PROTOCOL_FACTORY_H_  
      
    #include <atlbase.h>  
    #include <Unknwn.h>  
    #include "base/basictypes.h"  
      
    namespace trident {  
      
    class HttpProtocolFactory : public IClassFactory {  
    public:  
      
      explicit HttpProtocolFactory(bool is_https_protocol);  
      
      // IUnknown  
      STDMETHOD(QueryInterface)(REFIID riid, void** ppvObject);  
      STDMETHOD_(ULONG, AddRef)();  
      STDMETHOD_(ULONG, Release)();  
      
      // IClassFactory  
      STDMETHOD(CreateInstance)(IUnknown* pUnkOuter, REFIID riid, void** ppvObject);  
      STDMETHOD(LockServer)(BOOL fLock);  
      
    private:  
      virtual ~HttpProtocolFactory();  
      
      volatile ULONG reference_count_;  
      CComPtr<IClassFactory> origin_factory_;  
      DISALLOW_IMPLICIT_CONSTRUCTORS(HttpProtocolFactory);  
    };  
      
    }  
      
    #endif  // TRIDENT_HTTP_PROTOCOL_FACTORY_H_  


你可能感兴趣的:(自定义异步可插入协议 (代码记录))