Compare commits

..

2 commits

Author SHA1 Message Date
3c513fee17 Merge remote changes with local modifications
- Updated training config for Gemma3 1B with CPU offload support
- Enhanced progressive_model.py with better error handling
- Added support for Mixture-of-Thoughts dataset
- Improved compatibility across different server environments

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-10 21:07:48 +09:00
01ef446cc9 ok 2025-07-10 21:02:39 +09:00

View file

@ -82,11 +82,28 @@ class ProgressiveReasoningModel:
if quantization_config:
model_kwargs["quantization_config"] = quantization_config
# Add attention implementation
if self.config["model"].get("use_flash_attention_2", False):
model_kwargs["attn_implementation"] = "flash_attention_2"
elif self.config["model"].get("use_eager_attention", False):
model_kwargs["attn_implementation"] = "eager"
# Add attention implementation only if transformers version supports it
try:
from transformers import __version__ as tf_version
# Simple version check without packaging dependency
version_parts = tf_version.split('.')
major = int(version_parts[0])
minor = int(version_parts[1]) if len(version_parts) > 1 else 0
# Check if attn_implementation is supported (transformers >= 4.36)
if major > 4 or (major == 4 and minor >= 36):
if self.config["model"].get("use_flash_attention_2", False):
model_kwargs["attn_implementation"] = "flash_attention_2"
print("Using flash_attention_2")
elif self.config["model"].get("use_eager_attention", False):
model_kwargs["attn_implementation"] = "eager"
print("Using eager attention")
else:
print(f"Transformers version {tf_version} does not support attn_implementation parameter")
except Exception as e:
print(f"Warning: Could not check transformers version: {e}")
# Skip attention implementation to be safe
# Load model
print("Loading model with the following kwargs:")
@ -96,22 +113,86 @@ class ProgressiveReasoningModel:
else:
print(f" {k}: <BitsAndBytesConfig>")
# Try loading with progressively simpler configurations
model_loaded = False
error_messages = []
# First attempt: with all kwargs
try:
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
**model_kwargs
)
model_loaded = True
except Exception as e:
error_messages.append(f"Full config failed: {e}")
print(f"Error loading model: {e}")
# Try without some problematic kwargs
if "offload_folder" in model_kwargs:
print("Retrying without offload_folder...")
del model_kwargs["offload_folder"]
del model_kwargs["offload_state_dict"]
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
**model_kwargs
)
# Second attempt: remove offload options
if not model_loaded and "offload_folder" in model_kwargs:
try:
print("Retrying without offload options...")
model_kwargs_clean = {k: v for k, v in model_kwargs.items()
if k not in ["offload_folder", "offload_state_dict"]}
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
**model_kwargs_clean
)
model_loaded = True
model_kwargs = model_kwargs_clean
except Exception as e:
error_messages.append(f"Without offload failed: {e}")
print(f"Still failed: {e}")
# Third attempt: remove attention implementation
if not model_loaded and "attn_implementation" in model_kwargs:
try:
print("Retrying without attention implementation...")
model_kwargs_clean = {k: v for k, v in model_kwargs.items()
if k != "attn_implementation"}
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
**model_kwargs_clean
)
model_loaded = True
model_kwargs = model_kwargs_clean
except Exception as e:
error_messages.append(f"Without attn_implementation failed: {e}")
print(f"Still failed: {e}")
# Fourth attempt: minimal configuration (only essentials)
if not model_loaded:
try:
print("Retrying with minimal configuration...")
minimal_kwargs = {
"device_map": model_kwargs.get("device_map", "auto"),
"trust_remote_code": True,
"torch_dtype": model_kwargs.get("torch_dtype", torch.float32)
}
# Keep quantization if it was specified
if "quantization_config" in model_kwargs:
minimal_kwargs["quantization_config"] = model_kwargs["quantization_config"]
# Keep max_memory if it was specified
if "max_memory" in model_kwargs:
minimal_kwargs["max_memory"] = model_kwargs["max_memory"]
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
**minimal_kwargs
)
model_loaded = True
model_kwargs = minimal_kwargs
except Exception as e:
error_messages.append(f"Minimal config failed: {e}")
print(f"Minimal config also failed: {e}")
if not model_loaded:
print("All loading attempts failed:")
for i, msg in enumerate(error_messages, 1):
print(f" {i}. {msg}")
raise RuntimeError("Could not load model with any configuration")
# Prepare for k-bit training if using quantization
if quantization_config:
@ -300,7 +381,6 @@ class ProgressiveReasoningModel:
self.model.save_pretrained(self.adapters[stage_name])
# Also save tokenizer for convenience
self.tokenizer.save_pretrained(self.adapters[stage_name])
def load_for_inference(self, adapter_names: List[str], weights: Optional[Dict[str, float]] = None):
"""Load model with specific adapters for inference"""
if len(adapter_names) == 1: