Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bf16 Support #136

Merged
merged 31 commits into from
Aug 29, 2023
Merged
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3210632
modified: bmtrain/optim/adam.py
Aug 2, 2023
4e4d32a
modified: add bf16 in adam.py
Aug 2, 2023
ef038ae
modified: bmtrain/optim/_function.py
Aug 2, 2023
a1e0e42
modified: bmtrain/optim/_function.py
JerryYin777 Aug 2, 2023
efe3cdc
modified: add bf16.h to csrc/cuda/cross_entropy.cu
JerryYin777 Aug 3, 2023
6125f39
modified: bmtrain/optim/_function.py
JerryYin777 Aug 3, 2023
e2fdcc5
modified: add adam_fp32_accum_bf16 function
JerryYin777 Aug 3, 2023
00eee39
modified: add adam_fp32_accum_bf16 function
JerryYin777 Aug 3, 2023
2eb6d72
modified: add adam_fp32_accum_bf16 function
JerryYin777 Aug 3, 2023
6dfd739
modified: bmtrain/loss/_function.py
JerryYin777 Aug 7, 2023
8a6f686
modified: add bf16 to is_nan_inf()
JerryYin777 Aug 7, 2023
0b22ed1
FIX: csrc/bind.cpp
JerryYin777 Aug 7, 2023
f8885cf
modified: tests/test_has_inf_nan.py
JerryYin777 Aug 7, 2023
9a9d526
modified: bmtrain/optim/_function.py
JerryYin777 Aug 7, 2023
77c3585
add pybind11 in Update other_requirements.txt
JerryYin777 Aug 7, 2023
5cc3611
Update adam_cuda.cu
JerryYin777 Aug 7, 2023
870c613
Merge branch 'OpenBMB:main' into main
JerryYin777 Aug 8, 2023
2b414b8
Update test_optim_bf16.py
JerryYin777 Aug 8, 2023
40441f8
Merge branch 'OpenBMB:main' into main
JerryYin777 Aug 8, 2023
c5f7e49
FIX
JerryYin777 Aug 9, 2023
148ed85
Merge branch 'main' of https://github.com/JerryYin777/BMTrain
JerryYin777 Aug 9, 2023
55839be
refactor has_inf_nan_bf16
Achazwl Aug 10, 2023
4d44dce
refactor has_inf_nan_bf16
Achazwl Aug 10, 2023
d008d7f
refactor adam_offload
Achazwl Aug 10, 2023
145d90f
refactor adam
Achazwl Aug 10, 2023
da14e7e
fix adam_cuda
Achazwl Aug 11, 2023
21d5218
test nccl
Achazwl Aug 11, 2023
d29b22a
fix optim state test
Achazwl Aug 11, 2023
4aeb638
fix cuda version if; refactor cross_entropy
Achazwl Aug 11, 2023
53d8f8b
fix bf16 not support info
Achazwl Aug 11, 2023
eef5542
Merge branch 'dev' into bf16
Achazwl Aug 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
modified: bmtrain/optim/adam.py
  • Loading branch information
ycr0776 committed Aug 2, 2023
commit 321063284ff9462b77f8f56a81383574d4db46e7
15 changes: 14 additions & 1 deletion bmtrain/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def step(self, closure=None, scale=1):
grad = p.grad

if p.dtype == torch.half:
F.adam(
C.f_adam(
state["_param_fp32"], # fp32
p, # fp16
grad, # fp16
Expand All @@ -101,6 +101,19 @@ def step(self, closure=None, scale=1):
group['weight_decay'],
state['step']
)
elif p.dtype == torch.bfloat16:
C.f_adam_bf16(
state["_param_fp32"], # fp32
p, # bf16
grad, # bf16
state['exp_avg'], # fp32: m
group['betas'][0], group['betas'][1],
group['eps'],
0.0 if state["step"] <= self._hold_steps else group['lr'],
scale,
group['weight_decay'],
state['step']
)
else:
other_kwargs = {}
if 'maximize' in inspect.signature(torch.optim._functional.adam).parameters:
Expand Down