Skip to content

Commit

Permalink
Merge pull request pytorch#881 from colesbury/parallelize_backwards
Browse files Browse the repository at this point in the history
Parallelize autograd backwards
  • Loading branch information
colesbury committed Mar 6, 2017
2 parents 761d679 + 6336300 commit 15a9fbd
Show file tree
Hide file tree
Showing 40 changed files with 1,940 additions and 670 deletions.
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def run(self):
"torch/csrc/byte_order.cpp",
"torch/csrc/utils.cpp",
"torch/csrc/utils/object_ptr.cpp",
"torch/csrc/utils/tuple_parser.cpp",
"torch/csrc/allocators.cpp",
"torch/csrc/serialization.cpp",
"torch/csrc/autograd/init.cpp",
Expand All @@ -260,7 +261,9 @@ def run(self):
"torch/csrc/autograd/python_cpp_function.cpp",
"torch/csrc/autograd/python_variable.cpp",
"torch/csrc/autograd/python_engine.cpp",
"torch/csrc/autograd/python_hook.cpp",
"torch/csrc/autograd/functions/batch_normalization.cpp",
"torch/csrc/autograd/functions/convolution.cpp",
"torch/csrc/autograd/functions/init.cpp",
"torch/csrc/nn/THNN_generic.cpp",
]
Expand Down
48 changes: 46 additions & 2 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,51 @@ def bw_hook_modify(grad):
z.backward(torch.ones(5, 5))
self.assertEqual(y.grad.data, (x.data + 1) * 4)

def test_hooks_cpp(self):
# Tests hooks for autograd function implemented in C++
bn = torch.nn.BatchNorm1d(5, affine=False)
bn.eval()

counter = [0]

def bw_hook(grad):
counter[0] += 1
return grad * 2

x = Variable(torch.ones(5, 5), requires_grad=True)
z = bn(x)
z.register_hook(bw_hook)
z.sum().backward()

self.assertEqual(counter[0], 1, 'bw_hook not called')
self.assertEqual(x.grad.data, torch.ones(5, 5) * 2)

@unittest.skipIf(sys.version_info[0] == 2, "Python 2 doesn't collect cycles involving __del__")
def test_hooks_cycle(self):
import gc
counter = [0]

class GradHook(object):
def __init__(self, var):
self.var = var

def __del__(self):
counter[0] += 1

def __call__(self, *args):
pass

def run_test():
x = Variable(torch.ones(5, 5), requires_grad=True)
y = x * 2
x.register_hook(GradHook(x))
y.register_hook(GradHook(y))
y._backward_hooks[1] = GradHook(y)

run_test()
gc.collect()
self.assertEqual(counter[0], 3)

def test_hook_none(self):
# WARNING: this is a test for autograd internals.
# You should never have to use such things in your code.
Expand All @@ -84,7 +129,6 @@ def backward(self, grad_x, grad_y):
return grad_x, None

fn = NoneGradientFunction()
fn._backward_hooks = OrderedDict()
was_called = [False]

def hook(grad_input, grad_output):
Expand All @@ -95,7 +139,7 @@ def hook(grad_input, grad_output):
self.assertIsNotNone(grad_output[0])
self.assertIsNotNone(grad_output[1])
was_called[0] = True
fn._backward_hooks[id(hook)] = hook
fn.register_hook(hook)

x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5))
Expand Down
14 changes: 14 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,20 @@ def bw_hook(inc, h_module, grad_input, grad_output):
test_fwd.remove()
test_bwd.remove()

def test_hook_cpp(self):
counter = [0]
bn = nn.BatchNorm1d(5)

def hook(module, grad_inputs, grad_outputs):
counter[0] += 1
self.assertEqual(len(grad_inputs), 3)
self.assertEqual(len(grad_outputs), 1)
self.assertEqual(module, bn)

bn.register_backward_hook(hook)
output = bn(Variable(torch.randn(5, 5), requires_grad=True))
output.sum().backward()

def test_hook_fail(self):
module = nn.Sigmoid()
input = Variable(torch.randn(5, 5), requires_grad=True)
Expand Down
13 changes: 7 additions & 6 deletions torch/autograd/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,13 @@ def mark_non_differentiable(self, *args):
"""
self.non_differentiable = args

def register_hook(self, hook):
if self._backward_hooks is None:
self._backward_hooks = OrderedDict()
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[id(handle)] = hook
return handle
@staticmethod
def _register_hook(backward_hooks, hook):
if backward_hooks is None:
backward_hooks = OrderedDict()
handle = hooks.RemovableHandle(backward_hooks)
backward_hooks[handle.id] = hook
return backward_hooks, handle

def forward(self, *input):
"""Performs the operation.
Expand Down
17 changes: 2 additions & 15 deletions torch/autograd/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def register_hook(self, hook):
The hook will be called every time a gradient with respect to the
variable is computed. The hook should have the following signature::
hook(grad) -> Tensor or None
hook(grad) -> Variable or None
The hook should not modify its argument, but it can optionally return
a new gradient which will be used in place of :attr:`grad`.
Expand Down Expand Up @@ -180,22 +180,9 @@ def register_hook(self, hook):
if self.creator is not None:
self.creator._register_hook_dict(self)
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[id(handle)] = hook
self._backward_hooks[handle.id] = hook
return handle

def _do_backward(self, grad_output, retain_variables):
assert len(grad_output) == 1
assert self._version == 0 and self.creator is None, \
"leaf variable was used in an inplace operation"
unpacked_grad = grad_output[0]
if self._backward_hooks:
for hook in self._backward_hooks.values():
result = hook(unpacked_grad)
if result is not None:
unpacked_grad = result
self.grad.data.add_(unpacked_grad)
return tuple()

def reinforce(self, reward):
"""Registers a reward obtained as a result of a stochastic process.
Expand Down
45 changes: 41 additions & 4 deletions torch/csrc/Exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
#include <exception>
#include <stdexcept>
#include <string>

// Throwing this exception means that the python error flags have been already
// set and control should be immediately returned to the interpreter.
class python_error : public std::exception {};
#include "torch/csrc/utils/object_ptr.h"
#include "torch/csrc/utils/auto_gil.h"

#define HANDLE_TH_ERRORS \
try {
Expand All @@ -24,6 +22,45 @@ class python_error : public std::exception {};

extern PyObject *THPException_FatalError;

// Throwing this exception means that the python error flags have been already
// set and control should be immediately returned to the interpreter.
struct python_error : public std::exception {
python_error() : type(nullptr), value(nullptr), traceback(nullptr) {}

~python_error() {
if (type || value || traceback) {
AutoGIL gil;
Py_XDECREF(type);
Py_XDECREF(value);
Py_XDECREF(traceback);
}
}

/** Saves the exception so that it can be re-thrown on a different thread */
inline void persist() {
// PyErr_Fetch overwrites the pointers
AutoGIL gil;
Py_XDECREF(type);
Py_XDECREF(value);
Py_XDECREF(traceback);
PyErr_Fetch(&type, &value, &traceback);
}

/** Sets the current Python error from this exception */
inline void restore() {
// PyErr_Restore steals references
AutoGIL gil;
Py_XINCREF(type);
Py_XINCREF(value);
Py_XINCREF(traceback);
PyErr_Restore(type, value, traceback);
}

PyObject* type;
PyObject* value;
PyObject* traceback;
};

#ifdef _THP_CORE

struct THException: public std::exception {
Expand Down
Loading

0 comments on commit 15a9fbd

Please sign in to comment.