131 lines
No EOL
3.6 KiB
Python
Executable file
131 lines
No EOL
3.6 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
"""
|
|
Main training script for progressive reasoning model
|
|
"""
|
|
|
|
import sys
|
|
import yaml
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
# Add src to path
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
|
|
|
from src.progressive_model import ProgressiveReasoningModel
|
|
from src.training import ProgressiveTrainer
|
|
from src.data_utils import prepare_sample_datasets
|
|
|
|
|
|
def parse_args():
|
|
"""Parse command line arguments"""
|
|
parser = argparse.ArgumentParser(
|
|
description="Progressive LLM Training for 松尾研LLMコンペ2025",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Use default config
|
|
python scripts/train_progressive.py
|
|
|
|
# Use specific config file
|
|
python scripts/train_progressive.py --config config/training_config_large.yaml
|
|
|
|
# Use config with custom path
|
|
python scripts/train_progressive.py --config /path/to/my_config.yaml
|
|
|
|
# Prepare sample datasets
|
|
python scripts/train_progressive.py --prepare-data
|
|
"""
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--config", "-c",
|
|
type=str,
|
|
default="config/training_config.yaml",
|
|
help="Path to the training configuration file (default: config/training_config.yaml)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--prepare-data",
|
|
action="store_true",
|
|
help="Prepare sample datasets before training"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Load config and model but skip training (for testing)"
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_config(config_path: str) -> dict:
|
|
"""Load configuration from file"""
|
|
config_path = Path(config_path)
|
|
|
|
if not config_path.exists():
|
|
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
|
|
|
print(f"Loading configuration from: {config_path}")
|
|
|
|
with open(config_path) as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
return config
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
print("Progressive LLM Training for 松尾研LLMコンペ2025")
|
|
print("=" * 50)
|
|
|
|
# Load configuration
|
|
try:
|
|
config = load_config(args.config)
|
|
except FileNotFoundError as e:
|
|
print(f"Error: {e}")
|
|
print("Available config files:")
|
|
config_dir = Path("config")
|
|
if config_dir.exists():
|
|
for config_file in config_dir.glob("*.yaml"):
|
|
print(f" {config_file}")
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
print(f"Error loading config: {e}")
|
|
sys.exit(1)
|
|
|
|
# Print configuration info
|
|
print(f"Experiment: {config['experiment']['name']}")
|
|
print(f"Base model: {config['experiment']['base_model']}")
|
|
print(f"Output directory: {config['experiment']['output_dir']}")
|
|
print(f"Stages: {len(config['progressive_stages'])}")
|
|
|
|
# Prepare sample datasets if requested
|
|
if args.prepare_data:
|
|
print("\nPreparing sample datasets...")
|
|
prepare_sample_datasets()
|
|
print("Sample datasets prepared.")
|
|
|
|
# Initialize model wrapper
|
|
print("\nInitializing model...")
|
|
model_wrapper = ProgressiveReasoningModel(config)
|
|
model_wrapper.setup_base_model()
|
|
|
|
if args.dry_run:
|
|
print("\nDry run completed. Model loaded successfully.")
|
|
return
|
|
|
|
# Initialize trainer
|
|
print("\nInitializing trainer...")
|
|
trainer = ProgressiveTrainer(model_wrapper, config)
|
|
|
|
# Run progressive training
|
|
print("\nStarting progressive training...")
|
|
trainer.run_progressive_training()
|
|
|
|
print("\nTraining completed successfully!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |