183 lines
No EOL
6.9 KiB
Python
Executable file
183 lines
No EOL
6.9 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
"""
|
|
TUI for comparing original and trained models
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
import yaml
|
|
import torch
|
|
from rich.console import Console
|
|
from rich.panel import Panel
|
|
from rich.columns import Columns
|
|
from rich.prompt import Prompt
|
|
from rich.text import Text
|
|
from rich.layout import Layout
|
|
from rich.live import Live
|
|
from rich.table import Table
|
|
import time
|
|
|
|
# Add src to path
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
|
|
|
from src.progressive_model import ProgressiveReasoningModel
|
|
|
|
|
|
class ModelCompareTUI:
|
|
def __init__(self, config_path: str = "config/training_config.yaml"):
|
|
self.console = Console()
|
|
|
|
# Load configuration
|
|
with open(config_path) as f:
|
|
self.config = yaml.safe_load(f)
|
|
|
|
# Initialize models
|
|
self.console.print("[yellow]Loading models...[/yellow]")
|
|
|
|
# Original model
|
|
self.original_model = ProgressiveReasoningModel(self.config)
|
|
self.original_model.setup_base_model()
|
|
|
|
# Trained model
|
|
self.trained_model = ProgressiveReasoningModel(self.config)
|
|
self.trained_model.setup_base_model()
|
|
|
|
# Load the trained adapter if it exists
|
|
adapter_path = Path(self.config["experiment"]["output_dir"]) / "adapters" / "basic_cot"
|
|
if adapter_path.exists():
|
|
self.console.print(f"[green]Loading trained adapter from: {adapter_path}[/green]")
|
|
self.trained_model.load_for_inference(["basic_cot"])
|
|
else:
|
|
self.console.print("[red]No trained adapter found. Please run training first.[/red]")
|
|
self.console.print("[yellow]Both models will show original behavior.[/yellow]")
|
|
|
|
self.console.print("[green]Models loaded successfully![/green]\n")
|
|
|
|
def generate_response(self, model, prompt: str, with_think_tags: bool = True) -> str:
|
|
"""Generate response from a model"""
|
|
# For trained model, encourage think tags
|
|
if with_think_tags and model == self.trained_model:
|
|
formatted_prompt = f"{prompt}\n\nPlease think step by step."
|
|
else:
|
|
formatted_prompt = prompt
|
|
|
|
inputs = model.tokenizer(formatted_prompt, return_tensors="pt").to(model.model.device)
|
|
|
|
with torch.no_grad():
|
|
outputs = model.model.generate(
|
|
**inputs,
|
|
max_length=512,
|
|
temperature=0.7,
|
|
do_sample=True,
|
|
top_p=0.95,
|
|
pad_token_id=model.tokenizer.pad_token_id,
|
|
eos_token_id=model.tokenizer.eos_token_id
|
|
)
|
|
|
|
response = model.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
# Extract response after prompt
|
|
response = response[len(formatted_prompt):].strip()
|
|
|
|
return response
|
|
|
|
def create_comparison_panel(self, prompt: str, original_response: str, trained_response: str) -> Panel:
|
|
"""Create a panel showing the comparison"""
|
|
# Create table
|
|
table = Table(show_header=True, header_style="bold magenta", expand=True)
|
|
table.add_column("Original Model", style="cyan", width=50)
|
|
table.add_column("Trained Model (with CoT)", style="green", width=50)
|
|
|
|
table.add_row(original_response, trained_response)
|
|
|
|
return Panel(
|
|
table,
|
|
title=f"[bold yellow]Prompt: {prompt}[/bold yellow]",
|
|
border_style="blue"
|
|
)
|
|
|
|
def run_interactive_mode(self):
|
|
"""Run interactive comparison mode"""
|
|
self.console.print("\n[bold cyan]Model Comparison TUI[/bold cyan]")
|
|
self.console.print("Compare responses from original and trained models\n")
|
|
self.console.print("[dim]Type 'quit' or 'exit' to leave[/dim]\n")
|
|
|
|
while True:
|
|
# Get user prompt
|
|
prompt = Prompt.ask("\n[bold yellow]Enter your prompt[/bold yellow]")
|
|
|
|
if prompt.lower() in ['quit', 'exit']:
|
|
self.console.print("\n[yellow]Goodbye![/yellow]")
|
|
break
|
|
|
|
# Generate responses
|
|
self.console.print("\n[dim]Generating responses...[/dim]")
|
|
|
|
start_time = time.time()
|
|
original_response = self.generate_response(self.original_model, prompt, with_think_tags=False)
|
|
original_time = time.time() - start_time
|
|
|
|
start_time = time.time()
|
|
trained_response = self.generate_response(self.trained_model, prompt, with_think_tags=True)
|
|
trained_time = time.time() - start_time
|
|
|
|
# Display comparison
|
|
panel = self.create_comparison_panel(prompt, original_response, trained_response)
|
|
self.console.print(panel)
|
|
|
|
# Show generation times
|
|
self.console.print(f"\n[dim]Generation times - Original: {original_time:.2f}s, Trained: {trained_time:.2f}s[/dim]")
|
|
|
|
def run_benchmark_mode(self):
|
|
"""Run benchmark with predefined prompts"""
|
|
test_prompts = [
|
|
"What is 156 + 389?",
|
|
"If I have 23 apples and buy 17 more, how many do I have?",
|
|
"A store has 145 items. If 38 are sold, how many remain?",
|
|
"What is 45 * 12?",
|
|
"Explain why 2 + 2 = 4",
|
|
"If a train travels 80 km/h for 2.5 hours, how far does it go?",
|
|
"What is the sum of all numbers from 1 to 10?",
|
|
"How many minutes are in 3.5 hours?",
|
|
]
|
|
|
|
self.console.print("\n[bold cyan]Running Benchmark Comparison[/bold cyan]\n")
|
|
|
|
for i, prompt in enumerate(test_prompts, 1):
|
|
self.console.print(f"[bold]Test {i}/{len(test_prompts)}[/bold]")
|
|
|
|
# Generate responses
|
|
original_response = self.generate_response(self.original_model, prompt, with_think_tags=False)
|
|
trained_response = self.generate_response(self.trained_model, prompt, with_think_tags=True)
|
|
|
|
# Display comparison
|
|
panel = self.create_comparison_panel(prompt, original_response, trained_response)
|
|
self.console.print(panel)
|
|
self.console.print("")
|
|
|
|
self.console.print("[green]Benchmark completed![/green]")
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Compare original and trained models")
|
|
parser.add_argument("--mode", choices=["interactive", "benchmark"], default="interactive",
|
|
help="Mode to run the comparison")
|
|
parser.add_argument("--config", default="config/training_config.yaml",
|
|
help="Path to configuration file")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create TUI
|
|
tui = ModelCompareTUI(args.config)
|
|
|
|
# Run in selected mode
|
|
if args.mode == "interactive":
|
|
tui.run_interactive_mode()
|
|
else:
|
|
tui.run_benchmark_mode()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |