201 lines
No EOL
5.8 KiB
Python
201 lines
No EOL
5.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Debug script to identify model loading issues
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import torch
|
|
from pathlib import Path
|
|
|
|
# Add src to path
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
|
|
|
def clear_accelerate_env():
|
|
"""Clear all ACCELERATE environment variables"""
|
|
print("Clearing ACCELERATE environment variables...")
|
|
env_vars_to_clear = []
|
|
for key in os.environ:
|
|
if 'ACCELERATE' in key:
|
|
env_vars_to_clear.append(key)
|
|
|
|
for var in env_vars_to_clear:
|
|
print(f" Removing {var}={os.environ[var]}")
|
|
del os.environ[var]
|
|
|
|
def test_basic_model_loading():
|
|
"""Test basic model loading without any configuration"""
|
|
print("Testing basic model loading...")
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
model_name = "google/gemma-2-2b-it"
|
|
|
|
try:
|
|
print("Testing with absolutely minimal config...")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
trust_remote_code=True,
|
|
torch_dtype=torch.float32
|
|
)
|
|
print("✅ Basic loading successful!")
|
|
del model
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ Basic loading failed: {e}")
|
|
return False
|
|
|
|
def test_with_device_map():
|
|
"""Test with device_map auto"""
|
|
print("Testing with device_map='auto'...")
|
|
|
|
from transformers import AutoModelForCausalLM
|
|
|
|
model_name = "google/gemma-2-2b-it"
|
|
|
|
try:
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
trust_remote_code=True,
|
|
torch_dtype=torch.float32,
|
|
device_map="auto"
|
|
)
|
|
print("✅ Device map loading successful!")
|
|
del model
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ Device map loading failed: {e}")
|
|
return False
|
|
|
|
def test_with_quantization():
|
|
"""Test with quantization"""
|
|
print("Testing with 4-bit quantization...")
|
|
|
|
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
|
|
|
model_name = "google/gemma-2-2b-it"
|
|
|
|
try:
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
bnb_4bit_use_double_quant=True,
|
|
bnb_4bit_quant_type="nf4"
|
|
)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
trust_remote_code=True,
|
|
quantization_config=bnb_config
|
|
)
|
|
print("✅ Quantization loading successful!")
|
|
del model
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ Quantization loading failed: {e}")
|
|
return False
|
|
|
|
def print_environment_info():
|
|
"""Print detailed environment information"""
|
|
print("\n" + "="*50)
|
|
print("ENVIRONMENT INFORMATION")
|
|
print("="*50)
|
|
|
|
# Python version
|
|
print(f"Python version: {sys.version}")
|
|
|
|
# PyTorch info
|
|
try:
|
|
import torch
|
|
print(f"PyTorch version: {torch.__version__}")
|
|
print(f"CUDA available: {torch.cuda.is_available()}")
|
|
if torch.cuda.is_available():
|
|
print(f"CUDA device count: {torch.cuda.device_count()}")
|
|
for i in range(torch.cuda.device_count()):
|
|
print(f" Device {i}: {torch.cuda.get_device_name(i)}")
|
|
print(f"CUDA version: {torch.version.cuda}")
|
|
except Exception as e:
|
|
print(f"PyTorch info error: {e}")
|
|
|
|
# Transformers info
|
|
try:
|
|
from transformers import __version__ as tf_version
|
|
print(f"Transformers version: {tf_version}")
|
|
except Exception as e:
|
|
print(f"Transformers info error: {e}")
|
|
|
|
# Accelerate info
|
|
try:
|
|
from accelerate import __version__ as acc_version
|
|
print(f"Accelerate version: {acc_version}")
|
|
except Exception as e:
|
|
print(f"Accelerate info error: {e}")
|
|
|
|
# PEFT info
|
|
try:
|
|
from peft import __version__ as peft_version
|
|
print(f"PEFT version: {peft_version}")
|
|
except Exception as e:
|
|
print(f"PEFT info error: {e}")
|
|
|
|
# BitsAndBytes info
|
|
try:
|
|
import bitsandbytes as bnb
|
|
print(f"BitsAndBytes version: {bnb.__version__}")
|
|
except Exception as e:
|
|
print(f"BitsAndBytes info error: {e}")
|
|
|
|
# Environment variables
|
|
print("\nRelevant environment variables:")
|
|
for key, value in sorted(os.environ.items()):
|
|
if any(prefix in key for prefix in ['CUDA', 'TORCH', 'HF_', 'ACCELERATE', 'TRANSFORMERS']):
|
|
print(f" {key}={value}")
|
|
|
|
def main():
|
|
print("Progressive LLM Training - Model Loading Debug")
|
|
print("=" * 60)
|
|
|
|
# Print environment info first
|
|
print_environment_info()
|
|
|
|
# Clear environment variables
|
|
clear_accelerate_env()
|
|
|
|
# Test various loading methods
|
|
print("\n" + "="*50)
|
|
print("TESTING MODEL LOADING")
|
|
print("="*50)
|
|
|
|
results = []
|
|
|
|
# Test 1: Basic loading
|
|
results.append(("Basic loading", test_basic_model_loading()))
|
|
|
|
# Test 2: With device map
|
|
results.append(("Device map", test_with_device_map()))
|
|
|
|
# Test 3: With quantization
|
|
results.append(("Quantization", test_with_quantization()))
|
|
|
|
# Summary
|
|
print("\n" + "="*50)
|
|
print("SUMMARY")
|
|
print("="*50)
|
|
|
|
for test_name, success in results:
|
|
status = "✅ PASS" if success else "❌ FAIL"
|
|
print(f"{test_name}: {status}")
|
|
|
|
if any(result[1] for result in results):
|
|
print("\n✅ At least one loading method works!")
|
|
print("Use the successful method in your configuration.")
|
|
else:
|
|
print("\n❌ All loading methods failed!")
|
|
print("This indicates a fundamental environment issue.")
|
|
print("Consider:")
|
|
print("1. Reinstalling transformers, accelerate, torch")
|
|
print("2. Checking CUDA installation")
|
|
print("3. Using a different model")
|
|
|
|
if __name__ == "__main__":
|
|
main() |