Skip to content

Commit

Permalink
Logsumexp dim (clab#1001)
Browse files Browse the repository at this point in the history
* Added dimension option to tensortools logsumexp

* Initial addition of logsumexp_dim

* Added binding and bugfix

* Added python bindings

* Fixed build on windows?
  • Loading branch information
neubig committed Oct 19, 2017
1 parent c48031d commit 6efe3b7
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 33 deletions.
1 change: 1 addition & 0 deletions dynet/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Expression hinge_dim(const Expression& x, const std::vector<std::vector<unsigned
Expression hinge_dim(const Expression& x, const std::vector<std::vector<unsigned> > * pindices, unsigned d, float m) { return Expression(x.pg, x.pg->add_function<HingeDim>({x.i}, pindices, d, m)); }
Expression log_softmax(const Expression& x) { return Expression(x.pg, x.pg->add_function<LogSoftmax>({x.i})); }
Expression log_softmax(const Expression& x, const vector<unsigned>& d) { return Expression(x.pg, x.pg->add_function<RestrictedLogSoftmax>({x.i}, d)); }
Expression logsumexp_dim(const Expression& x, unsigned d) { return Expression(x.pg, x.pg->add_function<LogSumExpDimension>({x.i}, d)); }
Expression sparsemax(const Expression& x) { return Expression(x.pg, x.pg->add_function<Sparsemax>({x.i})); }
Expression sparsemax_loss(const Expression& x, const vector<unsigned>& target_support) { return Expression(x.pg, x.pg->add_function<SparsemaxLoss>({x.i}, target_support)); }
Expression sparsemax_loss(const Expression& x, const vector<unsigned>* ptarget_support) { return Expression(x.pg, x.pg->add_function<SparsemaxLoss>({x.i}, ptarget_support)); }
Expand Down
13 changes: 13 additions & 0 deletions dynet/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,19 @@ Expression log_softmax(const Expression& x);
*/
Expression log_softmax(const Expression& x, const std::vector<unsigned>& restriction);

/**
* \ingroup lossoperations
* \brief Log, sum, exp by dimension
* \details The "logsumexp" function calculated over a particular dimension
* \f$ln(\sum_i e^{xs_i})\f$, used in adding probabilities in the log domain.
*
* \param x Expression with respect to which to calculate the logsumexp.
* \param d The dimension along which to do the logsumexp.
*
* \return The result.
*/
Expression logsumexp_dim(const Expression& x, unsigned d);

/**
* \ingroup lossoperations
* \brief Log, sum, exp
Expand Down
61 changes: 48 additions & 13 deletions dynet/nodes-logsumexp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,6 @@ namespace dynet {

#ifndef __CUDACC__

// template <class T>
// EIGEN_STRONG_INLINE real logsumexp(const T& x, const vector<unsigned>& denom) {
// real m = x(denom[0],0);
// for (auto i : denom) {
// real r = x(i,0);
// if (r > m) m = r;
// }
// real z = 0;
// for (auto i : denom)
// z += expf(x(i,0) - m);
// return m + logf(z);
// }

string LogSumExp::as_string(const vector<string>& arg_names) const {
ostringstream s;
s << "log(exp " << arg_names[0];
Expand Down Expand Up @@ -109,4 +96,52 @@ void LogSumExp::backward_dev_impl(const MyDevice & dev,
}
DYNET_NODE_INST_DEV_IMPL(LogSumExp)

// ************* LogSumExpDimension *************

#define MAX_LOG_SUM_EXP 65536

#ifndef __CUDACC__

string LogSumExpDimension::as_string(const vector<string>& arg_names) const {
ostringstream s;
s << "logsumexp_dim(" << arg_names[0] << ", " << dimension << ")";
return s.str();
}

Dim LogSumExpDimension::dim_forward(const vector<Dim>& xs) const {
DYNET_ARG_CHECK(xs.size() == 1, "LogSumExpDimension takes only one argument" << xs);
DYNET_ARG_CHECK(xs[0].nd <= 2, "LogSumExpDimension, expects 2 or fewer dimensions" << xs);
DYNET_ARG_CHECK(xs[0].nd > dimension, "LogSumExpDimension, expects its dimension argument (" <<
dimension << ") to be smaller than the number of elements in the input " << xs);
Dim d = xs[0];
if(dimension < d.nd)
d.delete_dim(dimension);
return d;
}

#endif

template<class MyDevice>
void LogSumExpDimension::forward_dev_impl(const MyDevice & dev, const vector<const Tensor*>& xs, Tensor& fx) const {
Tensor ms(fx.d, nullptr, fx.device, fx.mem_pool), zs(fx.d, nullptr, fx.device, fx.mem_pool);
AlignedMemoryPool* scratch_allocator = fx.device->pools[(int)DeviceMempool::SCS];
ms.v = static_cast<float*>(scratch_allocator->allocate(ms.d.size() * sizeof(float)));
TensorTools::logsumexp_dev(dev, *xs[0], ms, fx, dimension);
scratch_allocator->free();
}

template<class MyDevice>
void LogSumExpDimension::backward_dev_impl(const MyDevice & dev,
const vector<const Tensor*>& xs,
const Tensor& fx,
const Tensor& dEdf,
unsigned i,
Tensor& dEdxi) const {
unsigned other_dim = dimension ^ 1;
Eigen::array<int, 3> bcast = {1, 1, 1}; bcast[dimension] = xs[0]->d[dimension];
Eigen::array<int, 3> morph = {1, 1, (int)fx.d.bd}; morph[other_dim] = fx.d[0];
dEdxi.tb<2>().device(*dev.edevice) += (xs[0]->tb<2>() - fx.tb<1>().reshape(morph).broadcast(bcast)).exp() * dEdf.tb<1>().reshape(morph).broadcast(bcast);
}
DYNET_NODE_INST_DEV_IMPL(LogSumExpDimension)

}
8 changes: 8 additions & 0 deletions dynet/nodes-logsumexp.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ struct LogSumExp : public Node {
virtual bool supports_multibatch() const override { return true; }
};

struct LogSumExpDimension : public Node {
template <typename T> explicit LogSumExpDimension(const T& a, unsigned d = 0) : Node(a), dimension(d) {}
DYNET_NODE_DEFINE_DEV_IMPL()
virtual bool supports_multibatch() const override { return true; }
private:
unsigned dimension;
};

} // namespace dynet

#endif
38 changes: 20 additions & 18 deletions dynet/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,57 +376,59 @@ void TensorTools::clip(Tensor& d, float left, float right) {
#endif

template <class MyDevice>
void TensorTools::logsumexp_dev(const MyDevice & dev, const Tensor& x, Tensor & m, Tensor& z) {
if(x.d.bd == 1 && x.d[1] == 1) {
m.t<0>().device(*dev.edevice) = x.t<1>().maximum();
void TensorTools::logsumexp_dev(const MyDevice & dev, const Tensor& x, Tensor & m, Tensor& z, unsigned axis) {
DYNET_ARG_CHECK(x.d.nd <= 2, "TensorTools::logsumexp currently only supports tensors of dimension <= 2");
unsigned other_axis = axis ^ 1;
if(x.d.bd == 1 && x.d[other_axis] == 1) {
m.t<0>().device(*dev.edevice) = x.tvec().maximum();
#ifdef __CUDACC__
Eigen::array<int, 1> bcast;
bcast[0] = x.d[0];
// This needs to be split into two lines to prevent memory allocation
// TODO? Here and in logsoftmax: Is there a better way to subtract a scalar that is already on the GPU without using broadcasting (and without copying the scalar back to the host first)
z.t<0>().device(*dev.edevice) = (x.t<1>() - m.t<1>().broadcast(bcast)).exp().sum();
z.t<0>().device(*dev.edevice) = z.t<0>().log() + m.t<0>();
z.t<0>().device(*dev.edevice) = (x.tvec() - m.tvec().broadcast(bcast)).exp().sum();
z.t<0>().device(*dev.edevice) = z.tvec().log() + m.t<0>();
#else
float mval = as_scalar(m);
// This needs to be split into two lines to prevent memory allocation
z.t<0>().device(*dev.edevice) = (x.t<1>() - mval).exp().sum();
z.t<0>().device(*dev.edevice) = (x.tvec() - mval).exp().sum();
z.t<0>().device(*dev.edevice) = z.t<0>().log() + mval;
#endif
} else {
Eigen::array<int, 1> red_axis; red_axis[0] = 0;
Eigen::array<int, 1> red_axis; red_axis[0] = axis;
m.tb<1>().device(*dev.edevice) = x.tb<2>().maximum(red_axis);
// TODO: Currently, the first version is slower on CPU, hence the switch
#ifdef __CUDACC__
Eigen::array<int, 3> bcast({(int)x.d.rows(), 1, 1});
Eigen::array<int, 3> morph({1, (int)m.d[0], (int)m.d.bd});
Eigen::array<int, 3> bcast = {1, 1, 1}; bcast[axis] = (int)x.d[axis];
Eigen::array<int, 3> morph = {1, 1, (int)m.d.bd}; morph[other_axis] = (int)m.d[0];
// This needs to be split into two lines to prevent memory allocation
z.tb<1>().device(*dev.edevice) = (x.tb<2>() - m.tb<2>().reshape(morph).broadcast(bcast)).exp().sum(red_axis);
z.tb<1>().device(*dev.edevice) = z.tb<1>().log() + m.tb<1>();
#else
auto miter = m.v;
for(size_t b = 0; b < x.d.bd; ++b) {
for(size_t i = 0; i < x.d[1]; ++i, ++miter) {
z.tb<1>().chip<1>(b).chip<0>(i).device(*dev.edevice) = (x.tb<2>().chip<2>(b).chip<1>(i) - *miter).exp().sum();
z.tb<1>().chip<1>(b).chip<0>(i).device(*dev.edevice) = (x.tb<2>().chip<2>(b).chip(i,other_axis) - *miter).exp().sum();
z.tb<1>().chip<1>(b).chip<0>(i).device(*dev.edevice) = z.tb<1>().chip<1>(b).chip<0>(i).log() + *miter;
}
}
#endif
}
}
#ifdef __CUDACC__
template void TensorTools::logsumexp_dev<Device_GPU>(const Device_GPU & dev, const Tensor &x, Tensor &m, Tensor &z);
template void TensorTools::logsumexp_dev<Device_GPU>(const Device_GPU & dev, const Tensor &x, Tensor &m, Tensor &z, unsigned d);
#else
template void TensorTools::logsumexp_dev<Device_CPU>(const Device_CPU & dev, const Tensor &x, Tensor &m, Tensor &z);
template void TensorTools::logsumexp_dev<Device_CPU>(const Device_CPU & dev, const Tensor &x, Tensor &m, Tensor &z, unsigned d);
#ifdef HAVE_CUDA
extern template void TensorTools::logsumexp_dev<Device_GPU>(const Device_GPU & dev, const Tensor &x, Tensor &m, Tensor &z);
void TensorTools::logsumexp(const Tensor &x, Tensor &m, Tensor &z) {
if (x.device->type == DeviceType::CPU) { return logsumexp_dev(*(const Device_CPU*)x.device, x, m, z); }
else if (x.device->type == DeviceType::GPU) { return logsumexp_dev(*(const Device_GPU*)x.device, x, m, z); }
extern template void TensorTools::logsumexp_dev<Device_GPU>(const Device_GPU & dev, const Tensor &x, Tensor &m, Tensor &z, unsigned d);
void TensorTools::logsumexp(const Tensor &x, Tensor &m, Tensor &z, unsigned d) {
if (x.device->type == DeviceType::CPU) { return logsumexp_dev(*(const Device_CPU*)x.device, x, m, z, d); }
else if (x.device->type == DeviceType::GPU) { return logsumexp_dev(*(const Device_GPU*)x.device, x, m, z, d); }
else { throw std::runtime_error("Bad device type"); }
}
#else
void TensorTools::logsumexp(const Tensor &x, Tensor &m, Tensor &z) {
if (x.device->type == DeviceType::CPU) { return logsumexp_dev(*(const Device_CPU*)x.device, x, m, z); }
void TensorTools::logsumexp(const Tensor &x, Tensor &m, Tensor &z, unsigned d) {
if (x.device->type == DeviceType::CPU) { return logsumexp_dev(*(const Device_CPU*)x.device, x, m, z, d); }
else { throw std::runtime_error("Bad device type"); }
}
#endif
Expand Down
4 changes: 2 additions & 2 deletions dynet/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ struct TensorTools {
* \param m A tensor of scratch memory to hold the maximum values of each column
* \param z The output tensor
*/
static void logsumexp(const Tensor& x, Tensor &m, Tensor &z);
static void logsumexp(const Tensor& x, Tensor &m, Tensor &z, unsigned d = 0);

/**
* \brief Calculate the index of the maximum value
Expand Down Expand Up @@ -725,7 +725,7 @@ struct TensorTools {
template<class MyDevice>
static IndexTensor categorical_sample_log_prob_dev(const MyDevice & dev, const Tensor& v, unsigned dim = 0, unsigned num = 1);
template <class MyDevice>
static void logsumexp_dev(const MyDevice & dev, const Tensor& x, Tensor &m, Tensor &z);
static void logsumexp_dev(const MyDevice & dev, const Tensor& x, Tensor &m, Tensor &z, unsigned d = 0);

};

Expand Down
1 change: 1 addition & 0 deletions python/_dynet.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ cdef extern from "dynet/expr.h" namespace "dynet":

CExpression c_max_dim "dynet::max_dim" (CExpression& x, unsigned d) except + #
CExpression c_min_dim "dynet::min_dim" (CExpression& x, unsigned d) except + #
CExpression c_logsumexp_dim "dynet::logsumexp_dim" (CExpression& x, unsigned d) except +

CExpression c_layer_norm "dynet::layer_norm" (CExpression& x, CExpression& g, CExpression& b) except + #
CExpression c_weight_norm "dynet::weight_norm" (CExpression& w, CExpression& g) except + #
Expand Down
15 changes: 15 additions & 0 deletions python/_dynet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3865,6 +3865,21 @@ cpdef Expression logsumexp(list xs):
#print(cvec.size(), file=sys.stderr)
return Expression.from_cexpr(x.cg_version, c_logsumexp(cvec))

cpdef Expression logsumexp_dim(Expression x, unsigned d=0):
"""Logsumexp along an arbitrary dimension
The "logsumexp" function that calculates :math:`\ln(\sum_i e^{xs_i})`, used in adding probabilities in the log domain.
This is performed along a certain dimension.
Args:
x (dynet.Expression): Input expression
d (unsigned): Dimensions along which to reduce
Returns:
dynet.Expression: An expression with one less dimension representing the result
"""
return Expression.from_cexpr(x.cg_version, c_logsumexp_dim(x.c(), d))

cpdef Expression average(list xs):
"""Average
Expand Down
11 changes: 11 additions & 0 deletions tests/test-nodes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,17 @@ BOOST_AUTO_TEST_CASE( logsumexp_inequal_batch_gradient ) {
BOOST_CHECK(check_grad(mod, z, 0));
}

// Expression logsumexp(x);
BOOST_AUTO_TEST_CASE( logsumexp_dim_gradient ) {
dynet::ComputationGraph cg;
Expression x = parameter(cg, param_square1);
vector<Expression> exps;
for (int d = 1; d >= 0; d--)
exps.push_back(logsumexp_dim(x, d));
Expression z = sum_elems(sum(exps));
BOOST_CHECK(check_grad(mod, z, 1));
}

// Expression operator+(const Expression& x, real y);
BOOST_AUTO_TEST_CASE( addscalar_gradient ) {
dynet::ComputationGraph cg;
Expand Down

0 comments on commit 6efe3b7

Please sign in to comment.