こんにちは
This commit is contained in:
parent
6280c303dc
commit
6906a09c8f
1 changed files with 24 additions and 6 deletions
|
|
@ -69,6 +69,13 @@ Examples:
|
||||||
help="Enable DeepSpeed training"
|
help="Enable DeepSpeed training"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--local_rank",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Local rank for distributed training"
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -87,21 +94,32 @@ def load_config(config_path: str) -> dict:
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def setup_distributed_training():
|
def setup_distributed_training(local_rank=-1):
|
||||||
"""Setup distributed training environment"""
|
"""Setup distributed training environment"""
|
||||||
|
import torch
|
||||||
|
|
||||||
# Check if we're in a distributed environment
|
# Check if we're in a distributed environment
|
||||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch
|
|
||||||
|
|
||||||
# Initialize distributed training
|
# Initialize distributed training
|
||||||
if not dist.is_initialized():
|
if not dist.is_initialized():
|
||||||
dist.init_process_group(backend="nccl")
|
dist.init_process_group(backend="nccl")
|
||||||
|
|
||||||
|
# Use local_rank from args or environment
|
||||||
|
if local_rank >= 0:
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
else:
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
|
|
||||||
print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}")
|
print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}, local_rank {local_rank}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# For DeepSpeed, local_rank might be set even without RANK/WORLD_SIZE initially
|
||||||
|
elif local_rank >= 0:
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
print(f"Set CUDA device to local_rank {local_rank}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
@ -113,7 +131,7 @@ def main():
|
||||||
# Setup distributed training if requested
|
# Setup distributed training if requested
|
||||||
is_distributed = False
|
is_distributed = False
|
||||||
if args.distributed or args.deepspeed:
|
if args.distributed or args.deepspeed:
|
||||||
is_distributed = setup_distributed_training()
|
is_distributed = setup_distributed_training(args.local_rank)
|
||||||
|
|
||||||
print("Progressive LLM Training for 松尾研LLMコンペ2025")
|
print("Progressive LLM Training for 松尾研LLMコンペ2025")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue