progressive-llm/scripts/train_progressive.py
2025-07-10 18:09:14 +09:00

131 lines
No EOL
3.6 KiB
Python
Executable file

#!/usr/bin/env python3
"""
Main training script for progressive reasoning model
"""
import sys
import yaml
import argparse
from pathlib import Path
# Add src to path
sys.path.append(str(Path(__file__).parent.parent))
from src.progressive_model import ProgressiveReasoningModel
from src.training import ProgressiveTrainer
from src.data_utils import prepare_sample_datasets
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Progressive LLM Training for 松尾研LLMコンペ2025",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Use default config
python scripts/train_progressive.py
# Use specific config file
python scripts/train_progressive.py --config config/training_config_large.yaml
# Use config with custom path
python scripts/train_progressive.py --config /path/to/my_config.yaml
# Prepare sample datasets
python scripts/train_progressive.py --prepare-data
"""
)
parser.add_argument(
"--config", "-c",
type=str,
default="config/training_config.yaml",
help="Path to the training configuration file (default: config/training_config.yaml)"
)
parser.add_argument(
"--prepare-data",
action="store_true",
help="Prepare sample datasets before training"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Load config and model but skip training (for testing)"
)
return parser.parse_args()
def load_config(config_path: str) -> dict:
"""Load configuration from file"""
config_path = Path(config_path)
if not config_path.exists():
raise FileNotFoundError(f"Configuration file not found: {config_path}")
print(f"Loading configuration from: {config_path}")
with open(config_path) as f:
config = yaml.safe_load(f)
return config
def main():
args = parse_args()
print("Progressive LLM Training for 松尾研LLMコンペ2025")
print("=" * 50)
# Load configuration
try:
config = load_config(args.config)
except FileNotFoundError as e:
print(f"Error: {e}")
print("Available config files:")
config_dir = Path("config")
if config_dir.exists():
for config_file in config_dir.glob("*.yaml"):
print(f" {config_file}")
sys.exit(1)
except Exception as e:
print(f"Error loading config: {e}")
sys.exit(1)
# Print configuration info
print(f"Experiment: {config['experiment']['name']}")
print(f"Base model: {config['experiment']['base_model']}")
print(f"Output directory: {config['experiment']['output_dir']}")
print(f"Stages: {len(config['progressive_stages'])}")
# Prepare sample datasets if requested
if args.prepare_data:
print("\nPreparing sample datasets...")
prepare_sample_datasets()
print("Sample datasets prepared.")
# Initialize model wrapper
print("\nInitializing model...")
model_wrapper = ProgressiveReasoningModel(config)
model_wrapper.setup_base_model()
if args.dry_run:
print("\nDry run completed. Model loaded successfully.")
return
# Initialize trainer
print("\nInitializing trainer...")
trainer = ProgressiveTrainer(model_wrapper, config)
# Run progressive training
print("\nStarting progressive training...")
trainer.run_progressive_training()
print("\nTraining completed successfully!")
if __name__ == "__main__":
main()