From 01ef446cc9767abf656b5b2ca87a427a4347fd33 Mon Sep 17 00:00:00 2001 From: Soma Nakamura Date: Thu, 10 Jul 2025 21:02:39 +0900 Subject: [PATCH] ok --- src/progressive_model.py | 109 ++++++++++++++++++++++++++++++++++----- 1 file changed, 95 insertions(+), 14 deletions(-) diff --git a/src/progressive_model.py b/src/progressive_model.py index 6c5e669..b39bd50 100644 --- a/src/progressive_model.py +++ b/src/progressive_model.py @@ -82,11 +82,28 @@ class ProgressiveReasoningModel: if quantization_config: model_kwargs["quantization_config"] = quantization_config - # Add attention implementation - if self.config["model"].get("use_flash_attention_2", False): - model_kwargs["attn_implementation"] = "flash_attention_2" - elif self.config["model"].get("use_eager_attention", False): - model_kwargs["attn_implementation"] = "eager" + # Add attention implementation only if transformers version supports it + try: + from transformers import __version__ as tf_version + + # Simple version check without packaging dependency + version_parts = tf_version.split('.') + major = int(version_parts[0]) + minor = int(version_parts[1]) if len(version_parts) > 1 else 0 + + # Check if attn_implementation is supported (transformers >= 4.36) + if major > 4 or (major == 4 and minor >= 36): + if self.config["model"].get("use_flash_attention_2", False): + model_kwargs["attn_implementation"] = "flash_attention_2" + print("Using flash_attention_2") + elif self.config["model"].get("use_eager_attention", False): + model_kwargs["attn_implementation"] = "eager" + print("Using eager attention") + else: + print(f"Transformers version {tf_version} does not support attn_implementation parameter") + except Exception as e: + print(f"Warning: Could not check transformers version: {e}") + # Skip attention implementation to be safe # Load model print("Loading model with the following kwargs:") @@ -96,22 +113,86 @@ class ProgressiveReasoningModel: else: print(f" {k}: ") + # Try loading with progressively simpler configurations + model_loaded = False + error_messages = [] + + # First attempt: with all kwargs try: self.model = AutoModelForCausalLM.from_pretrained( self.base_model_name, **model_kwargs ) + model_loaded = True except Exception as e: + error_messages.append(f"Full config failed: {e}") print(f"Error loading model: {e}") - # Try without some problematic kwargs - if "offload_folder" in model_kwargs: - print("Retrying without offload_folder...") - del model_kwargs["offload_folder"] - del model_kwargs["offload_state_dict"] - self.model = AutoModelForCausalLM.from_pretrained( - self.base_model_name, - **model_kwargs - ) + + # Second attempt: remove offload options + if not model_loaded and "offload_folder" in model_kwargs: + try: + print("Retrying without offload options...") + model_kwargs_clean = {k: v for k, v in model_kwargs.items() + if k not in ["offload_folder", "offload_state_dict"]} + self.model = AutoModelForCausalLM.from_pretrained( + self.base_model_name, + **model_kwargs_clean + ) + model_loaded = True + model_kwargs = model_kwargs_clean + except Exception as e: + error_messages.append(f"Without offload failed: {e}") + print(f"Still failed: {e}") + + # Third attempt: remove attention implementation + if not model_loaded and "attn_implementation" in model_kwargs: + try: + print("Retrying without attention implementation...") + model_kwargs_clean = {k: v for k, v in model_kwargs.items() + if k != "attn_implementation"} + self.model = AutoModelForCausalLM.from_pretrained( + self.base_model_name, + **model_kwargs_clean + ) + model_loaded = True + model_kwargs = model_kwargs_clean + except Exception as e: + error_messages.append(f"Without attn_implementation failed: {e}") + print(f"Still failed: {e}") + + # Fourth attempt: minimal configuration (only essentials) + if not model_loaded: + try: + print("Retrying with minimal configuration...") + minimal_kwargs = { + "device_map": model_kwargs.get("device_map", "auto"), + "trust_remote_code": True, + "torch_dtype": model_kwargs.get("torch_dtype", torch.float32) + } + + # Keep quantization if it was specified + if "quantization_config" in model_kwargs: + minimal_kwargs["quantization_config"] = model_kwargs["quantization_config"] + + # Keep max_memory if it was specified + if "max_memory" in model_kwargs: + minimal_kwargs["max_memory"] = model_kwargs["max_memory"] + + self.model = AutoModelForCausalLM.from_pretrained( + self.base_model_name, + **minimal_kwargs + ) + model_loaded = True + model_kwargs = minimal_kwargs + except Exception as e: + error_messages.append(f"Minimal config failed: {e}") + print(f"Minimal config also failed: {e}") + + if not model_loaded: + print("All loading attempts failed:") + for i, msg in enumerate(error_messages, 1): + print(f" {i}. {msg}") + raise RuntimeError("Could not load model with any configuration") # Prepare for k-bit training if using quantization if quantization_config: