こんにちは

This commit is contained in:
Soma Nakamura 2025-07-10 22:51:29 +09:00
parent 6280c303dc
commit 6906a09c8f

View file

@ -69,6 +69,13 @@ Examples:
help="Enable DeepSpeed training"
)
parser.add_argument(
"--local_rank",
type=int,
default=-1,
help="Local rank for distributed training"
)
return parser.parse_args()
@ -87,21 +94,32 @@ def load_config(config_path: str) -> dict:
return config
def setup_distributed_training():
def setup_distributed_training(local_rank=-1):
"""Setup distributed training environment"""
import torch
# Check if we're in a distributed environment
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
import torch.distributed as dist
import torch
# Initialize distributed training
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# Use local_rank from args or environment
if local_rank >= 0:
torch.cuda.set_device(local_rank)
else:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}")
print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}, local_rank {local_rank}")
return True
# For DeepSpeed, local_rank might be set even without RANK/WORLD_SIZE initially
elif local_rank >= 0:
torch.cuda.set_device(local_rank)
print(f"Set CUDA device to local_rank {local_rank}")
return True
return False
@ -113,7 +131,7 @@ def main():
# Setup distributed training if requested
is_distributed = False
if args.distributed or args.deepspeed:
is_distributed = setup_distributed_training()
is_distributed = setup_distributed_training(args.local_rank)
print("Progressive LLM Training for 松尾研LLMコンペ2025")
print("=" * 50)