Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bf16 Support to Adam #134

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions csrc/cuda/adam_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ __global__ void adam_fp32_accum_bf16(
int32_t col = blockIdx.x * blockDim.x + threadIdx.x;

if (col < n) {
float local_g = __nv_bfloat162float(g[col]) / scale; // real_g
float local_m = beta1 * __nv_bfloat162float(m[col]) + (1 - beta1) * local_g; // real_m
float local_g = __bfloat162float(g[col]) / scale; // real_g
float local_m = beta1 * __bfloat162float(m[col]) + (1 - beta1) * local_g; // real_m
float local_v = beta2 * v[col] + (1 - beta2) * local_g * local_g; // real_v
float local_p = param[col];
local_p = local_p - lr * local_m / bias_correction1 / (sqrtf(local_v / bias_correction2 / scale) + eps) - lr * weight_decay * local_p;
Expand Down Expand Up @@ -122,4 +122,4 @@ void adam_bf16_launcher(
dim3 block_size = dim3(threads, 1, 1);
dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1);
adam_fp32_accum_bf16<<<grid_size, block_size, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, g_ptr, m_ptr, v_fp32_ptr, param_fp32_ptr, param_h_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
}
}
3 changes: 2 additions & 1 deletion other_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ cpm_kernels>=1.0.11
jieba
tensorboard
setuptools_rust
transformers
transformers
pybind11
10 changes: 5 additions & 5 deletions tests/test_optim_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def main():

model1 = model1.cuda().to(dtype=torch.bfloat16)
model2 = model2.cuda().to(dtype=torch.bfloat16)
model3 = model3.cuda().to(dtype=torch.bfloat16)
model3 = model3.cuda()

opt1 = bmt.optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3)
opt2 = bmt.optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3)
Expand All @@ -46,10 +46,10 @@ def main():
opt3.zero_grad()

for p1, p2, p3 in zip(model1.parameters(), model2.parameters(), model3.parameters()):
grad_bf16 = torch.randn_like(p1).to(dtype=torch.bfloat16)
p1.grad = grad_bf16
p2.grad = grad_bf16
p3.grad = grad_bf16
grad = torch.randn_like(p1)
p1.grad = grad.to(dtype=torch.bfloat16)
p2.grad = grad.to(dtype=torch.bfloat16)
p3.grad = grad.float()

opt1.step()
opt2.step()
Expand Down