Skip to content

Commit

Permalink
Refactor autograd package to separate Python dependencies. (pytorch#662)
Browse files Browse the repository at this point in the history
The core autograd Variable, Function, and Engine no longer depend on the
Python API. This let's us implement functions in C++. In the future, we
can also multithread engine and release the GIL for most of the
non-Python backwards.
  • Loading branch information
colesbury committed Feb 14, 2017
1 parent 16d2c3d commit bd53030
Show file tree
Hide file tree
Showing 44 changed files with 2,970 additions and 1,767 deletions.
13 changes: 11 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,23 @@ def run(self):
"torch/csrc/Exceptions.cpp",
"torch/csrc/Tensor.cpp",
"torch/csrc/Storage.cpp",
"torch/csrc/DynamicTypes.cpp",
"torch/csrc/byte_order.cpp",
"torch/csrc/utils.cpp",
"torch/csrc/utils/object_ptr.cpp",
"torch/csrc/allocators.cpp",
"torch/csrc/serialization.cpp",
"torch/csrc/autograd/init.cpp",
"torch/csrc/autograd/variable.cpp",
"torch/csrc/autograd/function.cpp",
"torch/csrc/autograd/engine.cpp",
"torch/csrc/autograd/function.cpp",
"torch/csrc/autograd/variable.cpp",
"torch/csrc/autograd/grad_buffer.cpp",
"torch/csrc/autograd/python_function.cpp",
"torch/csrc/autograd/python_cpp_function.cpp",
"torch/csrc/autograd/python_variable.cpp",
"torch/csrc/autograd/python_engine.cpp",
"torch/csrc/autograd/functions/batch_normalization.cpp",
"torch/csrc/autograd/functions/init.cpp",
"torch/csrc/nn/THNN_generic.cpp",
]

Expand Down
1 change: 1 addition & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def bw_hook(inc, grad):
counter[0] += inc

z = x ** 2 + x * 2 + x * y + y
x.register_hook(lambda *args: bw_hook(0, *args))
test = z.register_hook(lambda *args: bw_hook(1, *args))
z.backward(torch.ones(5, 5), retain_variables=True)
self.assertEqual(counter[0], 1)
Expand Down
12 changes: 5 additions & 7 deletions test/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ def autograd_sharing(queue, ready, master_modified):
is_ok = var.data.equal(expected_var)
var.data[:] = torch.ones(5, 5)

if var.grad is not None:
is_ok &= var.grad.data.equal(torch.ones(5, 5) * 4)
var.grad.data[:] = torch.ones(5, 5)
is_ok &= var.grad.data.equal(torch.zeros(5, 5))
var.grad.data[:] = torch.ones(5, 5)

queue.put(is_ok)

Expand Down Expand Up @@ -357,20 +356,19 @@ def _test_autograd_sharing(self, var):
queue = mp.Queue()
p = mp.Process(target=autograd_sharing, args=(queue, ready, master_modified))
p.start()
var.grad.data.zero_()
queue.put(var)

ready.wait()
var.data[0, 0] = 1000
if var.grad is not None:
var.grad.data[:] = torch.ones(5, 5) * 4
var.grad.data[:] = torch.ones(5, 5) * 4
master_modified.set()

worker_ok = queue.get()
self.assertTrue(worker_ok)

self.assertEqual(var.data, torch.ones(5, 5))
if var.grad is not None:
self.assertEqual(var.grad.data, torch.ones(5, 5))
self.assertEqual(var.grad.data, torch.ones(5, 5) * 4)
p.join()

def test_variable_sharing(self):
Expand Down
5 changes: 2 additions & 3 deletions torch/autograd/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch._C as _C
import torch.utils.hooks as hooks
from collections import OrderedDict
from itertools import chain


class Function(_C._FunctionBase):
Expand Down Expand Up @@ -98,9 +97,9 @@ def mark_non_differentiable(self, *args):
**This should be called at most once, only from inside the**
:func:`forward` **method, and all arguments should be outputs.**
This will mark outputs as non requiring gradient, increasing the
This will mark outputs as not requiring gradients, increasing the
efficiency of backward computation. You still need to accept a gradient
for this output in :meth:`~Function.backward`, but it's always going to
for each output in :meth:`~Function.backward`, but it's always going to
be ``None``.
This is used e.g. for indices returned from a max :class:`Function`.
Expand Down
43 changes: 15 additions & 28 deletions torch/autograd/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,30 +56,6 @@ class Variable(_C._VariableBase):
'is_cuda',
}

@property
def grad(self):
if self.requires_grad and self._grad is None:
# TODO: this won't have to be zeroed in the future
self._grad = Variable(self.data.new(self.data.size()).zero_())
return self._grad

@property
def requires_grad(self):
return self._requires_grad

@requires_grad.setter
def requires_grad(self, value):
if self.creator is not None:
if value is False:
hint = (" If you want to use a computed variable in a subgraph "
"that doesn't require differentiation use "
"var_no_grad = var.detach().")
else:
hint = ''
raise RuntimeError("you can only change requires_grad flags of "
"leaf variables." + hint)
self._requires_grad = value

def __getattr__(self, name):
if name in self._fallthrough_methods:
return getattr(self.data, name)
Expand Down Expand Up @@ -108,19 +84,30 @@ def __deepcopy__(self, memo):
if self.creator is not None:
raise RuntimeError("Only Variables created explicitly by the user "
"(graph leaves) support the deepcopy protocol at the moment")
result = type(self)(self.data.clone(), requires_grad=self.requires_grad,
volatile=self.volatile)
result = type(self)(self.data.clone())
result.requires_grad = self.requires_grad
result.volatile = self.volatile
memo[id(self)] = result
return result

def __reduce_ex__(self, proto):
state = (self.requires_grad, self.volatile, self._backward_hooks)
if proto > 1:
return super(Variable, self).__reduce_ex__(proto)
return type(self), (self.data,), state
if sys.version_info[0] == 2:
from copy_reg import __newobj__
else:
from copyreg import __newobj__
return __newobj__, (type(self),), self.__getstate__()
return __newobj__, (type(self), self.data), state

def __setstate__(self, state):
if len(state) == 5:
# legacy serialization of Variable
self.data = state[0]
state = (state[3], state[4], state[2])
if self.creator is not None:
raise RuntimeError('__setstate__ can be only called on leaf variables')
self.requires_grad, self.volatile, self._backward_hooks = state

def __repr__(self):
return 'Variable containing:' + self.data.__repr__()
Expand Down
161 changes: 161 additions & 0 deletions torch/csrc/DynamicTypes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#include "DynamicTypes.h"

#include "THP.h"
#include <vector>
#include <unordered_map>
#include <THPP/tensors/THTensor.hpp>
#include <THPP/tensors/THSTensor.hpp>

#ifdef WITH_CUDA
#include <THC/THC.h>
#include <THPP/tensors/THCTensor.hpp>
extern THCState* state;
#endif


using namespace thpp;

namespace torch {

struct TensorType {
Type data_type;
bool is_cuda;
bool is_sparse;

friend bool operator==(const TensorType &t1, const TensorType &t2)
{
return (t1.data_type == t2.data_type &&
t1.is_cuda == t2.is_cuda &&
t1.is_sparse == t2.is_sparse);
}

friend bool operator!=(const TensorType &t1, const TensorType &t2)
{
return !(t1 == t2);
}
};

struct TensorTypeHasher
{
std::size_t operator()(const TensorType& k) const
{
size_t hash = static_cast<size_t>(k.data_type);
hash = (hash << 8) + k.is_cuda;
hash = (hash << 1) + k.is_sparse;
return hash;
}
};

static std::unordered_map<std::string, Type> type_names = {
{"Float", Type::FLOAT},
{"Double", Type::DOUBLE},
{"Half", Type::HALF},
{"Byte", Type::UCHAR},
{"Char", Type::CHAR},
{"Short", Type::SHORT},
{"Int", Type::INT},
{"Long", Type::LONG},
};
static std::unordered_map<PyTypeObject*, TensorType> pytype_to_tensortype;
static std::unordered_map<TensorType, PyTypeObject*, TensorTypeHasher> tensortype_to_pytype;

void registerPyTypeObject(PyTypeObject *pytype, const std::string& name, bool is_cuda, bool is_sparse)
{
TensorType type;
type.data_type = type_names.at(name);
type.is_cuda = is_cuda;
type.is_sparse = is_sparse;

pytype_to_tensortype[pytype] = type;
tensortype_to_pytype[type] = pytype;
}

PyTypeObject* getPyTypeObject(const thpp::Tensor& tensor)
{
TensorType type;
type.data_type = tensor.type();
type.is_cuda = tensor.isCuda();
type.is_sparse = tensor.isSparse();

return tensortype_to_pytype.at(type);
}

static std::unique_ptr<Tensor> createTensor(void *tensor, Type type, bool is_cuda, bool is_sparse)
{
if (is_cuda) {
#ifdef WITH_CUDA
if (type == Type::UCHAR) {
return std::unique_ptr<Tensor>(new THCTensor<unsigned char>(state, (THCudaByteTensor*)tensor));
} else if (type == Type::CHAR) {
return std::unique_ptr<Tensor>(new THCTensor<char>(state, (THCudaCharTensor*)tensor));
} else if (type == Type::SHORT) {
return std::unique_ptr<Tensor>(new THCTensor<short>(state, (THCudaShortTensor*)tensor));
} else if (type == Type::INT) {
return std::unique_ptr<Tensor>(new THCTensor<int>(state, (THCudaIntTensor*)tensor));
} else if (type == Type::LONG) {
return std::unique_ptr<Tensor>(new THCTensor<long>(state, (THCudaLongTensor*)tensor));
} else if (type == Type::FLOAT) {
return std::unique_ptr<Tensor>(new THCTensor<float>(state, (THCudaTensor*)tensor));
} else if (type == Type::DOUBLE) {
return std::unique_ptr<Tensor>(new THCTensor<double>(state, (THCudaDoubleTensor*)tensor));
} else if (type == Type::HALF) {
return std::unique_ptr<Tensor>(new THCTensor<half>(state, (THCudaHalfTensor*)tensor));
}
#else
throw std::runtime_error("Compiled without CUDA support");
#endif
} else if (is_sparse) {
if (type == Type::UCHAR) {
return std::unique_ptr<Tensor>(new THSTensor<unsigned char>((THSByteTensor*)tensor));
} else if (type == Type::CHAR) {
return std::unique_ptr<Tensor>(new THSTensor<char>((THSCharTensor*)tensor));
} else if (type == Type::SHORT) {
return std::unique_ptr<Tensor>(new THSTensor<short>((THSShortTensor*)tensor));
} else if (type == Type::INT) {
return std::unique_ptr<Tensor>(new THSTensor<int>((THSIntTensor*)tensor));
} else if (type == Type::LONG) {
return std::unique_ptr<Tensor>(new THSTensor<long>((THSLongTensor*)tensor));
} else if (type == Type::FLOAT) {
return std::unique_ptr<Tensor>(new THSTensor<float>((THSFloatTensor*)tensor));
} else if (type == Type::DOUBLE) {
return std::unique_ptr<Tensor>(new THSTensor<double>((THSDoubleTensor*)tensor));
}
} else if (type == Type::UCHAR) {
return std::unique_ptr<Tensor>(new THTensor<unsigned char>((THByteTensor*)tensor));
} else if (type == Type::CHAR) {
return std::unique_ptr<Tensor>(new THTensor<char>((THCharTensor*)tensor));
} else if (type == Type::SHORT) {
return std::unique_ptr<Tensor>(new THTensor<short>((THShortTensor*)tensor));
} else if (type == Type::INT) {
return std::unique_ptr<Tensor>(new THTensor<int>((THIntTensor*)tensor));
} else if (type == Type::LONG) {
return std::unique_ptr<Tensor>(new THTensor<long>((THLongTensor*)tensor));
} else if (type == Type::FLOAT) {
return std::unique_ptr<Tensor>(new THTensor<float>((THFloatTensor*)tensor));
} else if (type == Type::DOUBLE) {
return std::unique_ptr<Tensor>(new THTensor<double>((THDoubleTensor*)tensor));
}
throw std::invalid_argument("Unsupported tensor type");
}

std::unique_ptr<Tensor> createTensor(PyObject *data)
{
auto tensor_type = pytype_to_tensortype.at(Py_TYPE(data));
auto type = tensor_type.data_type;
auto tensor = ((THPVoidTensor *)data)->cdata;
auto wrapper = createTensor(tensor, type, tensor_type.is_cuda, tensor_type.is_sparse);
wrapper->retain();
return wrapper;
}

PyObject* createPyObject(const thpp::Tensor& tensor)
{
auto type = getPyTypeObject(tensor);
PyObject *obj = type->tp_alloc(type, 0);
if (obj) {
((THPVoidTensor*)obj)->cdata = (THVoidTensor *)const_cast<thpp::Tensor&>(tensor).retain().cdata();
}
return obj;
}

} // namespace
25 changes: 25 additions & 0 deletions torch/csrc/DynamicTypes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

// Provides conversions between Python tensor objects and thpp::Tensors.

#include <memory>
#include <Python.h>
#include <THPP/THPP.h>

namespace torch {

// Register a PyTypeObject* with the given attributes
void registerPyTypeObject(
PyTypeObject *pytype, const std::string& name,
bool is_cuda, bool is_sparse);

// Gets the PyTypeObject* corresponding to the Tensor
PyTypeObject* getPyTypeObject(const thpp::Tensor& tensor);

// Creates a Tensor from a Python tensor object
std::unique_ptr<thpp::Tensor> createTensor(PyObject *data);

// Creates Python tensor object from a Tensor
PyObject* createPyObject(const thpp::Tensor& tensor);

} // namespace torch
7 changes: 5 additions & 2 deletions torch/csrc/Exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
#include <stdexcept>
#include <string>

#include "THP.h"

#define HANDLE_TH_ERRORS \
try {

Expand All @@ -21,6 +19,11 @@
extern PyObject *THPException_FatalError;

#ifdef _THP_CORE

// Throwing this exception means that the python error flags have been already
// set and control should be immediately returned to the interpreter.
class python_error : public std::exception {};

struct THException: public std::exception {
THException(const char* msg): msg(msg) {};

Expand Down
Loading

0 comments on commit bd53030

Please sign in to comment.