Compare commits
2 commits
37f1ad9408
...
3c513fee17
| Author | SHA1 | Date | |
|---|---|---|---|
| 3c513fee17 | |||
| 01ef446cc9 |
1 changed files with 95 additions and 15 deletions
|
|
@ -82,11 +82,28 @@ class ProgressiveReasoningModel:
|
|||
if quantization_config:
|
||||
model_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
# Add attention implementation
|
||||
# Add attention implementation only if transformers version supports it
|
||||
try:
|
||||
from transformers import __version__ as tf_version
|
||||
|
||||
# Simple version check without packaging dependency
|
||||
version_parts = tf_version.split('.')
|
||||
major = int(version_parts[0])
|
||||
minor = int(version_parts[1]) if len(version_parts) > 1 else 0
|
||||
|
||||
# Check if attn_implementation is supported (transformers >= 4.36)
|
||||
if major > 4 or (major == 4 and minor >= 36):
|
||||
if self.config["model"].get("use_flash_attention_2", False):
|
||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
print("Using flash_attention_2")
|
||||
elif self.config["model"].get("use_eager_attention", False):
|
||||
model_kwargs["attn_implementation"] = "eager"
|
||||
print("Using eager attention")
|
||||
else:
|
||||
print(f"Transformers version {tf_version} does not support attn_implementation parameter")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not check transformers version: {e}")
|
||||
# Skip attention implementation to be safe
|
||||
|
||||
# Load model
|
||||
print("Loading model with the following kwargs:")
|
||||
|
|
@ -96,22 +113,86 @@ class ProgressiveReasoningModel:
|
|||
else:
|
||||
print(f" {k}: <BitsAndBytesConfig>")
|
||||
|
||||
# Try loading with progressively simpler configurations
|
||||
model_loaded = False
|
||||
error_messages = []
|
||||
|
||||
# First attempt: with all kwargs
|
||||
try:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.base_model_name,
|
||||
**model_kwargs
|
||||
)
|
||||
model_loaded = True
|
||||
except Exception as e:
|
||||
error_messages.append(f"Full config failed: {e}")
|
||||
print(f"Error loading model: {e}")
|
||||
# Try without some problematic kwargs
|
||||
if "offload_folder" in model_kwargs:
|
||||
print("Retrying without offload_folder...")
|
||||
del model_kwargs["offload_folder"]
|
||||
del model_kwargs["offload_state_dict"]
|
||||
|
||||
# Second attempt: remove offload options
|
||||
if not model_loaded and "offload_folder" in model_kwargs:
|
||||
try:
|
||||
print("Retrying without offload options...")
|
||||
model_kwargs_clean = {k: v for k, v in model_kwargs.items()
|
||||
if k not in ["offload_folder", "offload_state_dict"]}
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.base_model_name,
|
||||
**model_kwargs
|
||||
**model_kwargs_clean
|
||||
)
|
||||
model_loaded = True
|
||||
model_kwargs = model_kwargs_clean
|
||||
except Exception as e:
|
||||
error_messages.append(f"Without offload failed: {e}")
|
||||
print(f"Still failed: {e}")
|
||||
|
||||
# Third attempt: remove attention implementation
|
||||
if not model_loaded and "attn_implementation" in model_kwargs:
|
||||
try:
|
||||
print("Retrying without attention implementation...")
|
||||
model_kwargs_clean = {k: v for k, v in model_kwargs.items()
|
||||
if k != "attn_implementation"}
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.base_model_name,
|
||||
**model_kwargs_clean
|
||||
)
|
||||
model_loaded = True
|
||||
model_kwargs = model_kwargs_clean
|
||||
except Exception as e:
|
||||
error_messages.append(f"Without attn_implementation failed: {e}")
|
||||
print(f"Still failed: {e}")
|
||||
|
||||
# Fourth attempt: minimal configuration (only essentials)
|
||||
if not model_loaded:
|
||||
try:
|
||||
print("Retrying with minimal configuration...")
|
||||
minimal_kwargs = {
|
||||
"device_map": model_kwargs.get("device_map", "auto"),
|
||||
"trust_remote_code": True,
|
||||
"torch_dtype": model_kwargs.get("torch_dtype", torch.float32)
|
||||
}
|
||||
|
||||
# Keep quantization if it was specified
|
||||
if "quantization_config" in model_kwargs:
|
||||
minimal_kwargs["quantization_config"] = model_kwargs["quantization_config"]
|
||||
|
||||
# Keep max_memory if it was specified
|
||||
if "max_memory" in model_kwargs:
|
||||
minimal_kwargs["max_memory"] = model_kwargs["max_memory"]
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.base_model_name,
|
||||
**minimal_kwargs
|
||||
)
|
||||
model_loaded = True
|
||||
model_kwargs = minimal_kwargs
|
||||
except Exception as e:
|
||||
error_messages.append(f"Minimal config failed: {e}")
|
||||
print(f"Minimal config also failed: {e}")
|
||||
|
||||
if not model_loaded:
|
||||
print("All loading attempts failed:")
|
||||
for i, msg in enumerate(error_messages, 1):
|
||||
print(f" {i}. {msg}")
|
||||
raise RuntimeError("Could not load model with any configuration")
|
||||
|
||||
# Prepare for k-bit training if using quantization
|
||||
if quantization_config:
|
||||
|
|
@ -300,7 +381,6 @@ class ProgressiveReasoningModel:
|
|||
self.model.save_pretrained(self.adapters[stage_name])
|
||||
# Also save tokenizer for convenience
|
||||
self.tokenizer.save_pretrained(self.adapters[stage_name])
|
||||
|
||||
def load_for_inference(self, adapter_names: List[str], weights: Optional[Dict[str, float]] = None):
|
||||
"""Load model with specific adapters for inference"""
|
||||
if len(adapter_names) == 1:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue