diff --git a/scripts/train_progressive.py b/scripts/train_progressive.py index 9b4df04..3d2ce54 100755 --- a/scripts/train_progressive.py +++ b/scripts/train_progressive.py @@ -69,6 +69,13 @@ Examples: help="Enable DeepSpeed training" ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="Local rank for distributed training" + ) + return parser.parse_args() @@ -87,21 +94,32 @@ def load_config(config_path: str) -> dict: return config -def setup_distributed_training(): +def setup_distributed_training(local_rank=-1): """Setup distributed training environment""" + import torch + # Check if we're in a distributed environment if "RANK" in os.environ and "WORLD_SIZE" in os.environ: import torch.distributed as dist - import torch # Initialize distributed training if not dist.is_initialized(): dist.init_process_group(backend="nccl") - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(local_rank) + # Use local_rank from args or environment + if local_rank >= 0: + torch.cuda.set_device(local_rank) + else: + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) - print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}") + print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}, local_rank {local_rank}") + return True + + # For DeepSpeed, local_rank might be set even without RANK/WORLD_SIZE initially + elif local_rank >= 0: + torch.cuda.set_device(local_rank) + print(f"Set CUDA device to local_rank {local_rank}") return True return False @@ -113,7 +131,7 @@ def main(): # Setup distributed training if requested is_distributed = False if args.distributed or args.deepspeed: - is_distributed = setup_distributed_training() + is_distributed = setup_distributed_training(args.local_rank) print("Progressive LLM Training for 松尾研LLMコンペ2025") print("=" * 50)