59 lines
No EOL
1.5 KiB
Python
Executable file
59 lines
No EOL
1.5 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
"""
|
|
Evaluation script for progressive model
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
|
|
|
from src.progressive_model import ProgressiveReasoningModel
|
|
import yaml
|
|
|
|
|
|
def evaluate_reasoning(model_wrapper, test_prompts):
|
|
"""Evaluate model on test prompts"""
|
|
results = []
|
|
|
|
for prompt in test_prompts:
|
|
print(f"\nPrompt: {prompt}")
|
|
response = model_wrapper.generate_with_reasoning(prompt)
|
|
print(f"Response: {response}")
|
|
results.append({
|
|
"prompt": prompt,
|
|
"response": response
|
|
})
|
|
|
|
return results
|
|
|
|
|
|
def main():
|
|
# Load config
|
|
with open("config/training_config.yaml") as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
# Initialize model
|
|
model_wrapper = ProgressiveReasoningModel(config)
|
|
model_wrapper.setup_base_model()
|
|
|
|
# Test different adapters
|
|
test_prompts = [
|
|
"What is 156 + 389?",
|
|
"If a train travels 80 km/h for 2.5 hours, how far does it go?",
|
|
"Explain why the sky is blue.",
|
|
]
|
|
|
|
# Test each adapter
|
|
for adapter_name in ["basic_cot", "math_reasoning", "complex_reasoning"]:
|
|
if adapter_name in model_wrapper.adapters:
|
|
print(f"\n{'='*50}")
|
|
print(f"Testing adapter: {adapter_name}")
|
|
print(f"{'='*50}")
|
|
|
|
model_wrapper.load_for_inference([adapter_name])
|
|
results = evaluate_reasoning(model_wrapper, test_prompts)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |