Pytorch 源码阅读笔记

THArgCheck

该函数在 pytorch/aten/src/TH/THGeneral.h.in 中定义

宏替换实现泛型

pytorch/aten/src/TH/THGeneral.h.in 中,通过宏替换函数名实现C的泛型编程

#define TH_CONCAT_STRING_2(x,y) TH_CONCAT_STRING_2_EXPAND(x,y)
#define TH_CONCAT_STRING_2_EXPAND(x,y) #x #y

#define TH_CONCAT_STRING_3(x,y,z) TH_CONCAT_STRING_3_EXPAND(x,y,z)
#define TH_CONCAT_STRING_3_EXPAND(x,y,z) #x #y #z

#define TH_CONCAT_STRING_4(x,y,z,w) TH_CONCAT_STRING_4_EXPAND(x,y,z,w)
#define TH_CONCAT_STRING_4_EXPAND(x,y,z,w) #x #y #z #w

#define TH_CONCAT_2(x,y) TH_CONCAT_2_EXPAND(x,y)
#define TH_CONCAT_2_EXPAND(x,y) x ## y

#define TH_CONCAT_3(x,y,z) TH_CONCAT_3_EXPAND(x,y,z)
#define TH_CONCAT_3_EXPAND(x,y,z) x ## y ## z

#define TH_CONCAT_4_EXPAND(x,y,z,w) x ## y ## z ## w
#define TH_CONCAT_4(x,y,z,w) TH_CONCAT_4_EXPAND(x,y,z,w)

#define THMin(X, Y)  ((X) < (Y) ? (X) : (Y))
#define THMax(X, Y)  ((X) > (Y) ? (X) : (Y))

THStorage实现

THStorage 在pytorch 0.4.0版本中改为
aten/src/ATen/Storage.h中定义:

struct AT_API Storage {
public:
  Storage() = delete;
  Storage(StorageImpl* storage_impl) : storage_impl_(storage_impl) {}
  Storage(
      at::ScalarType,
      size_t size,
      Allocator* allocator,
      bool resizable = false);
  Storage(
      at::ScalarType,
      at::DataPtr,
      size_t size,
      const std::function<void(void*)>& deleter,
      bool resizable = false);
  ~Storage();
  // There are reasonable interpretations of these constructors, but they're to
  // be implemented on demand.
  Storage(Storage&) = delete;
  Storage(const Storage&) = delete;
  Storage(Storage&&) = delete;
  Storage(const Storage&&) = delete;
  void set_pImpl(StorageImpl* storage_impl) {
    storage_impl_ = storage_impl;
  }
  StorageImpl* pImpl() {
    return storage_impl_;
  }
  StorageImpl* pImpl() const {
    return storage_impl_;
  }
  StorageImpl* retained_pImpl() const {
    storage_impl_->retain();
    return storage_impl_;
  }

 protected:
  StorageImpl* storage_impl_;
};

aten/src/ATen/StorageImpI.h中实现:

truct AT_API StorageImpl : public Retainable {
 public:
  StorageImpl() = delete;
  virtual ~StorageImpl() {};
  StorageImpl(
      at::ScalarType scalar_type,
      ptrdiff_t size,
      at::DataPtr data_ptr,
      at::Allocator* allocator,
      bool resizable);
  StorageImpl(
      at::ScalarType scalar_type,
      ptrdiff_t size,
      at::Allocator* allocator,
      bool resizable);
  StorageImpl(StorageImpl&) = delete;
  StorageImpl(const StorageImpl&) = delete;
  // NB: Don't move ref count!
  StorageImpl(StorageImpl&& other) = delete;
  StorageImpl(const StorageImpl&&) = delete;
  StorageImpl& operator=(StorageImpl&& other) = delete;

  // TODO: Rename this into th_data, and move it out of the class;
  // the real data shouldn't call th::from_type
  template <typename T>
  inline T* data() const {
    auto scalar_type_T = at::CTypeToScalarType>::to();
    if (scalar_type_ != scalar_type_T) {
      AT_ERROR(
          "Attempt to access StorageImpl having data type ",
          at::toString(scalar_type_),
          " as data type ",
          at::toString(scalar_type_T));
    }
    return unsafe_data();
  }

  template <typename T>
  inline T* unsafe_data() const {
    return static_cast(this->data_ptr_.get());
  }

  void release_resources() {
    if (finalizer_) {
      (*finalizer_)();
    }
    finalizer_ = nullptr;
    data_ptr_.clear();
  }

  void operator=(const StorageImpl&) = delete;

  virtual size_t elementSize() const {
    return at::elementSize(scalar_type_);
  }

  Type& type();

  // TODO: Rename to size() and size to size_
  ptrdiff_t size() const {
    return size_;
  };
  void set_size(ptrdiff_t size) {
    size_ = size;
  };
  bool resizable() const {
    return resizable_;
  };
  at::DataPtr& data_ptr() {
    return data_ptr_;
  };
  void set_data_ptr(at::DataPtr&& data_ptr) {
    data_ptr_ = std::move(data_ptr);
  };
  void* data() {
    return data_ptr_.get();
  };
  const void* data() const {
    return data_ptr_.get();
  };
  at::Allocator* allocator() {
    return allocator_;
  };
  at::ScalarType& scalar_type() {
    return scalar_type_;
  };
  const at::Allocator* allocator() const {
    return allocator_;
  };
  int getDevice() const {
    return data_ptr_.device().index();
  }
  void set_resizable(bool resizable) {
    resizable_ = resizable;
  }

 private:
  at::ScalarType scalar_type_;
  at::DataPtr data_ptr_;
  ptrdiff_t size_;
  bool resizable_;

 public:
  at::Allocator* allocator_;
  std::unique_ptr finalizer_;
};

module forward()函数

由重载__call__实现,同时会加载注册的hooks函数,包括register_backward_hook()register_forward_pre_hook.
另外为了保证hooks函数的有序性,使用了python collections中的OrderedDict实现。

Tensor.backward()函数

关键是底层的Torch._C中实现的Variable._execution_engine.run_backward()函数
pytorch/torch/csrc/autograd/python_engine.cpp中实现:

// Implementation of torch._C._EngineBase.run_backward
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
  HANDLE_TH_ERRORS
  _maybe_reinitialize_engine_after_fork();
  PyObject *tensors = nullptr;
  PyObject *grad_tensors = nullptr;
  unsigned char keep_graph = 0;
  unsigned char create_graph = 0;
  PyObject *inputs = nullptr;
  unsigned char allow_unreachable = 0;
  const char *accepted_kwargs[] = {
      "tensors", "grad_tensors", "keep_graph", "create_graph", "inputs",
      "allow_unreachable", nullptr
  };
  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Ob", (char**)accepted_kwargs,
        &tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable))
    return nullptr;

  THPUtils_assert(PyTuple_Check(tensors), "tensors argument is expected to "
      "be a tuple, but got %s", THPUtils_typename(tensors));
  THPUtils_assert(PyTuple_Check(grad_tensors), "grad_tensors argument is "
      "expected to be a tuple, but got %s", THPUtils_typename(grad_tensors));

  Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors);
  Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors);
  THPUtils_assert(num_tensors == num_gradients, "got %ld tensors and %ld "
      "gradients", num_tensors, num_gradients);

  edge_list roots;
  roots.reserve(num_tensors);
  variable_list grads;
  grads.reserve(num_tensors);
  for (int i = 0; i < num_tensors; i++) {
    PyObject *_tensor = PyTuple_GET_ITEM(tensors, i);
    THPUtils_assert(THPVariable_Check(_tensor), "element %d of tensors "
        "tuple is not a Tensor", i);
    auto& variable = ((THPVariable*)_tensor)->cdata;
    auto gradient_edge = variable.gradient_edge();
    THPUtils_assert(gradient_edge.function,
        "element %d of tensors does not require grad and does not have a grad_fn", i);
    roots.push_back(std::move(gradient_edge));

    PyObject *grad = PyTuple_GET_ITEM(grad_tensors, i);
    if (THPVariable_Check(grad)) {
      grads.push_back(((THPVariable*)grad)->cdata);
    } else {
      THPUtils_assert(grad == Py_None,
          "element %d of gradients tuple is not a Tensor or None", i);
      THPUtils_assert(!variable.requires_grad(),
          "element %d of gradients tuple is None, but the corresponding Tensor requires grad");
    }
  }

  std::vector output_edges;
  if (inputs != nullptr) {
    int num_inputs = PyTuple_GET_SIZE(inputs);
    output_edges.reserve(num_inputs);
    for (int i = 0; i < num_inputs; ++i) {
      PyObject *input = PyTuple_GET_ITEM(inputs, i);
      THPUtils_assert(THPVariable_Check(input),
          "all inputs have to be Tensors, but got %s", THPUtils_typename(input));
      THPVariable *input_var = (THPVariable*)input;
      const auto output_nr = input_var->cdata.output_nr();
      auto grad_fn = input_var->cdata.grad_fn();
      if (!grad_fn) {
          grad_fn = input_var->cdata.try_get_grad_accumulator();
      }
      THPUtils_assert(input_var->cdata.requires_grad(),
          "One of the differentiated Tensors does not require grad");
      if (!grad_fn) {
        output_edges.emplace_back();
      } else {
        output_edges.emplace_back(grad_fn, output_nr);
      }
    }
  }

  variable_list outputs;
  {
    AutoNoGIL no_gil;
    outputs = engine.execute(roots, grads, keep_graph, create_graph, output_edges);
  }

  if (inputs != nullptr) {
    int num_inputs = PyTuple_GET_SIZE(inputs);
    THPObjectPtr py_outputs {PyTuple_New(num_inputs)};
    if (!py_outputs) return nullptr;
    for (int i = 0; i < num_inputs; i++) {
      THPUtils_assert(allow_unreachable || outputs[i].defined(), "One of the "
                      "differentiated Tensors appears to not have been used "
                      "in the graph. Set allow_unused=True if this is the "
                      "desired behavior.");
      PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i]));
    }
    return py_outputs.release();
  } else {
    Py_RETURN_NONE;
  }
  END_HANDLE_TH_ERRORS
}

你可能感兴趣的:(pytorch)