From 01ef446cc9767abf656b5b2ca87a427a4347fd33 Mon Sep 17 00:00:00 2001
From: Soma Nakamura
Date: Thu, 10 Jul 2025 21:02:39 +0900
Subject: [PATCH] ok
---
src/progressive_model.py | 109 ++++++++++++++++++++++++++++++++++-----
1 file changed, 95 insertions(+), 14 deletions(-)
diff --git a/src/progressive_model.py b/src/progressive_model.py
index 6c5e669..b39bd50 100644
--- a/src/progressive_model.py
+++ b/src/progressive_model.py
@@ -82,11 +82,28 @@ class ProgressiveReasoningModel:
if quantization_config:
model_kwargs["quantization_config"] = quantization_config
- # Add attention implementation
- if self.config["model"].get("use_flash_attention_2", False):
- model_kwargs["attn_implementation"] = "flash_attention_2"
- elif self.config["model"].get("use_eager_attention", False):
- model_kwargs["attn_implementation"] = "eager"
+ # Add attention implementation only if transformers version supports it
+ try:
+ from transformers import __version__ as tf_version
+
+ # Simple version check without packaging dependency
+ version_parts = tf_version.split('.')
+ major = int(version_parts[0])
+ minor = int(version_parts[1]) if len(version_parts) > 1 else 0
+
+ # Check if attn_implementation is supported (transformers >= 4.36)
+ if major > 4 or (major == 4 and minor >= 36):
+ if self.config["model"].get("use_flash_attention_2", False):
+ model_kwargs["attn_implementation"] = "flash_attention_2"
+ print("Using flash_attention_2")
+ elif self.config["model"].get("use_eager_attention", False):
+ model_kwargs["attn_implementation"] = "eager"
+ print("Using eager attention")
+ else:
+ print(f"Transformers version {tf_version} does not support attn_implementation parameter")
+ except Exception as e:
+ print(f"Warning: Could not check transformers version: {e}")
+ # Skip attention implementation to be safe
# Load model
print("Loading model with the following kwargs:")
@@ -96,22 +113,86 @@ class ProgressiveReasoningModel:
else:
print(f" {k}: ")
+ # Try loading with progressively simpler configurations
+ model_loaded = False
+ error_messages = []
+
+ # First attempt: with all kwargs
try:
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
**model_kwargs
)
+ model_loaded = True
except Exception as e:
+ error_messages.append(f"Full config failed: {e}")
print(f"Error loading model: {e}")
- # Try without some problematic kwargs
- if "offload_folder" in model_kwargs:
- print("Retrying without offload_folder...")
- del model_kwargs["offload_folder"]
- del model_kwargs["offload_state_dict"]
- self.model = AutoModelForCausalLM.from_pretrained(
- self.base_model_name,
- **model_kwargs
- )
+
+ # Second attempt: remove offload options
+ if not model_loaded and "offload_folder" in model_kwargs:
+ try:
+ print("Retrying without offload options...")
+ model_kwargs_clean = {k: v for k, v in model_kwargs.items()
+ if k not in ["offload_folder", "offload_state_dict"]}
+ self.model = AutoModelForCausalLM.from_pretrained(
+ self.base_model_name,
+ **model_kwargs_clean
+ )
+ model_loaded = True
+ model_kwargs = model_kwargs_clean
+ except Exception as e:
+ error_messages.append(f"Without offload failed: {e}")
+ print(f"Still failed: {e}")
+
+ # Third attempt: remove attention implementation
+ if not model_loaded and "attn_implementation" in model_kwargs:
+ try:
+ print("Retrying without attention implementation...")
+ model_kwargs_clean = {k: v for k, v in model_kwargs.items()
+ if k != "attn_implementation"}
+ self.model = AutoModelForCausalLM.from_pretrained(
+ self.base_model_name,
+ **model_kwargs_clean
+ )
+ model_loaded = True
+ model_kwargs = model_kwargs_clean
+ except Exception as e:
+ error_messages.append(f"Without attn_implementation failed: {e}")
+ print(f"Still failed: {e}")
+
+ # Fourth attempt: minimal configuration (only essentials)
+ if not model_loaded:
+ try:
+ print("Retrying with minimal configuration...")
+ minimal_kwargs = {
+ "device_map": model_kwargs.get("device_map", "auto"),
+ "trust_remote_code": True,
+ "torch_dtype": model_kwargs.get("torch_dtype", torch.float32)
+ }
+
+ # Keep quantization if it was specified
+ if "quantization_config" in model_kwargs:
+ minimal_kwargs["quantization_config"] = model_kwargs["quantization_config"]
+
+ # Keep max_memory if it was specified
+ if "max_memory" in model_kwargs:
+ minimal_kwargs["max_memory"] = model_kwargs["max_memory"]
+
+ self.model = AutoModelForCausalLM.from_pretrained(
+ self.base_model_name,
+ **minimal_kwargs
+ )
+ model_loaded = True
+ model_kwargs = minimal_kwargs
+ except Exception as e:
+ error_messages.append(f"Minimal config failed: {e}")
+ print(f"Minimal config also failed: {e}")
+
+ if not model_loaded:
+ print("All loading attempts failed:")
+ for i, msg in enumerate(error_messages, 1):
+ print(f" {i}. {msg}")
+ raise RuntimeError("Could not load model with any configuration")
# Prepare for k-bit training if using quantization
if quantization_config: