Skip to content

Commit

Permalink
Transformer ML-Perf SPR WW04 (#359)
Browse files Browse the repository at this point in the history
* Changed the attention part so that it can utilize the existing fusion of batchmatmul+mul+addv2, and also use static varibles to reduce redundant compution

* fixed a minor bug for a static variable

* Changed the model so that the reshape can be moved out of dense layer so that we can fuse the ops in the dense layers

* Changed the depth of attention to a static variable
  • Loading branch information
cuixiaom committed Jan 21, 2022
1 parent a1987e4 commit 65f6f0d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
class Attention(tf.compat.v1.layers.Layer):
"""Multi-headed attention layer."""

# set some variables to be static so that it can be computed only once, use 1.0 as initial
# value
rsqrtQ = 1.0
depth = 1

def __init__(self, hidden_size, num_heads, attention_dropout, train):
if hidden_size % num_heads != 0:
raise ValueError("Hidden size must be evenly divisible by the number of "
Expand Down Expand Up @@ -53,6 +58,10 @@ def __init__(self, hidden_size, num_heads, attention_dropout, train):
"num_heads": num_heads
})

# Scale q to prevent the dot product between q and k from growing too large.
Attention.depth = (self.hidden_size // self.num_heads)
Attention.rsqrtQ = Attention.depth ** -0.5

def split_heads(self, x):
"""Split x into different heads, and transpose the resulting value.
Expand All @@ -69,11 +78,8 @@ def split_heads(self, x):
batch_size = tf.shape(input=x)[0]
length = tf.shape(input=x)[1]

# Calculate depth of last dimension after it has been split.
depth = (self.hidden_size // self.num_heads)

# Split the last dimension
x = tf.reshape(x, [batch_size, length, self.num_heads, depth])
x = tf.reshape(x, [batch_size, length, self.num_heads, Attention.depth])

# Transpose the result
return tf.transpose(a=x, perm=[0, 2, 1, 3])
Expand Down Expand Up @@ -141,12 +147,11 @@ def call(self, x, y, bias, cache=None, encdec_cache=None):
v = self.split_heads(v)

# Scale q to prevent the dot product between q and k from growing too large.
depth = (self.hidden_size // self.num_heads)
q *= depth ** -0.5
# Calculate dot product attention
with tf.compat.v1.tpu.bfloat16_scope():
logits = tf.matmul(q, k, transpose_b=True)
bias = tf.cast(bias, tf.bfloat16)
logits = tf.matmul(q, k, transpose_b=True)
logits *= Attention.rsqrtQ
logits += bias
weights = tf.nn.softmax(logits, name="attention_weights")
if self.train:
Expand All @@ -161,7 +166,6 @@ def call(self, x, y, bias, cache=None, encdec_cache=None):

# Run the combined outputs through another linear projection layer.
attention_output = self.output_dense_layer(attention_output)
# attention_output = tf.cast(attention_output, tf.float32)
return attention_output


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,17 @@ def call(self, x, padding=None):
length = tf.shape(input=x)[1]

with tf.compat.v1.tpu.bfloat16_scope():
# Reshape to 2D teansor
x = tf.reshape(x, [-1, self.hidden_size])
output = self.filter_dense_layer(x)
if self.train:
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_RELU_DROPOUT, value=self.relu_dropout)
output = tf.nn.dropout(output, 1 - (1.0 - self.relu_dropout))
output = self.output_dense_layer(output)

# Reshaped back to 3D tensor
output = tf.reshape(output, [batch_size, length, self.hidden_size])

return output

0 comments on commit 65f6f0d

Please sign in to comment.