189 lines
No EOL
5.8 KiB
Python
Executable file
189 lines
No EOL
5.8 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
"""
|
|
Simple comparison script without rich TUI
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
import yaml
|
|
import torch
|
|
import argparse
|
|
|
|
# Add src to path
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
|
|
|
from src.progressive_model import ProgressiveReasoningModel
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Compare original and trained models")
|
|
parser.add_argument(
|
|
"--config", "-c",
|
|
type=str,
|
|
default="config/training_config_gemma2_small.yaml",
|
|
help="Path to configuration file"
|
|
)
|
|
parser.add_argument(
|
|
"--adapter", "-a",
|
|
type=str,
|
|
default="basic_cot",
|
|
help="Adapter name to load for comparison"
|
|
)
|
|
parser.add_argument(
|
|
"--max-length",
|
|
type=int,
|
|
default=512,
|
|
help="Maximum generation length"
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_config(config_path):
|
|
"""Load configuration from file"""
|
|
config_path = Path(config_path)
|
|
if not config_path.exists():
|
|
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
|
|
|
with open(config_path) as f:
|
|
config = yaml.safe_load(f)
|
|
return config
|
|
|
|
|
|
def generate_response(model, tokenizer, prompt, max_length=512):
|
|
"""Generate response using the model"""
|
|
# Format prompt for Gemma
|
|
formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
|
|
|
|
# Tokenize
|
|
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
|
|
|
# Generate
|
|
with torch.no_grad():
|
|
outputs = model.generate(
|
|
**inputs,
|
|
max_length=len(inputs["input_ids"][0]) + max_length,
|
|
temperature=0.7,
|
|
do_sample=True,
|
|
top_p=0.9,
|
|
pad_token_id=tokenizer.pad_token_id,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
repetition_penalty=1.1,
|
|
)
|
|
|
|
# Decode
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
# Extract only the model's response
|
|
if "<start_of_turn>model" in response:
|
|
response = response.split("<start_of_turn>model")[-1].strip()
|
|
|
|
return response
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
try:
|
|
config = load_config(args.config)
|
|
except FileNotFoundError as e:
|
|
print(f"Error: {e}")
|
|
return
|
|
|
|
print(f"Progressive Model Comparison")
|
|
print(f"Config: {args.config}")
|
|
print(f"Base model: {config['experiment']['base_model']}")
|
|
print(f"Adapter: {args.adapter}")
|
|
print("="*60)
|
|
|
|
print("Loading models...")
|
|
|
|
# Original model (no adapter)
|
|
print("Loading original model...")
|
|
original_model = ProgressiveReasoningModel(config)
|
|
original_model.setup_base_model()
|
|
|
|
# Trained model (with adapter)
|
|
print("Loading trained model...")
|
|
trained_model = ProgressiveReasoningModel(config)
|
|
trained_model.setup_base_model()
|
|
|
|
# Load the trained adapter if it exists
|
|
adapter_path = Path(config["experiment"]["output_dir"]) / "adapters" / args.adapter
|
|
if adapter_path.exists():
|
|
print(f"Loading trained adapter from: {adapter_path}")
|
|
try:
|
|
trained_model.load_for_inference([args.adapter])
|
|
print("Adapter loaded successfully!")
|
|
except Exception as e:
|
|
print(f"Error loading adapter: {e}")
|
|
print("Will compare with base model instead.")
|
|
else:
|
|
print(f"No trained adapter found at: {adapter_path}")
|
|
print("Available adapters:")
|
|
adapters_dir = Path(config["experiment"]["output_dir"]) / "adapters"
|
|
if adapters_dir.exists():
|
|
for adapter_dir in adapters_dir.iterdir():
|
|
if adapter_dir.is_dir():
|
|
print(f" - {adapter_dir.name}")
|
|
else:
|
|
print(" No adapters directory found.")
|
|
print("Both models will show original behavior.")
|
|
|
|
print("\nModels loaded! Enter prompts to compare (type 'quit' to exit)")
|
|
print("Examples:")
|
|
print(" - What is 25 + 17?")
|
|
print(" - Explain why the sky is blue")
|
|
print(" - Solve this step by step: If I have 10 apples and give away 3, how many do I have left?")
|
|
print()
|
|
|
|
while True:
|
|
try:
|
|
prompt = input("\nPrompt: ").strip()
|
|
if prompt.lower() in ['quit', 'exit', 'q']:
|
|
break
|
|
|
|
if not prompt:
|
|
continue
|
|
|
|
print(f"\n{'='*60}")
|
|
print("ORIGINAL MODEL (No fine-tuning)")
|
|
print("="*60)
|
|
|
|
try:
|
|
original_response = generate_response(
|
|
original_model.model,
|
|
original_model.tokenizer,
|
|
prompt,
|
|
args.max_length
|
|
)
|
|
print(original_response)
|
|
except Exception as e:
|
|
print(f"Error generating original response: {e}")
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"TRAINED MODEL (With {args.adapter} adapter)")
|
|
print("="*60)
|
|
|
|
try:
|
|
# Add CoT prompt for trained model
|
|
cot_prompt = f"{prompt}\n\nPlease think step by step using <think> tags."
|
|
trained_response = generate_response(
|
|
trained_model.model,
|
|
trained_model.tokenizer,
|
|
cot_prompt,
|
|
args.max_length
|
|
)
|
|
print(trained_response)
|
|
except Exception as e:
|
|
print(f"Error generating trained response: {e}")
|
|
|
|
except KeyboardInterrupt:
|
|
print("\nExiting...")
|
|
break
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
continue
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |