42 lines
No EOL
1 KiB
Python
Executable file
42 lines
No EOL
1 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
"""
|
|
Simple DDP training script without complex launcher
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import subprocess
|
|
from pathlib import Path
|
|
|
|
def main():
|
|
if len(sys.argv) < 2:
|
|
print("Usage: python train_ddp_simple.py <config_file> [num_gpus]")
|
|
print("Example: python train_ddp_simple.py config/training_config_gemma3_1b_8gpu_ddp.yaml 8")
|
|
sys.exit(1)
|
|
|
|
config_file = sys.argv[1]
|
|
num_gpus = int(sys.argv[2]) if len(sys.argv) > 2 else 8
|
|
|
|
# Set environment variables
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
os.environ["MASTER_PORT"] = "12355"
|
|
|
|
# Build torchrun command
|
|
cmd = [
|
|
"torchrun",
|
|
"--nproc_per_node", str(num_gpus),
|
|
"--master_port", "12355",
|
|
"scripts/train_progressive.py",
|
|
"--config", config_file,
|
|
"--distributed"
|
|
]
|
|
|
|
print(f"Running DDP training with {num_gpus} GPUs")
|
|
print(f"Command: {' '.join(cmd)}")
|
|
|
|
# Execute
|
|
result = subprocess.run(cmd)
|
|
sys.exit(result.returncode)
|
|
|
|
if __name__ == "__main__":
|
|
main() |