こんにちは
This commit is contained in:
parent
6280c303dc
commit
6906a09c8f
1 changed files with 24 additions and 6 deletions
|
|
@ -69,6 +69,13 @@ Examples:
|
|||
help="Enable DeepSpeed training"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Local rank for distributed training"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
|
@ -87,21 +94,32 @@ def load_config(config_path: str) -> dict:
|
|||
return config
|
||||
|
||||
|
||||
def setup_distributed_training():
|
||||
def setup_distributed_training(local_rank=-1):
|
||||
"""Setup distributed training environment"""
|
||||
import torch
|
||||
|
||||
# Check if we're in a distributed environment
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
import torch.distributed as dist
|
||||
import torch
|
||||
|
||||
# Initialize distributed training
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="nccl")
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
# Use local_rank from args or environment
|
||||
if local_rank >= 0:
|
||||
torch.cuda.set_device(local_rank)
|
||||
else:
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}")
|
||||
print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}, local_rank {local_rank}")
|
||||
return True
|
||||
|
||||
# For DeepSpeed, local_rank might be set even without RANK/WORLD_SIZE initially
|
||||
elif local_rank >= 0:
|
||||
torch.cuda.set_device(local_rank)
|
||||
print(f"Set CUDA device to local_rank {local_rank}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
|
@ -113,7 +131,7 @@ def main():
|
|||
# Setup distributed training if requested
|
||||
is_distributed = False
|
||||
if args.distributed or args.deepspeed:
|
||||
is_distributed = setup_distributed_training()
|
||||
is_distributed = setup_distributed_training(args.local_rank)
|
||||
|
||||
print("Progressive LLM Training for 松尾研LLMコンペ2025")
|
||||
print("=" * 50)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue