progressive-llm/scripts/train_ddp_simple.py
2025-07-10 22:47:07 +09:00

42 lines
No EOL
1 KiB
Python
Executable file

#!/usr/bin/env python3
"""
Simple DDP training script without complex launcher
"""
import os
import sys
import subprocess
from pathlib import Path
def main():
if len(sys.argv) < 2:
print("Usage: python train_ddp_simple.py <config_file> [num_gpus]")
print("Example: python train_ddp_simple.py config/training_config_gemma3_1b_8gpu_ddp.yaml 8")
sys.exit(1)
config_file = sys.argv[1]
num_gpus = int(sys.argv[2]) if len(sys.argv) > 2 else 8
# Set environment variables
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# Build torchrun command
cmd = [
"torchrun",
"--nproc_per_node", str(num_gpus),
"--master_port", "12355",
"scripts/train_progressive.py",
"--config", config_file,
"--distributed"
]
print(f"Running DDP training with {num_gpus} GPUs")
print(f"Command: {' '.join(cmd)}")
# Execute
result = subprocess.run(cmd)
sys.exit(result.returncode)
if __name__ == "__main__":
main()