progressive-llm/scripts/compare_models_tui.py
2025-07-10 18:09:14 +09:00

183 lines
No EOL
6.9 KiB
Python
Executable file

#!/usr/bin/env python3
"""
TUI for comparing original and trained models
"""
import sys
from pathlib import Path
import yaml
import torch
from rich.console import Console
from rich.panel import Panel
from rich.columns import Columns
from rich.prompt import Prompt
from rich.text import Text
from rich.layout import Layout
from rich.live import Live
from rich.table import Table
import time
# Add src to path
sys.path.append(str(Path(__file__).parent.parent))
from src.progressive_model import ProgressiveReasoningModel
class ModelCompareTUI:
def __init__(self, config_path: str = "config/training_config.yaml"):
self.console = Console()
# Load configuration
with open(config_path) as f:
self.config = yaml.safe_load(f)
# Initialize models
self.console.print("[yellow]Loading models...[/yellow]")
# Original model
self.original_model = ProgressiveReasoningModel(self.config)
self.original_model.setup_base_model()
# Trained model
self.trained_model = ProgressiveReasoningModel(self.config)
self.trained_model.setup_base_model()
# Load the trained adapter if it exists
adapter_path = Path(self.config["experiment"]["output_dir"]) / "adapters" / "basic_cot"
if adapter_path.exists():
self.console.print(f"[green]Loading trained adapter from: {adapter_path}[/green]")
self.trained_model.load_for_inference(["basic_cot"])
else:
self.console.print("[red]No trained adapter found. Please run training first.[/red]")
self.console.print("[yellow]Both models will show original behavior.[/yellow]")
self.console.print("[green]Models loaded successfully![/green]\n")
def generate_response(self, model, prompt: str, with_think_tags: bool = True) -> str:
"""Generate response from a model"""
# For trained model, encourage think tags
if with_think_tags and model == self.trained_model:
formatted_prompt = f"{prompt}\n\nPlease think step by step."
else:
formatted_prompt = prompt
inputs = model.tokenizer(formatted_prompt, return_tensors="pt").to(model.model.device)
with torch.no_grad():
outputs = model.model.generate(
**inputs,
max_length=512,
temperature=0.7,
do_sample=True,
top_p=0.95,
pad_token_id=model.tokenizer.pad_token_id,
eos_token_id=model.tokenizer.eos_token_id
)
response = model.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract response after prompt
response = response[len(formatted_prompt):].strip()
return response
def create_comparison_panel(self, prompt: str, original_response: str, trained_response: str) -> Panel:
"""Create a panel showing the comparison"""
# Create table
table = Table(show_header=True, header_style="bold magenta", expand=True)
table.add_column("Original Model", style="cyan", width=50)
table.add_column("Trained Model (with CoT)", style="green", width=50)
table.add_row(original_response, trained_response)
return Panel(
table,
title=f"[bold yellow]Prompt: {prompt}[/bold yellow]",
border_style="blue"
)
def run_interactive_mode(self):
"""Run interactive comparison mode"""
self.console.print("\n[bold cyan]Model Comparison TUI[/bold cyan]")
self.console.print("Compare responses from original and trained models\n")
self.console.print("[dim]Type 'quit' or 'exit' to leave[/dim]\n")
while True:
# Get user prompt
prompt = Prompt.ask("\n[bold yellow]Enter your prompt[/bold yellow]")
if prompt.lower() in ['quit', 'exit']:
self.console.print("\n[yellow]Goodbye![/yellow]")
break
# Generate responses
self.console.print("\n[dim]Generating responses...[/dim]")
start_time = time.time()
original_response = self.generate_response(self.original_model, prompt, with_think_tags=False)
original_time = time.time() - start_time
start_time = time.time()
trained_response = self.generate_response(self.trained_model, prompt, with_think_tags=True)
trained_time = time.time() - start_time
# Display comparison
panel = self.create_comparison_panel(prompt, original_response, trained_response)
self.console.print(panel)
# Show generation times
self.console.print(f"\n[dim]Generation times - Original: {original_time:.2f}s, Trained: {trained_time:.2f}s[/dim]")
def run_benchmark_mode(self):
"""Run benchmark with predefined prompts"""
test_prompts = [
"What is 156 + 389?",
"If I have 23 apples and buy 17 more, how many do I have?",
"A store has 145 items. If 38 are sold, how many remain?",
"What is 45 * 12?",
"Explain why 2 + 2 = 4",
"If a train travels 80 km/h for 2.5 hours, how far does it go?",
"What is the sum of all numbers from 1 to 10?",
"How many minutes are in 3.5 hours?",
]
self.console.print("\n[bold cyan]Running Benchmark Comparison[/bold cyan]\n")
for i, prompt in enumerate(test_prompts, 1):
self.console.print(f"[bold]Test {i}/{len(test_prompts)}[/bold]")
# Generate responses
original_response = self.generate_response(self.original_model, prompt, with_think_tags=False)
trained_response = self.generate_response(self.trained_model, prompt, with_think_tags=True)
# Display comparison
panel = self.create_comparison_panel(prompt, original_response, trained_response)
self.console.print(panel)
self.console.print("")
self.console.print("[green]Benchmark completed![/green]")
def main():
import argparse
parser = argparse.ArgumentParser(description="Compare original and trained models")
parser.add_argument("--mode", choices=["interactive", "benchmark"], default="interactive",
help="Mode to run the comparison")
parser.add_argument("--config", default="config/training_config.yaml",
help="Path to configuration file")
args = parser.parse_args()
# Create TUI
tui = ModelCompareTUI(args.config)
# Run in selected mode
if args.mode == "interactive":
tui.run_interactive_mode()
else:
tui.run_benchmark_mode()
if __name__ == "__main__":
main()