#!/usr/bin/env python3 """ TUI for comparing original and trained models """ import sys from pathlib import Path import yaml import torch from rich.console import Console from rich.panel import Panel from rich.columns import Columns from rich.prompt import Prompt from rich.text import Text from rich.layout import Layout from rich.live import Live from rich.table import Table import time # Add src to path sys.path.append(str(Path(__file__).parent.parent)) from src.progressive_model import ProgressiveReasoningModel class ModelCompareTUI: def __init__(self, config_path: str = "config/training_config.yaml"): self.console = Console() # Load configuration with open(config_path) as f: self.config = yaml.safe_load(f) # Initialize models self.console.print("[yellow]Loading models...[/yellow]") # Original model self.original_model = ProgressiveReasoningModel(self.config) self.original_model.setup_base_model() # Trained model self.trained_model = ProgressiveReasoningModel(self.config) self.trained_model.setup_base_model() # Load the trained adapter if it exists adapter_path = Path(self.config["experiment"]["output_dir"]) / "adapters" / "basic_cot" if adapter_path.exists(): self.console.print(f"[green]Loading trained adapter from: {adapter_path}[/green]") self.trained_model.load_for_inference(["basic_cot"]) else: self.console.print("[red]No trained adapter found. Please run training first.[/red]") self.console.print("[yellow]Both models will show original behavior.[/yellow]") self.console.print("[green]Models loaded successfully![/green]\n") def generate_response(self, model, prompt: str, with_think_tags: bool = True) -> str: """Generate response from a model""" # For trained model, encourage think tags if with_think_tags and model == self.trained_model: formatted_prompt = f"{prompt}\n\nPlease think step by step." else: formatted_prompt = prompt inputs = model.tokenizer(formatted_prompt, return_tensors="pt").to(model.model.device) with torch.no_grad(): outputs = model.model.generate( **inputs, max_length=512, temperature=0.7, do_sample=True, top_p=0.95, pad_token_id=model.tokenizer.pad_token_id, eos_token_id=model.tokenizer.eos_token_id ) response = model.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract response after prompt response = response[len(formatted_prompt):].strip() return response def create_comparison_panel(self, prompt: str, original_response: str, trained_response: str) -> Panel: """Create a panel showing the comparison""" # Create table table = Table(show_header=True, header_style="bold magenta", expand=True) table.add_column("Original Model", style="cyan", width=50) table.add_column("Trained Model (with CoT)", style="green", width=50) table.add_row(original_response, trained_response) return Panel( table, title=f"[bold yellow]Prompt: {prompt}[/bold yellow]", border_style="blue" ) def run_interactive_mode(self): """Run interactive comparison mode""" self.console.print("\n[bold cyan]Model Comparison TUI[/bold cyan]") self.console.print("Compare responses from original and trained models\n") self.console.print("[dim]Type 'quit' or 'exit' to leave[/dim]\n") while True: # Get user prompt prompt = Prompt.ask("\n[bold yellow]Enter your prompt[/bold yellow]") if prompt.lower() in ['quit', 'exit']: self.console.print("\n[yellow]Goodbye![/yellow]") break # Generate responses self.console.print("\n[dim]Generating responses...[/dim]") start_time = time.time() original_response = self.generate_response(self.original_model, prompt, with_think_tags=False) original_time = time.time() - start_time start_time = time.time() trained_response = self.generate_response(self.trained_model, prompt, with_think_tags=True) trained_time = time.time() - start_time # Display comparison panel = self.create_comparison_panel(prompt, original_response, trained_response) self.console.print(panel) # Show generation times self.console.print(f"\n[dim]Generation times - Original: {original_time:.2f}s, Trained: {trained_time:.2f}s[/dim]") def run_benchmark_mode(self): """Run benchmark with predefined prompts""" test_prompts = [ "What is 156 + 389?", "If I have 23 apples and buy 17 more, how many do I have?", "A store has 145 items. If 38 are sold, how many remain?", "What is 45 * 12?", "Explain why 2 + 2 = 4", "If a train travels 80 km/h for 2.5 hours, how far does it go?", "What is the sum of all numbers from 1 to 10?", "How many minutes are in 3.5 hours?", ] self.console.print("\n[bold cyan]Running Benchmark Comparison[/bold cyan]\n") for i, prompt in enumerate(test_prompts, 1): self.console.print(f"[bold]Test {i}/{len(test_prompts)}[/bold]") # Generate responses original_response = self.generate_response(self.original_model, prompt, with_think_tags=False) trained_response = self.generate_response(self.trained_model, prompt, with_think_tags=True) # Display comparison panel = self.create_comparison_panel(prompt, original_response, trained_response) self.console.print(panel) self.console.print("") self.console.print("[green]Benchmark completed![/green]") def main(): import argparse parser = argparse.ArgumentParser(description="Compare original and trained models") parser.add_argument("--mode", choices=["interactive", "benchmark"], default="interactive", help="Mode to run the comparison") parser.add_argument("--config", default="config/training_config.yaml", help="Path to configuration file") args = parser.parse_args() # Create TUI tui = ModelCompareTUI(args.config) # Run in selected mode if args.mode == "interactive": tui.run_interactive_mode() else: tui.run_benchmark_mode() if __name__ == "__main__": main()