该函数在 pytorch/aten/src/TH/THGeneral.h.in
中定义
在pytorch/aten/src/TH/THGeneral.h.in
中,通过宏替换函数名实现C的泛型编程
#define TH_CONCAT_STRING_2(x,y) TH_CONCAT_STRING_2_EXPAND(x,y)
#define TH_CONCAT_STRING_2_EXPAND(x,y) #x #y
#define TH_CONCAT_STRING_3(x,y,z) TH_CONCAT_STRING_3_EXPAND(x,y,z)
#define TH_CONCAT_STRING_3_EXPAND(x,y,z) #x #y #z
#define TH_CONCAT_STRING_4(x,y,z,w) TH_CONCAT_STRING_4_EXPAND(x,y,z,w)
#define TH_CONCAT_STRING_4_EXPAND(x,y,z,w) #x #y #z #w
#define TH_CONCAT_2(x,y) TH_CONCAT_2_EXPAND(x,y)
#define TH_CONCAT_2_EXPAND(x,y) x ## y
#define TH_CONCAT_3(x,y,z) TH_CONCAT_3_EXPAND(x,y,z)
#define TH_CONCAT_3_EXPAND(x,y,z) x ## y ## z
#define TH_CONCAT_4_EXPAND(x,y,z,w) x ## y ## z ## w
#define TH_CONCAT_4(x,y,z,w) TH_CONCAT_4_EXPAND(x,y,z,w)
#define THMin(X, Y) ((X) < (Y) ? (X) : (Y))
#define THMax(X, Y) ((X) > (Y) ? (X) : (Y))
THStorage 在pytorch 0.4.0版本中改为
在aten/src/ATen/Storage.h
中定义:
struct AT_API Storage {
public:
Storage() = delete;
Storage(StorageImpl* storage_impl) : storage_impl_(storage_impl) {}
Storage(
at::ScalarType,
size_t size,
Allocator* allocator,
bool resizable = false);
Storage(
at::ScalarType,
at::DataPtr,
size_t size,
const std::function<void(void*)>& deleter,
bool resizable = false);
~Storage();
// There are reasonable interpretations of these constructors, but they're to
// be implemented on demand.
Storage(Storage&) = delete;
Storage(const Storage&) = delete;
Storage(Storage&&) = delete;
Storage(const Storage&&) = delete;
void set_pImpl(StorageImpl* storage_impl) {
storage_impl_ = storage_impl;
}
StorageImpl* pImpl() {
return storage_impl_;
}
StorageImpl* pImpl() const {
return storage_impl_;
}
StorageImpl* retained_pImpl() const {
storage_impl_->retain();
return storage_impl_;
}
protected:
StorageImpl* storage_impl_;
};
在aten/src/ATen/StorageImpI.h
中实现:
truct AT_API StorageImpl : public Retainable {
public:
StorageImpl() = delete;
virtual ~StorageImpl() {};
StorageImpl(
at::ScalarType scalar_type,
ptrdiff_t size,
at::DataPtr data_ptr,
at::Allocator* allocator,
bool resizable);
StorageImpl(
at::ScalarType scalar_type,
ptrdiff_t size,
at::Allocator* allocator,
bool resizable);
StorageImpl(StorageImpl&) = delete;
StorageImpl(const StorageImpl&) = delete;
// NB: Don't move ref count!
StorageImpl(StorageImpl&& other) = delete;
StorageImpl(const StorageImpl&&) = delete;
StorageImpl& operator=(StorageImpl&& other) = delete;
// TODO: Rename this into th_data, and move it out of the class;
// the real data shouldn't call th::from_type
template <typename T>
inline T* data() const {
auto scalar_type_T = at::CTypeToScalarType>::to();
if (scalar_type_ != scalar_type_T) {
AT_ERROR(
"Attempt to access StorageImpl having data type ",
at::toString(scalar_type_),
" as data type ",
at::toString(scalar_type_T));
}
return unsafe_data();
}
template <typename T>
inline T* unsafe_data() const {
return static_cast(this->data_ptr_.get());
}
void release_resources() {
if (finalizer_) {
(*finalizer_)();
}
finalizer_ = nullptr;
data_ptr_.clear();
}
void operator=(const StorageImpl&) = delete;
virtual size_t elementSize() const {
return at::elementSize(scalar_type_);
}
Type& type();
// TODO: Rename to size() and size to size_
ptrdiff_t size() const {
return size_;
};
void set_size(ptrdiff_t size) {
size_ = size;
};
bool resizable() const {
return resizable_;
};
at::DataPtr& data_ptr() {
return data_ptr_;
};
void set_data_ptr(at::DataPtr&& data_ptr) {
data_ptr_ = std::move(data_ptr);
};
void* data() {
return data_ptr_.get();
};
const void* data() const {
return data_ptr_.get();
};
at::Allocator* allocator() {
return allocator_;
};
at::ScalarType& scalar_type() {
return scalar_type_;
};
const at::Allocator* allocator() const {
return allocator_;
};
int getDevice() const {
return data_ptr_.device().index();
}
void set_resizable(bool resizable) {
resizable_ = resizable;
}
private:
at::ScalarType scalar_type_;
at::DataPtr data_ptr_;
ptrdiff_t size_;
bool resizable_;
public:
at::Allocator* allocator_;
std::unique_ptr finalizer_;
};
由重载__call__实现,同时会加载注册的hooks函数,包括register_backward_hook()
和register_forward_pre_hook
.
另外为了保证hooks函数的有序性,使用了python collections中的OrderedDict实现。
关键是底层的Torch._C中实现的Variable._execution_engine.run_backward()
函数
在pytorch/torch/csrc/autograd/python_engine.cpp
中实现:
// Implementation of torch._C._EngineBase.run_backward
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
_maybe_reinitialize_engine_after_fork();
PyObject *tensors = nullptr;
PyObject *grad_tensors = nullptr;
unsigned char keep_graph = 0;
unsigned char create_graph = 0;
PyObject *inputs = nullptr;
unsigned char allow_unreachable = 0;
const char *accepted_kwargs[] = {
"tensors", "grad_tensors", "keep_graph", "create_graph", "inputs",
"allow_unreachable", nullptr
};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Ob", (char**)accepted_kwargs,
&tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable))
return nullptr;
THPUtils_assert(PyTuple_Check(tensors), "tensors argument is expected to "
"be a tuple, but got %s", THPUtils_typename(tensors));
THPUtils_assert(PyTuple_Check(grad_tensors), "grad_tensors argument is "
"expected to be a tuple, but got %s", THPUtils_typename(grad_tensors));
Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors);
Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors);
THPUtils_assert(num_tensors == num_gradients, "got %ld tensors and %ld "
"gradients", num_tensors, num_gradients);
edge_list roots;
roots.reserve(num_tensors);
variable_list grads;
grads.reserve(num_tensors);
for (int i = 0; i < num_tensors; i++) {
PyObject *_tensor = PyTuple_GET_ITEM(tensors, i);
THPUtils_assert(THPVariable_Check(_tensor), "element %d of tensors "
"tuple is not a Tensor", i);
auto& variable = ((THPVariable*)_tensor)->cdata;
auto gradient_edge = variable.gradient_edge();
THPUtils_assert(gradient_edge.function,
"element %d of tensors does not require grad and does not have a grad_fn", i);
roots.push_back(std::move(gradient_edge));
PyObject *grad = PyTuple_GET_ITEM(grad_tensors, i);
if (THPVariable_Check(grad)) {
grads.push_back(((THPVariable*)grad)->cdata);
} else {
THPUtils_assert(grad == Py_None,
"element %d of gradients tuple is not a Tensor or None", i);
THPUtils_assert(!variable.requires_grad(),
"element %d of gradients tuple is None, but the corresponding Tensor requires grad");
}
}
std::vector output_edges;
if (inputs != nullptr) {
int num_inputs = PyTuple_GET_SIZE(inputs);
output_edges.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
PyObject *input = PyTuple_GET_ITEM(inputs, i);
THPUtils_assert(THPVariable_Check(input),
"all inputs have to be Tensors, but got %s", THPUtils_typename(input));
THPVariable *input_var = (THPVariable*)input;
const auto output_nr = input_var->cdata.output_nr();
auto grad_fn = input_var->cdata.grad_fn();
if (!grad_fn) {
grad_fn = input_var->cdata.try_get_grad_accumulator();
}
THPUtils_assert(input_var->cdata.requires_grad(),
"One of the differentiated Tensors does not require grad");
if (!grad_fn) {
output_edges.emplace_back();
} else {
output_edges.emplace_back(grad_fn, output_nr);
}
}
}
variable_list outputs;
{
AutoNoGIL no_gil;
outputs = engine.execute(roots, grads, keep_graph, create_graph, output_edges);
}
if (inputs != nullptr) {
int num_inputs = PyTuple_GET_SIZE(inputs);
THPObjectPtr py_outputs {PyTuple_New(num_inputs)};
if (!py_outputs) return nullptr;
for (int i = 0; i < num_inputs; i++) {
THPUtils_assert(allow_unreachable || outputs[i].defined(), "One of the "
"differentiated Tensors appears to not have been used "
"in the graph. Set allow_unused=True if this is the "
"desired behavior.");
PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i]));
}
return py_outputs.release();
} else {
Py_RETURN_NONE;
}
END_HANDLE_TH_ERRORS
}