From 6906a09c8faf2e6be12470a62e634ee59364487c Mon Sep 17 00:00:00 2001 From: Soma Nakamura Date: Thu, 10 Jul 2025 22:51:29 +0900 Subject: [PATCH] =?UTF-8?q?=E3=81=93=E3=82=93=E3=81=AB=E3=81=A1=E3=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/train_progressive.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/scripts/train_progressive.py b/scripts/train_progressive.py index 9b4df04..3d2ce54 100755 --- a/scripts/train_progressive.py +++ b/scripts/train_progressive.py @@ -69,6 +69,13 @@ Examples: help="Enable DeepSpeed training" ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="Local rank for distributed training" + ) + return parser.parse_args() @@ -87,21 +94,32 @@ def load_config(config_path: str) -> dict: return config -def setup_distributed_training(): +def setup_distributed_training(local_rank=-1): """Setup distributed training environment""" + import torch + # Check if we're in a distributed environment if "RANK" in os.environ and "WORLD_SIZE" in os.environ: import torch.distributed as dist - import torch # Initialize distributed training if not dist.is_initialized(): dist.init_process_group(backend="nccl") - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(local_rank) + # Use local_rank from args or environment + if local_rank >= 0: + torch.cuda.set_device(local_rank) + else: + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) - print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}") + print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}, local_rank {local_rank}") + return True + + # For DeepSpeed, local_rank might be set even without RANK/WORLD_SIZE initially + elif local_rank >= 0: + torch.cuda.set_device(local_rank) + print(f"Set CUDA device to local_rank {local_rank}") return True return False @@ -113,7 +131,7 @@ def main(): # Setup distributed training if requested is_distributed = False if args.distributed or args.deepspeed: - is_distributed = setup_distributed_training() + is_distributed = setup_distributed_training(args.local_rank) print("Progressive LLM Training for 松尾研LLMコンペ2025") print("=" * 50)