こんにちは

This commit is contained in:
Soma Nakamura 2025-07-10 22:47:07 +09:00
parent 6d823eb371
commit 4799392e24
4 changed files with 50 additions and 14 deletions

View file

@ -14,7 +14,7 @@ pip install -r requirements.txt
# Start training # Start training
python scripts/train_progressive.py --config config/training_config_gemma3_1b.yaml python scripts/train_progressive.py --config config/training_config_gemma3_1b.yaml
./scripts/train_gemma3_1b_8gpu.sh --strategy deepspeed ./scripts/train_gemma3_1b_8gpu.sh --strategy ddp
``` ```
## Training Stages ## Training Stages
@ -28,7 +28,8 @@ python scripts/train_progressive.py --config config/training_config_gemma3_1b.ya
```bash ```bash
pip install -r requirements.txt # Install dependencies pip install -r requirements.txt # Install dependencies
python scripts/train_progressive.py --config config/training_config_gemma3_1b.yaml # Single GPU python scripts/train_progressive.py --config config/training_config_gemma3_1b.yaml # Single GPU
./scripts/train_gemma3_1b_8gpu.sh --strategy deepspeed # 8 GPUs ./scripts/train_gemma3_1b_8gpu.sh --strategy ddp # 8 GPUs (DDP)
python scripts/train_ddp_simple.py config/training_config_gemma3_1b_8gpu_ddp.yaml # 8 GPUs (Simple)
pytest # Run tests pytest # Run tests
``` ```

42
scripts/train_ddp_simple.py Executable file
View file

@ -0,0 +1,42 @@
#!/usr/bin/env python3
"""
Simple DDP training script without complex launcher
"""
import os
import sys
import subprocess
from pathlib import Path
def main():
if len(sys.argv) < 2:
print("Usage: python train_ddp_simple.py <config_file> [num_gpus]")
print("Example: python train_ddp_simple.py config/training_config_gemma3_1b_8gpu_ddp.yaml 8")
sys.exit(1)
config_file = sys.argv[1]
num_gpus = int(sys.argv[2]) if len(sys.argv) > 2 else 8
# Set environment variables
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# Build torchrun command
cmd = [
"torchrun",
"--nproc_per_node", str(num_gpus),
"--master_port", "12355",
"scripts/train_progressive.py",
"--config", config_file,
"--distributed"
]
print(f"Running DDP training with {num_gpus} GPUs")
print(f"Command: {' '.join(cmd)}")
# Execute
result = subprocess.run(cmd)
sys.exit(result.returncode)
if __name__ == "__main__":
main()

View file

@ -14,7 +14,7 @@ echo "=================================================="
UV_PREFIX="python" UV_PREFIX="python"
# Default values # Default values
STRATEGY="deepspeed" STRATEGY="ddp"
CONFIG="" CONFIG=""
NUM_GPUS=8 NUM_GPUS=8
DRY_RUN=false DRY_RUN=false
@ -48,8 +48,8 @@ while [[ $# -gt 0 ]]; do
echo " --dry-run Show command without executing" echo " --dry-run Show command without executing"
echo "" echo ""
echo "Examples:" echo "Examples:"
echo " # Use DeepSpeed (recommended)" echo " # Use DDP (recommended)"
echo " $0 --strategy deepspeed" echo " $0 --strategy ddp"
echo "" echo ""
echo " # Use DDP" echo " # Use DDP"
echo " $0 --strategy ddp" echo " $0 --strategy ddp"
@ -91,6 +91,7 @@ if [ -z "$CONFIG" ]; then
;; ;;
*) *)
echo -e "${RED}Error: Invalid strategy '$STRATEGY'. Choose from: ddp, fsdp, deepspeed${NC}" echo -e "${RED}Error: Invalid strategy '$STRATEGY'. Choose from: ddp, fsdp, deepspeed${NC}"
echo -e "${YELLOW}Note: DDP is recommended for single-node training${NC}"
exit 1 exit 1
;; ;;
esac esac

View file

@ -128,16 +128,11 @@ def launch_deepspeed_training(config_path, num_gpus):
setup_environment_for_strategy("deepspeed") setup_environment_for_strategy("deepspeed")
# Create DeepSpeed hostfile
hostfile = Path(__file__).parent.parent / "hostfile"
with open(hostfile, "w") as f:
f.write(f"localhost slots={num_gpus}\n")
python_cmd = ["python", "scripts/train_progressive.py"] python_cmd = ["python", "scripts/train_progressive.py"]
# Use --num_gpus without hostfile for single node
cmd = [ cmd = [
"deepspeed", "deepspeed",
"--hostfile", str(hostfile),
"--num_gpus", str(num_gpus), "--num_gpus", str(num_gpus),
] + python_cmd + [ ] + python_cmd + [
"--config", config_path, "--config", config_path,
@ -147,9 +142,6 @@ def launch_deepspeed_training(config_path, num_gpus):
print(f"Running command: {' '.join(cmd)}") print(f"Running command: {' '.join(cmd)}")
result = subprocess.run(cmd, cwd=Path(__file__).parent.parent) result = subprocess.run(cmd, cwd=Path(__file__).parent.parent)
# Clean up hostfile
hostfile.unlink(missing_ok=True)
return result return result