Compare commits
No commits in common. "3c513fee17491f6cb87fda36e93b7871a210fb97" and "37f1ad94086c52c23ffe00433912b2d567d9f28a" have entirely different histories.
3c513fee17
...
37f1ad9408
1 changed files with 15 additions and 95 deletions
|
|
@ -82,28 +82,11 @@ class ProgressiveReasoningModel:
|
||||||
if quantization_config:
|
if quantization_config:
|
||||||
model_kwargs["quantization_config"] = quantization_config
|
model_kwargs["quantization_config"] = quantization_config
|
||||||
|
|
||||||
# Add attention implementation only if transformers version supports it
|
# Add attention implementation
|
||||||
try:
|
if self.config["model"].get("use_flash_attention_2", False):
|
||||||
from transformers import __version__ as tf_version
|
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
|
elif self.config["model"].get("use_eager_attention", False):
|
||||||
# Simple version check without packaging dependency
|
model_kwargs["attn_implementation"] = "eager"
|
||||||
version_parts = tf_version.split('.')
|
|
||||||
major = int(version_parts[0])
|
|
||||||
minor = int(version_parts[1]) if len(version_parts) > 1 else 0
|
|
||||||
|
|
||||||
# Check if attn_implementation is supported (transformers >= 4.36)
|
|
||||||
if major > 4 or (major == 4 and minor >= 36):
|
|
||||||
if self.config["model"].get("use_flash_attention_2", False):
|
|
||||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
||||||
print("Using flash_attention_2")
|
|
||||||
elif self.config["model"].get("use_eager_attention", False):
|
|
||||||
model_kwargs["attn_implementation"] = "eager"
|
|
||||||
print("Using eager attention")
|
|
||||||
else:
|
|
||||||
print(f"Transformers version {tf_version} does not support attn_implementation parameter")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Could not check transformers version: {e}")
|
|
||||||
# Skip attention implementation to be safe
|
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
print("Loading model with the following kwargs:")
|
print("Loading model with the following kwargs:")
|
||||||
|
|
@ -113,86 +96,22 @@ class ProgressiveReasoningModel:
|
||||||
else:
|
else:
|
||||||
print(f" {k}: <BitsAndBytesConfig>")
|
print(f" {k}: <BitsAndBytesConfig>")
|
||||||
|
|
||||||
# Try loading with progressively simpler configurations
|
|
||||||
model_loaded = False
|
|
||||||
error_messages = []
|
|
||||||
|
|
||||||
# First attempt: with all kwargs
|
|
||||||
try:
|
try:
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
self.base_model_name,
|
self.base_model_name,
|
||||||
**model_kwargs
|
**model_kwargs
|
||||||
)
|
)
|
||||||
model_loaded = True
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_messages.append(f"Full config failed: {e}")
|
|
||||||
print(f"Error loading model: {e}")
|
print(f"Error loading model: {e}")
|
||||||
|
# Try without some problematic kwargs
|
||||||
# Second attempt: remove offload options
|
if "offload_folder" in model_kwargs:
|
||||||
if not model_loaded and "offload_folder" in model_kwargs:
|
print("Retrying without offload_folder...")
|
||||||
try:
|
del model_kwargs["offload_folder"]
|
||||||
print("Retrying without offload options...")
|
del model_kwargs["offload_state_dict"]
|
||||||
model_kwargs_clean = {k: v for k, v in model_kwargs.items()
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
if k not in ["offload_folder", "offload_state_dict"]}
|
self.base_model_name,
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
**model_kwargs
|
||||||
self.base_model_name,
|
)
|
||||||
**model_kwargs_clean
|
|
||||||
)
|
|
||||||
model_loaded = True
|
|
||||||
model_kwargs = model_kwargs_clean
|
|
||||||
except Exception as e:
|
|
||||||
error_messages.append(f"Without offload failed: {e}")
|
|
||||||
print(f"Still failed: {e}")
|
|
||||||
|
|
||||||
# Third attempt: remove attention implementation
|
|
||||||
if not model_loaded and "attn_implementation" in model_kwargs:
|
|
||||||
try:
|
|
||||||
print("Retrying without attention implementation...")
|
|
||||||
model_kwargs_clean = {k: v for k, v in model_kwargs.items()
|
|
||||||
if k != "attn_implementation"}
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
self.base_model_name,
|
|
||||||
**model_kwargs_clean
|
|
||||||
)
|
|
||||||
model_loaded = True
|
|
||||||
model_kwargs = model_kwargs_clean
|
|
||||||
except Exception as e:
|
|
||||||
error_messages.append(f"Without attn_implementation failed: {e}")
|
|
||||||
print(f"Still failed: {e}")
|
|
||||||
|
|
||||||
# Fourth attempt: minimal configuration (only essentials)
|
|
||||||
if not model_loaded:
|
|
||||||
try:
|
|
||||||
print("Retrying with minimal configuration...")
|
|
||||||
minimal_kwargs = {
|
|
||||||
"device_map": model_kwargs.get("device_map", "auto"),
|
|
||||||
"trust_remote_code": True,
|
|
||||||
"torch_dtype": model_kwargs.get("torch_dtype", torch.float32)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Keep quantization if it was specified
|
|
||||||
if "quantization_config" in model_kwargs:
|
|
||||||
minimal_kwargs["quantization_config"] = model_kwargs["quantization_config"]
|
|
||||||
|
|
||||||
# Keep max_memory if it was specified
|
|
||||||
if "max_memory" in model_kwargs:
|
|
||||||
minimal_kwargs["max_memory"] = model_kwargs["max_memory"]
|
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
self.base_model_name,
|
|
||||||
**minimal_kwargs
|
|
||||||
)
|
|
||||||
model_loaded = True
|
|
||||||
model_kwargs = minimal_kwargs
|
|
||||||
except Exception as e:
|
|
||||||
error_messages.append(f"Minimal config failed: {e}")
|
|
||||||
print(f"Minimal config also failed: {e}")
|
|
||||||
|
|
||||||
if not model_loaded:
|
|
||||||
print("All loading attempts failed:")
|
|
||||||
for i, msg in enumerate(error_messages, 1):
|
|
||||||
print(f" {i}. {msg}")
|
|
||||||
raise RuntimeError("Could not load model with any configuration")
|
|
||||||
|
|
||||||
# Prepare for k-bit training if using quantization
|
# Prepare for k-bit training if using quantization
|
||||||
if quantization_config:
|
if quantization_config:
|
||||||
|
|
@ -381,6 +300,7 @@ class ProgressiveReasoningModel:
|
||||||
self.model.save_pretrained(self.adapters[stage_name])
|
self.model.save_pretrained(self.adapters[stage_name])
|
||||||
# Also save tokenizer for convenience
|
# Also save tokenizer for convenience
|
||||||
self.tokenizer.save_pretrained(self.adapters[stage_name])
|
self.tokenizer.save_pretrained(self.adapters[stage_name])
|
||||||
|
|
||||||
def load_for_inference(self, adapter_names: List[str], weights: Optional[Dict[str, float]] = None):
|
def load_for_inference(self, adapter_names: List[str], weights: Optional[Dict[str, float]] = None):
|
||||||
"""Load model with specific adapters for inference"""
|
"""Load model with specific adapters for inference"""
|
||||||
if len(adapter_names) == 1:
|
if len(adapter_names) == 1:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue