progressive-llm/scripts/simple_compare.py
2025-07-10 18:09:14 +09:00

189 lines
No EOL
5.8 KiB
Python
Executable file

#!/usr/bin/env python3
"""
Simple comparison script without rich TUI
"""
import sys
from pathlib import Path
import yaml
import torch
import argparse
# Add src to path
sys.path.append(str(Path(__file__).parent.parent))
from src.progressive_model import ProgressiveReasoningModel
def parse_args():
parser = argparse.ArgumentParser(description="Compare original and trained models")
parser.add_argument(
"--config", "-c",
type=str,
default="config/training_config_gemma2_small.yaml",
help="Path to configuration file"
)
parser.add_argument(
"--adapter", "-a",
type=str,
default="basic_cot",
help="Adapter name to load for comparison"
)
parser.add_argument(
"--max-length",
type=int,
default=512,
help="Maximum generation length"
)
return parser.parse_args()
def load_config(config_path):
"""Load configuration from file"""
config_path = Path(config_path)
if not config_path.exists():
raise FileNotFoundError(f"Configuration file not found: {config_path}")
with open(config_path) as f:
config = yaml.safe_load(f)
return config
def generate_response(model, tokenizer, prompt, max_length=512):
"""Generate response using the model"""
# Format prompt for Gemma
formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
# Tokenize
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=len(inputs["input_ids"][0]) + max_length,
temperature=0.7,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1,
)
# Decode
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the model's response
if "<start_of_turn>model" in response:
response = response.split("<start_of_turn>model")[-1].strip()
return response
def main():
args = parse_args()
try:
config = load_config(args.config)
except FileNotFoundError as e:
print(f"Error: {e}")
return
print(f"Progressive Model Comparison")
print(f"Config: {args.config}")
print(f"Base model: {config['experiment']['base_model']}")
print(f"Adapter: {args.adapter}")
print("="*60)
print("Loading models...")
# Original model (no adapter)
print("Loading original model...")
original_model = ProgressiveReasoningModel(config)
original_model.setup_base_model()
# Trained model (with adapter)
print("Loading trained model...")
trained_model = ProgressiveReasoningModel(config)
trained_model.setup_base_model()
# Load the trained adapter if it exists
adapter_path = Path(config["experiment"]["output_dir"]) / "adapters" / args.adapter
if adapter_path.exists():
print(f"Loading trained adapter from: {adapter_path}")
try:
trained_model.load_for_inference([args.adapter])
print("Adapter loaded successfully!")
except Exception as e:
print(f"Error loading adapter: {e}")
print("Will compare with base model instead.")
else:
print(f"No trained adapter found at: {adapter_path}")
print("Available adapters:")
adapters_dir = Path(config["experiment"]["output_dir"]) / "adapters"
if adapters_dir.exists():
for adapter_dir in adapters_dir.iterdir():
if adapter_dir.is_dir():
print(f" - {adapter_dir.name}")
else:
print(" No adapters directory found.")
print("Both models will show original behavior.")
print("\nModels loaded! Enter prompts to compare (type 'quit' to exit)")
print("Examples:")
print(" - What is 25 + 17?")
print(" - Explain why the sky is blue")
print(" - Solve this step by step: If I have 10 apples and give away 3, how many do I have left?")
print()
while True:
try:
prompt = input("\nPrompt: ").strip()
if prompt.lower() in ['quit', 'exit', 'q']:
break
if not prompt:
continue
print(f"\n{'='*60}")
print("ORIGINAL MODEL (No fine-tuning)")
print("="*60)
try:
original_response = generate_response(
original_model.model,
original_model.tokenizer,
prompt,
args.max_length
)
print(original_response)
except Exception as e:
print(f"Error generating original response: {e}")
print(f"\n{'='*60}")
print(f"TRAINED MODEL (With {args.adapter} adapter)")
print("="*60)
try:
# Add CoT prompt for trained model
cot_prompt = f"{prompt}\n\nPlease think step by step using <think> tags."
trained_response = generate_response(
trained_model.model,
trained_model.tokenizer,
cot_prompt,
args.max_length
)
print(trained_response)
except Exception as e:
print(f"Error generating trained response: {e}")
except KeyboardInterrupt:
print("\nExiting...")
break
except Exception as e:
print(f"Error: {e}")
continue
if __name__ == "__main__":
main()