Compare commits

..

No commits in common. "3c513fee17491f6cb87fda36e93b7871a210fb97" and "37f1ad94086c52c23ffe00433912b2d567d9f28a" have entirely different histories.

View file

@ -82,28 +82,11 @@ class ProgressiveReasoningModel:
if quantization_config: if quantization_config:
model_kwargs["quantization_config"] = quantization_config model_kwargs["quantization_config"] = quantization_config
# Add attention implementation only if transformers version supports it # Add attention implementation
try: if self.config["model"].get("use_flash_attention_2", False):
from transformers import __version__ as tf_version model_kwargs["attn_implementation"] = "flash_attention_2"
elif self.config["model"].get("use_eager_attention", False):
# Simple version check without packaging dependency model_kwargs["attn_implementation"] = "eager"
version_parts = tf_version.split('.')
major = int(version_parts[0])
minor = int(version_parts[1]) if len(version_parts) > 1 else 0
# Check if attn_implementation is supported (transformers >= 4.36)
if major > 4 or (major == 4 and minor >= 36):
if self.config["model"].get("use_flash_attention_2", False):
model_kwargs["attn_implementation"] = "flash_attention_2"
print("Using flash_attention_2")
elif self.config["model"].get("use_eager_attention", False):
model_kwargs["attn_implementation"] = "eager"
print("Using eager attention")
else:
print(f"Transformers version {tf_version} does not support attn_implementation parameter")
except Exception as e:
print(f"Warning: Could not check transformers version: {e}")
# Skip attention implementation to be safe
# Load model # Load model
print("Loading model with the following kwargs:") print("Loading model with the following kwargs:")
@ -113,86 +96,22 @@ class ProgressiveReasoningModel:
else: else:
print(f" {k}: <BitsAndBytesConfig>") print(f" {k}: <BitsAndBytesConfig>")
# Try loading with progressively simpler configurations
model_loaded = False
error_messages = []
# First attempt: with all kwargs
try: try:
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.base_model_name, self.base_model_name,
**model_kwargs **model_kwargs
) )
model_loaded = True
except Exception as e: except Exception as e:
error_messages.append(f"Full config failed: {e}")
print(f"Error loading model: {e}") print(f"Error loading model: {e}")
# Try without some problematic kwargs
# Second attempt: remove offload options if "offload_folder" in model_kwargs:
if not model_loaded and "offload_folder" in model_kwargs: print("Retrying without offload_folder...")
try: del model_kwargs["offload_folder"]
print("Retrying without offload options...") del model_kwargs["offload_state_dict"]
model_kwargs_clean = {k: v for k, v in model_kwargs.items() self.model = AutoModelForCausalLM.from_pretrained(
if k not in ["offload_folder", "offload_state_dict"]} self.base_model_name,
self.model = AutoModelForCausalLM.from_pretrained( **model_kwargs
self.base_model_name, )
**model_kwargs_clean
)
model_loaded = True
model_kwargs = model_kwargs_clean
except Exception as e:
error_messages.append(f"Without offload failed: {e}")
print(f"Still failed: {e}")
# Third attempt: remove attention implementation
if not model_loaded and "attn_implementation" in model_kwargs:
try:
print("Retrying without attention implementation...")
model_kwargs_clean = {k: v for k, v in model_kwargs.items()
if k != "attn_implementation"}
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
**model_kwargs_clean
)
model_loaded = True
model_kwargs = model_kwargs_clean
except Exception as e:
error_messages.append(f"Without attn_implementation failed: {e}")
print(f"Still failed: {e}")
# Fourth attempt: minimal configuration (only essentials)
if not model_loaded:
try:
print("Retrying with minimal configuration...")
minimal_kwargs = {
"device_map": model_kwargs.get("device_map", "auto"),
"trust_remote_code": True,
"torch_dtype": model_kwargs.get("torch_dtype", torch.float32)
}
# Keep quantization if it was specified
if "quantization_config" in model_kwargs:
minimal_kwargs["quantization_config"] = model_kwargs["quantization_config"]
# Keep max_memory if it was specified
if "max_memory" in model_kwargs:
minimal_kwargs["max_memory"] = model_kwargs["max_memory"]
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
**minimal_kwargs
)
model_loaded = True
model_kwargs = minimal_kwargs
except Exception as e:
error_messages.append(f"Minimal config failed: {e}")
print(f"Minimal config also failed: {e}")
if not model_loaded:
print("All loading attempts failed:")
for i, msg in enumerate(error_messages, 1):
print(f" {i}. {msg}")
raise RuntimeError("Could not load model with any configuration")
# Prepare for k-bit training if using quantization # Prepare for k-bit training if using quantization
if quantization_config: if quantization_config:
@ -381,6 +300,7 @@ class ProgressiveReasoningModel:
self.model.save_pretrained(self.adapters[stage_name]) self.model.save_pretrained(self.adapters[stage_name])
# Also save tokenizer for convenience # Also save tokenizer for convenience
self.tokenizer.save_pretrained(self.adapters[stage_name]) self.tokenizer.save_pretrained(self.adapters[stage_name])
def load_for_inference(self, adapter_names: List[str], weights: Optional[Dict[str, float]] = None): def load_for_inference(self, adapter_names: List[str], weights: Optional[Dict[str, float]] = None):
"""Load model with specific adapters for inference""" """Load model with specific adapters for inference"""
if len(adapter_names) == 1: if len(adapter_names) == 1: