Skip to content

Commit

Permalink
Match LayerNorm and InstanceNorm layers to PyTorch (LLNL#2024)
Browse files Browse the repository at this point in the history
* modified layernorm and instancenorm to match pytorch (var dof changed from 1 to 0)

* made changes to GPU code, same as previous commit

* modified bamboo tests to match new layernorm and instancenorm expected values

* fixed gradient calculation

* fixed cpu instancenorm gradient calc
  • Loading branch information
mrwyattii committed Jan 11, 2022
1 parent 086f443 commit c41421b
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 22 deletions.
2 changes: 1 addition & 1 deletion bamboo/unit_tests/test_unit_layer_instance_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def numpy_instance_norm(x, epsilon=1e-5):
x = x.astype(np.float64)
axes = tuple(range(1,x.ndim))
mean = np.mean(x, axis=axes, keepdims=True)
var = np.var(x, ddof=1, axis=axes, keepdims=True)
var = np.var(x, ddof=0, axis=axes, keepdims=True)
return (x - mean) / np.sqrt(var + epsilon)

# ==============================================
Expand Down
2 changes: 1 addition & 1 deletion bamboo/unit_tests/test_unit_layer_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def numpy_layer_norm(x, epsilon=1e-5):
if x.dtype is not np.float64:
x = x.astype(np.float64)
mean = np.mean(x)
var = np.var(x, ddof=1)
var = np.var(x, ddof=0)
return (x - mean) / np.sqrt(var + epsilon)

# ==============================================
Expand Down
10 changes: 4 additions & 6 deletions src/layers/regularizers/instance_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,14 @@ void fp_impl(lbann_comm& comm,
// var = ( sum(x_i^2)/n - mean^2 ) * n/(n-1)
// y_i = (x_i - mean) / sqrt(var + epsilon)
const TensorDataType mean_scale = 1. / channel_size;
const TensorDataType var_correction = double(channel_size) / (channel_size - 1);
LBANN_OMP_PARALLEL_FOR_COLLAPSE2
for (El::Int k = 0; k < local_mini_batch_size; ++k) {
for (El::Int j = 0; j < num_channels; ++j) {
const auto& sum = local_sums(j,k);
const auto& sqsum = local_sqsums(j,k);
const auto mean = sum * mean_scale;
const auto sqmean = sqsum * mean_scale;
auto var = (sqmean - mean * mean) * var_correction;
auto var = (sqmean - mean * mean);
var = std::max(var, TensorDataType{0.});
const TensorDataType inv_stdev
= TensorDataType{1.} / std::sqrt(var + epsilon);
Expand Down Expand Up @@ -184,15 +183,14 @@ void bp_impl(lbann_comm& comm,
El::IR(num_channels, 2*num_channels),
El::ALL);
const TensorDataType mean_scale = 1. / channel_size;
const TensorDataType var_correction = double(channel_size) / (channel_size - 1);
LBANN_OMP_PARALLEL_FOR_COLLAPSE2
for (El::Int k = 0; k < local_mini_batch_size; ++k) {
for (El::Int j = 0; j < num_channels; ++j) {
const auto& sum = local_sums(j,k);
const auto& sqsum = local_sqsums(j,k);
const auto mean = sum * mean_scale;
const auto sqmean = sqsum * mean_scale;
auto var = (sqmean - mean * mean) * var_correction;
auto var = (sqmean - mean * mean);
const TensorDataType inv_stdev
= TensorDataType{1.} / std::sqrt(var + epsilon);
auto& dmean = local_means_grad(j,k);
Expand All @@ -219,7 +217,7 @@ void bp_impl(lbann_comm& comm,
const auto& sqsum = local_sqsums(j,k);
const auto mean = sum * mean_scale;
const auto sqmean = sqsum * mean_scale;
auto var = (sqmean - mean * mean) * var_correction;
auto var = (sqmean - mean * mean);
const TensorDataType inv_stdev
= TensorDataType{1.} / std::sqrt(var + epsilon);
const auto& dmean = local_means_grad(j,k);
Expand All @@ -230,7 +228,7 @@ void bp_impl(lbann_comm& comm,
auto& dx = local_input_grad(i+j*channel_size,k);
dx = (dy * inv_stdev
+ dmean / channel_size
+ dvar * (x - mean) * 2 / (channel_size - 1));
+ dvar * (x - mean) * 2 / channel_size);
}
}
}
Expand Down
11 changes: 4 additions & 7 deletions src/layers/regularizers/instance_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,13 @@ __global__ void fp_output_kernel(
const size_t nthreadsz = blockDim.z * gridDim.z;

const TensorDataType mean_scale = 1. / channel_size;
const TensorDataType var_correction = double(channel_size) / (channel_size - 1);
for (size_t k = gidz; k < mini_batch_size; k += nthreadsz) {
for (size_t j = gidy; j < num_channels; j += nthreadsy) {
const auto& sum = sums[j+k*sums_ldim];
const auto& sqsum = sqsums[j+k*sqsums_ldim];
const auto& mean = sum * mean_scale;
const auto& sqmean = sqsum * mean_scale;
auto var = (sqmean - mean*mean) * var_correction;
auto var = (sqmean - mean*mean);
var = gpu_lib::max(var, TensorDataType{0.});
const auto& inv_stdev = gpu_lib::rsqrt(var + epsilon);
for (size_t i = gidx; i < channel_size; i += nthreadsx) {
Expand Down Expand Up @@ -310,7 +309,6 @@ __global__ void bp_statistics_grad_kernel(
const size_t nthreadsz = blockDim.z * gridDim.z;

const TensorDataType mean_scale = 1. / channel_size;
const TensorDataType var_correction = double(channel_size) / (channel_size - 1);
for (size_t k = gidz; k < mini_batch_size; k += nthreadsz) {
for (size_t j = gidy; j < num_channels; j += nthreadsy) {

Expand All @@ -319,7 +317,7 @@ __global__ void bp_statistics_grad_kernel(
const auto& sqsum = sqsums[j+k*sqsums_ldim];
const auto& mean = sum * mean_scale;
const auto& sqmean = sqsum * mean_scale;
auto var = (sqmean - mean*mean) * var_correction;
auto var = (sqmean - mean*mean);
var = gpu_lib::max(var, TensorDataType{0.});
const auto& inv_stdev = gpu_lib::rsqrt(var + epsilon);

Expand Down Expand Up @@ -388,14 +386,13 @@ __global__ void bp_input_grad_kernel(
const size_t nthreadsz = blockDim.z * gridDim.z;

const TensorDataType mean_scale = 1. / channel_size;
const TensorDataType var_correction = double(channel_size) / (channel_size - 1);
for (size_t k = gidz; k < mini_batch_size; k += nthreadsz) {
for (size_t j = gidy; j < num_channels; j += nthreadsy) {
const auto& sum = sums[j+k*sums_ldim];
const auto& sqsum = sqsums[j+k*sqsums_ldim];
const auto& mean = sum * mean_scale;
const auto& sqmean = sqsum * mean_scale;
auto var = (sqmean - mean*mean) * var_correction;
auto var = (sqmean - mean*mean);
var = gpu_lib::max(var, TensorDataType{0.});
const auto& inv_stdev = gpu_lib::rsqrt(var + epsilon);
const auto& dmean = means_grad[j+k*means_grad_ldim];
Expand All @@ -406,7 +403,7 @@ __global__ void bp_input_grad_kernel(
auto& dx = input_grad[i + j*channel_size + k*input_grad_ldim];
dx = (dy * inv_stdev
+ dmean * mean_scale
+ dvar * (x - mean) * 2 * mean_scale * var_correction);
+ dvar * (x - mean) * 2 * mean_scale);
}
}
}
Expand Down
7 changes: 3 additions & 4 deletions src/layers/regularizers/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void fp_impl(lbann_comm& comm,

// Compute statistics from sums
// mean = sum(x_i) / n
// var = ( sum(x_i^2)/n - mean^2 ) * n/(n-1)
// var = ( sum(x_i^2)/n - mean^2 )
if (sample_size <= 1) {
// local_means already has correct values
El::Fill(local_vars, El::TypeTraits<TensorDataType>::One());
Expand All @@ -82,8 +82,7 @@ void fp_impl(lbann_comm& comm,
auto sample_size_dt = El::To<TensorDataType>(sample_size);
const auto& mean = sum / sample_size_dt;
const auto& sqmean = sqsum / sample_size_dt;
const auto& var = (sqmean - mean*mean) * sample_size_dt
/ (sample_size_dt-El::TypeTraits<TensorDataType>::One());
const auto& var = (sqmean - mean*mean);
local_means(0,i) = mean;
local_vars(0,i) = std::max(var, El::TypeTraits<TensorDataType>::Zero());
}
Expand Down Expand Up @@ -179,7 +178,7 @@ void bp_impl(lbann_comm& comm,
auto& dx = local_input_grad(j,i);
dx = (dy * inv_stdev
+ dmean / sample_size
+ dvar * (x - mean) * 2 / (sample_size - 1));
+ dvar * (x - mean) * 2 / sample_size);
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/layers/regularizers/layer_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ __global__ void fp_sums_kernel(
*
* mean = sum(x_i) / n
*
* var = ( sum(x_i^2)/n - mean^2 ) * n/(n-1)
* var = ( sum(x_i^2)/n - mean^2 )
*
* On input, means contains per-sample sums and vars contains
* per-sample sums of squares.
Expand All @@ -125,7 +125,7 @@ __global__ void fp_statistics_kernel(
const TensorDataType sample_size_dt = TensorDataType(sample_size);
const auto& mean = sum / sample_size_dt;
const auto& sqmean = sqsum / sample_size_dt;
const auto& var = (sqmean - mean*mean) * sample_size_dt / TensorDataType(sample_size-1);
const auto& var = (sqmean - mean*mean);
means[i*means_stride] = mean;
vars[i*vars_stride] = gpu_lib::max(var, TensorDataType(0.0));
}
Expand Down Expand Up @@ -371,7 +371,7 @@ __global__ void bp_input_grad_kernel(
auto& dx = input_grad[i*input_grad_ldim + j];
dx = (dy * inv_stdev
+ dmean / TensorDataType(sample_size)
+ dvar * (x - mean) * TensorDataType(2) / TensorDataType(sample_size - 1));
+ dvar * (x - mean) * TensorDataType(2) / TensorDataType(sample_size));
}
}

Expand Down

0 comments on commit c41421b

Please sign in to comment.