progressive-llm/scripts/debug_model_loading.py
2025-07-10 22:25:11 +09:00

201 lines
No EOL
5.8 KiB
Python

#!/usr/bin/env python3
"""
Debug script to identify model loading issues
"""
import sys
import os
import torch
from pathlib import Path
# Add src to path
sys.path.append(str(Path(__file__).parent.parent))
def clear_accelerate_env():
"""Clear all ACCELERATE environment variables"""
print("Clearing ACCELERATE environment variables...")
env_vars_to_clear = []
for key in os.environ:
if 'ACCELERATE' in key:
env_vars_to_clear.append(key)
for var in env_vars_to_clear:
print(f" Removing {var}={os.environ[var]}")
del os.environ[var]
def test_basic_model_loading():
"""Test basic model loading without any configuration"""
print("Testing basic model loading...")
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "google/gemma-2-2b-it"
try:
print("Testing with absolutely minimal config...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float32
)
print("✅ Basic loading successful!")
del model
return True
except Exception as e:
print(f"❌ Basic loading failed: {e}")
return False
def test_with_device_map():
"""Test with device_map auto"""
print("Testing with device_map='auto'...")
from transformers import AutoModelForCausalLM
model_name = "google/gemma-2-2b-it"
try:
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float32,
device_map="auto"
)
print("✅ Device map loading successful!")
del model
return True
except Exception as e:
print(f"❌ Device map loading failed: {e}")
return False
def test_with_quantization():
"""Test with quantization"""
print("Testing with 4-bit quantization...")
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
model_name = "google/gemma-2-2b-it"
try:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
quantization_config=bnb_config
)
print("✅ Quantization loading successful!")
del model
return True
except Exception as e:
print(f"❌ Quantization loading failed: {e}")
return False
def print_environment_info():
"""Print detailed environment information"""
print("\n" + "="*50)
print("ENVIRONMENT INFORMATION")
print("="*50)
# Python version
print(f"Python version: {sys.version}")
# PyTorch info
try:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f" Device {i}: {torch.cuda.get_device_name(i)}")
print(f"CUDA version: {torch.version.cuda}")
except Exception as e:
print(f"PyTorch info error: {e}")
# Transformers info
try:
from transformers import __version__ as tf_version
print(f"Transformers version: {tf_version}")
except Exception as e:
print(f"Transformers info error: {e}")
# Accelerate info
try:
from accelerate import __version__ as acc_version
print(f"Accelerate version: {acc_version}")
except Exception as e:
print(f"Accelerate info error: {e}")
# PEFT info
try:
from peft import __version__ as peft_version
print(f"PEFT version: {peft_version}")
except Exception as e:
print(f"PEFT info error: {e}")
# BitsAndBytes info
try:
import bitsandbytes as bnb
print(f"BitsAndBytes version: {bnb.__version__}")
except Exception as e:
print(f"BitsAndBytes info error: {e}")
# Environment variables
print("\nRelevant environment variables:")
for key, value in sorted(os.environ.items()):
if any(prefix in key for prefix in ['CUDA', 'TORCH', 'HF_', 'ACCELERATE', 'TRANSFORMERS']):
print(f" {key}={value}")
def main():
print("Progressive LLM Training - Model Loading Debug")
print("=" * 60)
# Print environment info first
print_environment_info()
# Clear environment variables
clear_accelerate_env()
# Test various loading methods
print("\n" + "="*50)
print("TESTING MODEL LOADING")
print("="*50)
results = []
# Test 1: Basic loading
results.append(("Basic loading", test_basic_model_loading()))
# Test 2: With device map
results.append(("Device map", test_with_device_map()))
# Test 3: With quantization
results.append(("Quantization", test_with_quantization()))
# Summary
print("\n" + "="*50)
print("SUMMARY")
print("="*50)
for test_name, success in results:
status = "✅ PASS" if success else "❌ FAIL"
print(f"{test_name}: {status}")
if any(result[1] for result in results):
print("\n✅ At least one loading method works!")
print("Use the successful method in your configuration.")
else:
print("\n❌ All loading methods failed!")
print("This indicates a fundamental environment issue.")
print("Consider:")
print("1. Reinstalling transformers, accelerate, torch")
print("2. Checking CUDA installation")
print("3. Using a different model")
if __name__ == "__main__":
main()