#!/usr/bin/env python3 """ Simple comparison script without rich TUI """ import sys from pathlib import Path import yaml import torch import argparse # Add src to path sys.path.append(str(Path(__file__).parent.parent)) from src.progressive_model import ProgressiveReasoningModel def parse_args(): parser = argparse.ArgumentParser(description="Compare original and trained models") parser.add_argument( "--config", "-c", type=str, default="config/training_config_gemma2_small.yaml", help="Path to configuration file" ) parser.add_argument( "--adapter", "-a", type=str, default="basic_cot", help="Adapter name to load for comparison" ) parser.add_argument( "--max-length", type=int, default=512, help="Maximum generation length" ) return parser.parse_args() def load_config(config_path): """Load configuration from file""" config_path = Path(config_path) if not config_path.exists(): raise FileNotFoundError(f"Configuration file not found: {config_path}") with open(config_path) as f: config = yaml.safe_load(f) return config def generate_response(model, tokenizer, prompt, max_length=512): """Generate response using the model""" # Format prompt for Gemma formatted_prompt = f"user\n{prompt}\nmodel\n" # Tokenize inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) # Generate with torch.no_grad(): outputs = model.generate( **inputs, max_length=len(inputs["input_ids"][0]) + max_length, temperature=0.7, do_sample=True, top_p=0.9, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.1, ) # Decode response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the model's response if "model" in response: response = response.split("model")[-1].strip() return response def main(): args = parse_args() try: config = load_config(args.config) except FileNotFoundError as e: print(f"Error: {e}") return print(f"Progressive Model Comparison") print(f"Config: {args.config}") print(f"Base model: {config['experiment']['base_model']}") print(f"Adapter: {args.adapter}") print("="*60) print("Loading models...") # Original model (no adapter) print("Loading original model...") original_model = ProgressiveReasoningModel(config) original_model.setup_base_model() # Trained model (with adapter) print("Loading trained model...") trained_model = ProgressiveReasoningModel(config) trained_model.setup_base_model() # Load the trained adapter if it exists adapter_path = Path(config["experiment"]["output_dir"]) / "adapters" / args.adapter if adapter_path.exists(): print(f"Loading trained adapter from: {adapter_path}") try: trained_model.load_for_inference([args.adapter]) print("Adapter loaded successfully!") except Exception as e: print(f"Error loading adapter: {e}") print("Will compare with base model instead.") else: print(f"No trained adapter found at: {adapter_path}") print("Available adapters:") adapters_dir = Path(config["experiment"]["output_dir"]) / "adapters" if adapters_dir.exists(): for adapter_dir in adapters_dir.iterdir(): if adapter_dir.is_dir(): print(f" - {adapter_dir.name}") else: print(" No adapters directory found.") print("Both models will show original behavior.") print("\nModels loaded! Enter prompts to compare (type 'quit' to exit)") print("Examples:") print(" - What is 25 + 17?") print(" - Explain why the sky is blue") print(" - Solve this step by step: If I have 10 apples and give away 3, how many do I have left?") print() while True: try: prompt = input("\nPrompt: ").strip() if prompt.lower() in ['quit', 'exit', 'q']: break if not prompt: continue print(f"\n{'='*60}") print("ORIGINAL MODEL (No fine-tuning)") print("="*60) try: original_response = generate_response( original_model.model, original_model.tokenizer, prompt, args.max_length ) print(original_response) except Exception as e: print(f"Error generating original response: {e}") print(f"\n{'='*60}") print(f"TRAINED MODEL (With {args.adapter} adapter)") print("="*60) try: # Add CoT prompt for trained model cot_prompt = f"{prompt}\n\nPlease think step by step using tags." trained_response = generate_response( trained_model.model, trained_model.tokenizer, cot_prompt, args.max_length ) print(trained_response) except Exception as e: print(f"Error generating trained response: {e}") except KeyboardInterrupt: print("\nExiting...") break except Exception as e: print(f"Error: {e}") continue if __name__ == "__main__": main()