Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer load gathered state and record delta feature are supported now #184

Merged
merged 3 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Optimizer load gathered state and record delta feature are supported now
  • Loading branch information
MayDomine committed Feb 22, 2024
commit d933ee90ead05449542ccd19a01d2155ef451c76
4 changes: 4 additions & 0 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,15 @@ def init_param_storage(self):
param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))
self._param_info[-1]["begin"] = to_offset_st
self._param_info[-1]["end"] = (to_offset_end - to_offset_st,)
setattr(param, "_start_partition", offset_st)
setattr(param, "_end_partition", offset_end)
param.data[:] = \
torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:]
del contiguous_param
else:
param.data = torch.tensor([], dtype=param.dtype, device=param.device)
setattr(param, "_start_partition", None)
setattr(param, "_end_partition", 0)
# clear parameter data, but keep the dtype and device
setattr(param, "_in_block", True)

Expand Down
29 changes: 29 additions & 0 deletions bmtrain/optim/_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
from ..distributed import all_reduce, all_gather

def state_dict_gather(state_dict):
param_key = [p for param_group in state_dict['param_groups'] for p in param_group['params'] ]
for k, v in state_dict['state'].items():
if "step" in v:
step = v['step']

for k in param_key:
if k not in state_dict['state']:
state_dict['state'][k] = {
'exp_avg' : torch.tensor([], device="cuda", dtype=torch.float32),
'exp_avg_sq' : torch.tensor([], device="cuda", dtype=torch.float32),
'_param_fp32' : torch.tensor([], device="cuda", dtype=torch.float32),
'step' : step
}
v = state_dict['state'][k]
for name, dtype in [("exp_avg", torch.float32), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]:
if name in v:
with torch.no_grad():
numel = torch.tensor(v[name].numel(), device="cuda", dtype=torch.long)
max_numel = all_reduce(numel, op="max")
v_p = torch.nn.functional.pad(v[name], (0, max_numel - numel), value=-1e15)
if max_numel > 0:
whole_state = all_gather(v_p.cuda()).flatten()
whole_state = whole_state[whole_state != -1e15]
v[name] = whole_state.contiguous().cpu()
return state_dict
15 changes: 14 additions & 1 deletion bmtrain/optim/_function.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
from .. import C
import torch
CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda
def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tensor, m_fp32: torch.Tensor,

def bf16_from_fp32(param_fp32):
param_bf16 = torch.empty_like(param_fp32, dtype=torch.bfloat16)
C.to_bf16_from_fp32(param_fp32.numel(), param_fp32.data_ptr(), param_bf16.data_ptr())
return param_bf16

def fp16_from_fp32(param_fp32):
param_fp16 = torch.empty_like(param_fp32, dtype=torch.float16)
C.to_fp16_from_fp32(param_fp32.numel(), param_fp32.data_ptr(), param_fp16.data_ptr())
return param_fp16

def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, delta_info: torch.Tensor, g_fp16: torch.Tensor, m_fp32: torch.Tensor,
v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float,
weight_decay: float, step: int) -> None:
assert param_fp32.is_contiguous(), "param_fp32 must be contiguous"
Expand All @@ -19,6 +30,7 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.T
assert g_fp16.device == torch.device("cpu"), "g_fp16 must be a cpu tensor"
assert m_fp32.device == torch.device("cpu"), "m_fp32 must be a cpu tensor"
assert v_fp32.device == torch.device("cpu"), "v_fp32 must be a cpu tensor"
#TODO check avg_delta and var_delta
assert param_fp32.numel() == param_fp16.numel(), "param_fp32 and param_fp16 must have the same number of elements"
assert param_fp32.numel() == g_fp16.numel(), "param_fp32 and g_fp16 must have the same number of elements"
assert param_fp32.numel() == m_fp32.numel(), "param_fp32 and m_fp32 must have the same number of elements"
Expand All @@ -35,6 +47,7 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.T
param_fp32.numel(),
param_fp32.data_ptr(),
param_fp16.data_ptr(),
delta_info.data_ptr() if delta_info is not None else 0,
g_fp16.data_ptr(),
m_fp32.data_ptr(),
v_fp32.data_ptr(),
Expand Down
8 changes: 8 additions & 0 deletions bmtrain/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ def step(self, closure=None, scale=1):

return loss

def get_avg_delta():

raise NotImplementedError("get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer")

def get_var_delta():

raise NotImplementedError("get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer")

def load_state_dict(self, state_dict: dict) -> None:
r"""Loads the optimizer state.

Expand Down
66 changes: 58 additions & 8 deletions bmtrain/optim/adam_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from copy import deepcopy
from itertools import chain
from collections import defaultdict
from ._distributed import state_dict_gather

class AdamOffloadOptimizer(torch.optim.Optimizer):
"""
Adam optimizer
"""
_bmtrain_optimizer = True

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, hold_steps=0):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, hold_steps=0, record_delta=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
Expand All @@ -25,12 +26,17 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

self.avg_delta = 0
self.var_delta = 0
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)

self._hold_steps = hold_steps
self._events = {}
self.record_delta = record_delta
if self.record_delta:
for group in self.param_groups:
for p in group['params']:
setattr(p, "_delta_info", ( torch.tensor([0 for i in range(4)], dtype=torch.float32, device="cpu") ))

@torch.no_grad()
def step(self, closure=None, scale=1):
Expand Down Expand Up @@ -92,7 +98,9 @@ def step(self, closure=None, scale=1):
else:
state["_grad_fp16"].copy_(param.grad, non_blocking=True)
torch.cuda.current_stream().record_event(event)

sum_delta = 0
sum_sq_delta = 0
total_numel = 0
for param, state, event, beta1, beta2, eps, lr, weight_decay in update_params:
# wait for transfer to host
event.synchronize()
Expand Down Expand Up @@ -135,6 +143,7 @@ def step(self, closure=None, scale=1):
F.adam_cpu(
state["_param_fp32"].view(-1),
state["_param_fp16"].view(-1),
param._delta_info if self.record_delta else None,
grad.view(-1),
state["exp_avg"].view(-1),
state["exp_avg_sq"].view(-1),
Expand All @@ -144,12 +153,25 @@ def step(self, closure=None, scale=1):
weight_decay,
state["step"]
)
total_numel += state["_param_fp16"].numel()
if self.record_delta:
sum_delta += param._delta_info[2].item();
sum_sq_delta += param._delta_info[3].item();
# transfer parameters back to device asynchronously
param.copy_(state["_param_fp16"], non_blocking=True)
if self.record_delta:
self.avg_delta = sum_delta / total_numel
self.var_delta = sum_sq_delta / total_numel - self.avg_delta ** 2


return loss

def get_avg_delta(self) -> None:
return self.avg_delta if self.record_delta else 0

def get_var_delta(self) -> None:
return self.var_delta if self.record_delta else 0

def load_state_dict(self, state_dict: dict) -> None:
r"""Loads the optimizer state.

Expand All @@ -158,6 +180,9 @@ def load_state_dict(self, state_dict: dict) -> None:
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API



state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
Expand All @@ -177,13 +202,27 @@ def load_state_dict(self, state_dict: dict) -> None:
zip(chain.from_iterable((g['params'] for g in saved_groups)),
chain.from_iterable((g['params'] for g in groups)))}

# _param_start_end = chain.from_iterable((g["params_start_end"] for g in saved_groups))
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
is_whole = False if "is_whole" not in state_dict else state_dict['is_whole']
pop_key = []
for k, v in state_dict['state'].items():
if k in id_map:
param = id_map[k]
if is_whole and param._start_partition is not None:
for key in ['_param_fp32', 'exp_avg_sq', 'exp_avg']:
if key in v:
v[key] = v[key][param._start_partition:param._end_partition]
elif is_whole and param._start_partition is None:
pop_key.append(param)

if "_param_fp32" not in v:
with torch.no_grad():
v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device="cpu")
v["_param_fp32"].copy_(param)

if "_param_fp32" not in v:
v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device="cpu")
Expand All @@ -204,16 +243,19 @@ def load_state_dict(self, state_dict: dict) -> None:
state[param]["_grad_fp16"] = torch.empty(param.size(), dtype=param.dtype, pin_memory=True) # on host
else:
state[k] = v

for k in pop_key:
state.pop(k)
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})



def state_dict(self) -> dict:
def state_dict(self, gather=False) -> dict:
r"""Returns the state of the optimizer as a :class:`dict`.

It contains two entries:
Expand All @@ -223,6 +265,7 @@ def state_dict(self) -> dict:
* param_groups - a list containing all parameter groups where each
parameter group is a dict
"""

# Save order indices instead of Tensors
param_mappings = {}
start_index = 0
Expand All @@ -247,11 +290,18 @@ def cut_states(state):
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): cut_states(v)
for k, v in self.state.items()}
return {
states = {
'state': packed_state,
'param_groups': param_groups,
}
if gather:
states = state_dict_gather(states)
states['is_whole'] = True
else:
states['is_whole'] = False

return states

#TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu
def zero_grad(self, set_to_none: bool = False):
super().zero_grad(set_to_none=set_to_none)
super().zero_grad(set_to_none=set_to_none)
4 changes: 2 additions & 2 deletions bmtrain/optim/optim_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ def _justify_scale(self, scale):
self.loss_scale = scale
self.steps_since_last_scale = 0

def state_dict(self) -> dict:
def state_dict(self, gather_opt=False) -> dict:
return {
"optimizers": [opt.state_dict() for opt in self.optimizers],
"optimizers": [opt.state_dict(gather_opt) for opt in self.optimizers],
"lr_schedulers": [lrs.state_dict() if lrs else None for lrs in self.lr_schedulers],
"loss_scale": self.loss_scale,
"loss_scale_enabled": self.loss_scale_enabled,
Expand Down
25 changes: 18 additions & 7 deletions csrc/bind.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#include "include/bind.hpp"

PYBIND11_MODULE(C, m) {
m.def("is_bf16_supported",&is_bf16_supported,"whether bf16 supported");
m.def("has_nan_inf_fp16_launcher",&has_nan_inf_fp16_launcher,"has nan inf");
m.def("has_nan_inf_bf16_launcher",&has_nan_inf_bf16_launcher,"has nan inf bf16");
m.def("to_fp16_from_fp32", &fp16_from_fp32_value_launcher, "convert");
m.def("to_bf16_from_fp32", &bf16_from_fp32_value_launcher, "convert");
m.def("is_bf16_supported", &is_bf16_supported, "whether bf16 supported");
m.def("has_nan_inf_fp16_launcher", &has_nan_inf_fp16_launcher, "has nan inf");
m.def("has_nan_inf_bf16_launcher", &has_nan_inf_bf16_launcher, "has nan inf bf16");
m.def("adam_fp16_launcher", &adam_fp16_launcher, "adam function cpu");
m.def("adam_bf16_launcher", &adam_bf16_launcher, "adam function cpu");
m.def("adam_cpu_fp16_launcher", &adam_cpu_fp16_launcher, "adam function cpu");
Expand All @@ -26,8 +28,17 @@ PYBIND11_MODULE(C, m) {
m.def("ncclReduceScatter", &pyNCCLReduceScatter, "nccl reduce scatter");
m.def("ncclGroupStart", &pyNCCLGroupStart, "nccl group start");
m.def("ncclGroupEnd", &pyNCCLGroupEnd, "nccl group end");
m.def("ncclSend",&pyNCCLSend,"nccl send");
m.def("ncclRecv",&pyNCCLRecv,"nccl recv");
m.def("ncclCommCount",&pyNCCLCommCount,"nccl comm count");
m.def("ncclCommUserRank",&pyNCCLCommUserRank,"nccl comm user rank");
m.def("ncclSend", &pyNCCLSend, "nccl send");
m.def("ncclRecv", &pyNCCLRecv, "nccl recv");
m.def("ncclCommCount", &pyNCCLCommCount, "nccl comm count");
m.def("ncclCommUserRank", &pyNCCLCommUserRank, "nccl comm user rank");

py::class_<CUDAEventScope>(m, "CUDAEventScope")
.def(py::init(&CUDAEventScope::create))
.def("recordStart", &CUDAEventScope::recordStart)
.def("recordEnd", &CUDAEventScope::recordEnd);

py::class_<PyWatchDog>(m, "WatchDog")
.def(py::init(&PyWatchDog::create))
.def("watch", &PyWatchDog::watch);
}
Loading