参考:
http://www.codeguru.com/cpp/com-tech/atl/misc/article.php/c37/Asynchronous-Pluggable-Protocol-Implementation-with-ATL.htm
http://blog.csdn.net/cumtzly/article/details/40072613
// http_protocol.cc #include "trident/glue/protocol_impl/http_protocol.h" #include "base/logging.h" #include <WinInet.h> #include <ExDisp.h> namespace trident { HttpProtocol::HttpProtocol(IUnknown* pOuterUnknown) : reference_count_(0), outer_unknown_(pOuterUnknown), grf_BindF_(0), inner_unknown_(NULL) { inner_unknown_ = reinterpret_cast<IUnknown*>((INonDelegatingUnknown*)(this)); ZeroMemory(&bind_info_, sizeof(BINDINFO)); bind_info_.cbSize = sizeof(BINDINFO); } HttpProtocol::~HttpProtocol() { } // INonDelegatingUnknown STDMETHODIMP HttpProtocol::NonDelegatingQueryInterface(REFIID riid, void** ppvObject) { if(ppvObject == NULL){ return E_INVALIDARG; } HRESULT result = E_NOINTERFACE; *ppvObject = NULL; NonDelegatingAddRef(); if (riid == IID_IUnknown) { *ppvObject = static_cast<INonDelegatingUnknown*>(this); }else if(riid == IID_IInternetProtocolRoot) { *ppvObject = static_cast<IInternetProtocolRoot*>(this); } else if (riid == IID_IInternetProtocol) { *ppvObject = static_cast<IInternetProtocol*>(this); } else if (riid == IID_IInternetProtocolEx) { *ppvObject = static_cast<IInternetProtocolEx*>(this); } else if (riid == IID_IInternetProtocolInfo) { *ppvObject = static_cast<IInternetProtocolInfo*>(this); } if(*ppvObject) result = S_OK; else NonDelegatingRelease(); return result; } STDMETHODIMP_(ULONG) HttpProtocol::NonDelegatingAddRef() { return (ULONG)::InterlockedIncrement(&reference_count_); } STDMETHODIMP_(ULONG) HttpProtocol::NonDelegatingRelease() { ::InterlockedDecrement(&reference_count_); if (reference_count_ == 0) { delete this; } return reference_count_; } // IUnknown STDMETHODIMP HttpProtocol::QueryInterface(REFIID riid, void** ppvObject) { if (outer_unknown_) { return outer_unknown_->QueryInterface(riid, ppvObject); } else { return inner_unknown_->QueryInterface(riid, ppvObject); } } STDMETHODIMP_(ULONG) HttpProtocol::AddRef() { if (outer_unknown_) { return outer_unknown_->AddRef(); } else { return inner_unknown_->AddRef(); } } STDMETHODIMP_(ULONG) HttpProtocol::Release() { if (outer_unknown_) { return outer_unknown_->Release(); } else { return inner_unknown_->Release(); } } // IInternetProtocolRoot , XP SP2及以下版本走这个接口 STDMETHODIMP HttpProtocol::Start(LPCWSTR url, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) { if(bind_info == NULL || protocol_sink == NULL || url == NULL) return E_INVALIDARG; bind_url_ = GURL(url); spSink_ = protocol_sink; spBindinfo_ = bind_info; spSink_->QueryInterface(IID_IServiceProvider, (void**)&spServiceProvider_); if(!spServiceProvider_) spBindinfo_->QueryInterface(IID_IServiceProvider, (void**)&spServiceProvider_); DCHECK(spServiceProvider_); // BINDINFO //http://msdn.microsoft.com/en-us/library/ie/aa767897(v=vs.85).aspx //http://msdn.microsoft.com/en-us/library/ie/aa741006(v=vs.85).aspx#Handling_BINDINFO_St HRESULT result = spBindinfo_->GetBindInfo(&grf_BindF_, &bind_info_); DCHECK(result == S_OK); if( !bind_info_.dwCodePage ) bind_info_.dwCodePage = ::GetACP(); /*bind_info_->ReportProgress(BINDSTATUS_FINDINGRESOURCE, strData); bind_info_->ReportProgress(BINDSTATUS_CONNECTING, strData); bind_info_->ReportProgress(BINDSTATUS_SENDINGREQUEST, strData); bind_info_->ReportProgress(BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE, CAtlString(m_url.GetMimeType())); bind_info_->ReportData(BSCF_FIRSTDATANOTIFICATION, 0, bind_url_.GetDataLength()); bind_info_->ReportData(BSCF_LASTDATANOTIFICATION | BSCF_DATAFULLYAVAILABLE, m_url.GetDataLength(), m_url.GetDataLength());*/ return S_OK; return S_OK; } STDMETHODIMP HttpProtocol::Continue(PROTOCOLDATA* pProtocolData) { return S_OK; } // IE6/IE8下有断言,发现调用Terminate后还会调用Abort STDMETHODIMP HttpProtocol::Abort(HRESULT reason, DWORD options) { return S_OK; } STDMETHODIMP HttpProtocol::Terminate(DWORD options) { return S_OK; } STDMETHODIMP HttpProtocol::Suspend() { return E_NOTIMPL; } STDMETHODIMP HttpProtocol::Resume() { return E_NOTIMPL; } STDMETHODIMP HttpProtocol::Read(void* pv, ULONG size, ULONG* pcbRead) { return S_OK; } STDMETHODIMP HttpProtocol::Seek(LARGE_INTEGER move, DWORD origin, ULARGE_INTEGER* new_position) { return S_OK; } STDMETHODIMP HttpProtocol::LockRequest(DWORD options) { has_lock_request_ = true; return S_OK; } STDMETHODIMP HttpProtocol::UnlockRequest() { has_lock_request_ = false; return S_OK; } // XP SP3及以上版本走这个接口 STDMETHODIMP HttpProtocol::StartEx(IUri* uri, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) { if(uri == NULL) { return E_INVALIDARG; } BSTR uri_URL = NULL; std::wstring url; uri->GetAbsoluteUri(&uri_URL); if (uri_URL != NULL) { url = uri_URL; ::SysFreeString(uri_URL); } uri->Release(); return Start(url.c_str(), protocol_sink, bind_info, flags, reserved); } STDMETHODIMP HttpProtocol::ParseUrl(LPCWSTR pwzUrl, PARSEACTION ParseAction, DWORD dwParseFlags, LPWSTR pwzResult, DWORD cchResult, DWORD *pcchResult, DWORD dwReserved) { return S_OK; } STDMETHODIMP HttpProtocol::CombineUrl( LPCWSTR pwzBaseUrl, LPCWSTR pwzRelativeUrl, DWORD dwCombineFlags, LPWSTR pwzResult, DWORD cchResult,DWORD *pcchResult,DWORD dwReserved) { return S_OK; } STDMETHODIMP HttpProtocol::CompareUrl( LPCWSTR pwzUrl1,LPCWSTR pwzUrl2,DWORD dwCompareFlags) { return S_OK; } STDMETHODIMP HttpProtocol::QueryInfo(LPCWSTR pwzUrl, QUERYOPTION OueryOption, DWORD dwQueryFlags, LPVOID pBuffer, DWORD cbBuffer, DWORD *pcbBuf, DWORD dwReserved) { return S_OK; } std::wstring HttpProtocol::GetVerbStr() const { wchar_t* pszRes = NULL; switch (bind_info_.dwBindVerb) { case BINDVERB_GET : pszRes = L"GET"; break; case BINDVERB_POST : pszRes = L"POST"; break; case BINDVERB_PUT : pszRes = L"PUT"; break; case BINDVERB_CUSTOM : pszRes = bind_info_.szCustomVerb; break; } DCHECK(pszRes); return pszRes; } bool HttpProtocol::GetDataToSend(char** lplpData, DWORD* pdwSize) const { if(bind_info_.dwBindVerb == BINDVERB_GET) return false; if (bind_info_.stgmedData.tymed == TYMED_HGLOBAL) { if(lplpData) *lplpData = (char*)bind_info_.stgmedData.hGlobal; if(pdwSize) *pdwSize = bind_info_.cbstgmedData; return true; } else { return false; } } }
<pre name="code" class="cpp">// http_protocol.h #ifndef TRIDENT_PROTOCOL_HTTP_PROTOCOL_H_ #define TRIDENT_PROTOCOL_HTTP_PROTOCOL_H_ // 实现参考 Win2K 源码 // private\inet\urlmon\iapp\cnet.cxx // private\inet\urlmon\iapp\cnethttp.cxx #include <atlbase.h> #include <urlmon.h> #include <vector> #include "base/basictypes.h" #include "url/gurl.h" namespace trident { // COM组件聚合帮助接口 // 参考:http://msdn.microsoft.com/en-us/library/windows/desktop/dd390339(v=vs.85).aspx struct INonDelegatingUnknown { STDMETHOD(NonDelegatingQueryInterface)(REFIID riid, void** ppvObject) = 0; STDMETHOD_(ULONG, NonDelegatingAddRef)() = 0; STDMETHOD_(ULONG, NonDelegatingRelease)() = 0; }; class HttpProtocol : public INonDelegatingUnknown, public IInternetProtocolEx, public IInternetProtocolInfo{ public: HttpProtocol(IUnknown* pOuterUnknown); virtual ~HttpProtocol(); public: // INonDelegatingUnknown // 只提供Protocol接口查询,不提供Sink接口查询 STDMETHOD(NonDelegatingQueryInterface)(REFIID riid, void** ppvObject); STDMETHOD_(ULONG, NonDelegatingAddRef)(); STDMETHOD_(ULONG, NonDelegatingRelease)(); // IUnknown STDMETHOD(QueryInterface)(REFIID riid, void** ppvObject); STDMETHOD_(ULONG, AddRef)(); STDMETHOD_(ULONG, Release)(); // IInternetProtocolRoot STDMETHOD(Start)(LPCWSTR url, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved); STDMETHOD(Continue)(PROTOCOLDATA* pProtocolData); STDMETHOD(Abort)(HRESULT reason, DWORD options); STDMETHOD(Terminate)(DWORD options); STDMETHOD(Suspend)(); STDMETHOD(Resume)(); // IInternetProtocol : public IInternetProtocolRoot STDMETHOD(Read)(void* pv, ULONG size, ULONG* pcbRead); STDMETHOD(Seek)(LARGE_INTEGER move, DWORD origin, ULARGE_INTEGER* new_position); STDMETHOD(LockRequest)(DWORD options); STDMETHOD(UnlockRequest)(); // IInternetProtocolEx : public IInternetProtocol STDMETHOD(StartEx)(IUri* uri, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved); // IInternetProtocolInfo STDMETHOD(ParseUrl)(LPCWSTR pwzUrl, PARSEACTION ParseAction, DWORD dwParseFlags, LPWSTR pwzResult, DWORD cchResult, DWORD *pcchResult, DWORD dwReserved) ; STDMETHOD(CombineUrl)( LPCWSTR pwzBaseUrl, LPCWSTR pwzRelativeUrl, DWORD dwCombineFlags, LPWSTR pwzResult,DWORD cchResult,DWORD *pcchResult,DWORD dwReserved) ; STDMETHOD(CompareUrl)( LPCWSTR pwzUrl1,LPCWSTR pwzUrl2,DWORD dwCompareFlags) ; STDMETHOD(QueryInfo)(LPCWSTR pwzUrl, QUERYOPTION OueryOption, DWORD dwQueryFlags, LPVOID pBuffer, DWORD cbBuffer, DWORD *pcbBuf, DWORD dwReserved); private: std::wstring GetVerbStr() const ; bool GetDataToSend(char** lplpData, DWORD* pdwSize) const ; private: volatile LONG reference_count_; IUnknown* outer_unknown_; IUnknown* inner_unknown_; CComPtr<IInternetProtocolSink> spSink_; CComPtr<IInternetBindInfo> spBindinfo_; CComPtr<IServiceProvider> spServiceProvider_; BINDINFO bind_info_; DWORD grf_BindF_; bool has_lock_request_; GURL bind_url_; DISALLOW_COPY_AND_ASSIGN(HttpProtocol); }; } //namespace trident #endif // TRIDENT_PROTOCOL_HTTP_PROTOCOL_H_
<pre name="code" class="cpp">// http protocol factory.cc #include "trident/glue/protocol_impl/http_protocol_factory.h" #include "base/logging.h" #include "trident/glue/protocol_impl/http_protocol.h" namespace trident { HttpProtocolFactory::HttpProtocolFactory(bool is_https_protocol) : reference_count_(1) { HRESULT result = S_OK; if (is_https_protocol) { result = ::CoGetClassObject(CLSID_HttpSProtocol, CLSCTX_INPROC_SERVER, NULL, IID_IClassFactory, (void**)&origin_factory_); } else { result = ::CoGetClassObject(CLSID_HttpProtocol, CLSCTX_INPROC_SERVER, NULL, IID_IClassFactory, (void**)&origin_factory_); } DCHECK(result == S_OK); DCHECK(origin_factory_ != NULL); } HttpProtocolFactory::~HttpProtocolFactory() { } // IUnknown STDMETHODIMP HttpProtocolFactory::QueryInterface(REFIID riid, void** ppvObject) { if (!ppvObject) { return E_INVALIDARG; } *ppvObject = NULL; HRESULT result = E_NOINTERFACE; if (riid == IID_IUnknown) { *ppvObject = static_cast<IUnknown*>(this); } else if (riid == IID_IClassFactory) { *ppvObject = static_cast<IClassFactory*>(this); } if (*ppvObject) { static_cast<IUnknown*>(*ppvObject)->AddRef(); result = S_OK; } return result; } STDMETHODIMP_(ULONG) HttpProtocolFactory::AddRef() { return ::InterlockedIncrement(&reference_count_); } STDMETHODIMP_(ULONG) HttpProtocolFactory::Release() { ULONG count = ::InterlockedDecrement(&reference_count_); if (count == 0) { delete this; return 0; } return count; } // IClassFactory STDMETHODIMP HttpProtocolFactory::CreateInstance(IUnknown* pUnkOuter, REFIID riid, void** ppvObject) { if (pUnkOuter && riid != IID_IUnknown) { return CLASS_E_NOAGGREGATION; } HttpProtocol* http_protocol = new HttpProtocol(pUnkOuter); if(http_protocol->NonDelegatingQueryInterface(riid, ppvObject) != S_OK) { delete http_protocol; *ppvObject = NULL; return E_NOINTERFACE; }else { return S_OK; } } STDMETHODIMP HttpProtocolFactory::LockServer(BOOL fLock) { if(fLock) AddRef(); else Release(); return S_OK; } }
// http protocol factroy .h #ifndef TRIDENT_HTTP_PROTOCOL_FACTORY_H_ #define TRIDENT_HTTP_PROTOCOL_FACTORY_H_ #include <atlbase.h> #include <Unknwn.h> #include "base/basictypes.h" namespace trident { class HttpProtocolFactory : public IClassFactory { public: explicit HttpProtocolFactory(bool is_https_protocol); // IUnknown STDMETHOD(QueryInterface)(REFIID riid, void** ppvObject); STDMETHOD_(ULONG, AddRef)(); STDMETHOD_(ULONG, Release)(); // IClassFactory STDMETHOD(CreateInstance)(IUnknown* pUnkOuter, REFIID riid, void** ppvObject); STDMETHOD(LockServer)(BOOL fLock); private: virtual ~HttpProtocolFactory(); volatile ULONG reference_count_; CComPtr<IClassFactory> origin_factory_; DISALLOW_IMPLICIT_CONSTRUCTORS(HttpProtocolFactory); }; } #endif // TRIDENT_HTTP_PROTOCOL_FACTORY_H_