Skip to content

Commit

Permalink
[DCP][BE] Move DCP._state_dict_utils out from DCP (pytorch#115523)
Browse files Browse the repository at this point in the history
DCP._state_dict_utils is also used by FSDP. This can cause circular import sometimes. Move it out from DCP to avoid circular import.

Differential Revision: [D52022440](https://our.internmc.facebook.com/intern/diff/D52022440/)

Pull Request resolved: pytorch#115523
Approved by: https://github.com/wz337
  • Loading branch information
fegin authored and pytorchmergebot committed Dec 13, 2023
1 parent 1500379 commit cc28f61
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable import fully_shard
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict
from torch.distributed._state_dict_utils import _gather_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch.distributed.checkpoint as dist_cp
from torch.distributed._shard.sharded_tensor import ShardedTensor

from torch.distributed._state_dict_utils import _all_gather_sharded_tensor
from torch.distributed._tensor import DTensor, init_device_mesh, Replicate
from torch.distributed.checkpoint._state_dict_utils import _all_gather_sharded_tensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

Expand Down
6 changes: 3 additions & 3 deletions test/distributed/checkpoint/test_state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol

from torch.distributed._tensor import DTensor
from torch.distributed._tensor.placement_types import Shard
from torch.distributed.checkpoint._state_dict_utils import (
from torch.distributed._state_dict_utils import (
_gather_state_dict,
_offload_state_dict_to_cpu,
)
from torch.distributed._tensor import DTensor
from torch.distributed._tensor.placement_types import Shard
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/fsdp/test_fsdp_optim_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import torch.nn as nn
from torch import distributed as dist
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._state_dict_utils import _gather_state_dict
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_WRAPPED_MODULE,
apply_activation_checkpointing,
)
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import (
Expand Down
8 changes: 4 additions & 4 deletions test/distributed/fsdp/test_fsdp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
Shard,
ShardedTensor,
)
from torch.distributed._state_dict_utils import (
_all_gather_sharded_tensor,
_gather_state_dict,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.checkpoint._state_dict_utils import (
_all_gather_sharded_tensor,
_gather_state_dict,
)
from torch.distributed.fsdp import (
CPUOffload,
FullStateDictConfig,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import math
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, Replicate

if dist.is_available() or TYPE_CHECKING:
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, Replicate


def _all_gather_sharded_tensor(
sharded_tensor: ShardedTensor,
sharded_tensor: "ShardedTensor",
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint._state_dict_utils import (
from torch.distributed._state_dict_utils import (
_gather_state_dict,
_offload_state_dict_to_cpu,
)
from torch.distributed._tensor import DTensor
from torch.distributed.fsdp import (
FullOptimStateDictConfig,
FullStateDictConfig,
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._state_dict_utils import _gather_state_dict
from torch.distributed._tensor import DTensor, Replicate
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict
from torch.distributed.distributed_c10d import _get_pg_default_device
from torch.distributed.fsdp._common_utils import (
_apply_to_modules,
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/distributed/common_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import torch.nn as nn

from torch.distributed._sharded_tensor import ShardedTensor
from torch.distributed._state_dict_utils import _gather_state_dict
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict
from torch.distributed.checkpoint.state_dict import (
PG,
set_state_dict,
Expand Down

0 comments on commit cc28f61

Please sign in to comment.