grad-repair
This commit is contained in:
parent
2d01c6577f
commit
5a784102b9
2 changed files with 26 additions and 27 deletions
|
|
@ -6,11 +6,11 @@ experiment:
|
||||||
wandb_project: "matsuo-llm-comp-2025"
|
wandb_project: "matsuo-llm-comp-2025"
|
||||||
|
|
||||||
model:
|
model:
|
||||||
load_in_4bit: false # Can use FP16/BF16 with multiple GPUs
|
load_in_4bit: true # Enable quantization for memory savings
|
||||||
bnb_4bit_compute_dtype: "bfloat16"
|
bnb_4bit_compute_dtype: "bfloat16"
|
||||||
bnb_4bit_use_double_quant: true
|
bnb_4bit_use_double_quant: true
|
||||||
device_map: "balanced" # Distribute across all GPUs
|
device_map: "balanced" # Distribute across all GPUs
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true # Enable gradient checkpointing
|
||||||
use_flash_attention_2: false
|
use_flash_attention_2: false
|
||||||
use_eager_attention: true
|
use_eager_attention: true
|
||||||
|
|
||||||
|
|
@ -25,24 +25,24 @@ progressive_stages:
|
||||||
description: "Basic Chain-of-Thought reasoning"
|
description: "Basic Chain-of-Thought reasoning"
|
||||||
dataset_path: "./data/basic_cot/"
|
dataset_path: "./data/basic_cot/"
|
||||||
adapter_config:
|
adapter_config:
|
||||||
r: 16 # Moderate rank for DDP
|
r: 8 # Minimal rank for memory
|
||||||
lora_alpha: 32
|
lora_alpha: 16
|
||||||
lora_dropout: 0.1
|
lora_dropout: 0.1
|
||||||
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||||
init_lora_weights: true
|
init_lora_weights: true
|
||||||
training:
|
training:
|
||||||
num_epochs: 2
|
num_epochs: 2
|
||||||
per_device_batch_size: 8 # 8 * 8 = 64 total batch size (reduced for memory)
|
per_device_batch_size: 1 # 1 * 8 = 8 total batch size (minimal)
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 8 # Maintain effective batch size
|
||||||
learning_rate: 5e-4
|
learning_rate: 5e-4
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
max_length: 1024
|
max_length: 512 # Reduced sequence length
|
||||||
bf16: true
|
bf16: true
|
||||||
max_grad_norm: 1.0
|
max_grad_norm: 1.0
|
||||||
weight_decay: 0.001
|
weight_decay: 0.001
|
||||||
save_steps: 50
|
save_steps: 50
|
||||||
logging_steps: 10
|
logging_steps: 10
|
||||||
dataloader_num_workers: 4
|
dataloader_num_workers: 2
|
||||||
dataloader_pin_memory: false
|
dataloader_pin_memory: false
|
||||||
|
|
||||||
- name: "math_reasoning"
|
- name: "math_reasoning"
|
||||||
|
|
@ -50,24 +50,24 @@ progressive_stages:
|
||||||
dataset_path: "open-r1/OpenR1-Math-220k"
|
dataset_path: "open-r1/OpenR1-Math-220k"
|
||||||
inherit_from: "basic_cot"
|
inherit_from: "basic_cot"
|
||||||
adapter_config:
|
adapter_config:
|
||||||
r: 32
|
r: 8 # Minimal rank for memory
|
||||||
lora_alpha: 64
|
lora_alpha: 16
|
||||||
lora_dropout: 0.1
|
lora_dropout: 0.1
|
||||||
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||||
init_lora_weights: true
|
init_lora_weights: true
|
||||||
training:
|
training:
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
per_device_batch_size: 4 # 4 * 8 = 32 total batch size (reduced for memory)
|
per_device_batch_size: 1 # 1 * 8 = 8 total batch size (minimal)
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 4
|
||||||
learning_rate: 3e-4
|
learning_rate: 3e-4
|
||||||
warmup_steps: 200
|
warmup_steps: 200
|
||||||
max_length: 2048
|
max_length: 1024 # Reduced sequence length
|
||||||
bf16: true
|
bf16: true
|
||||||
max_grad_norm: 1.0
|
max_grad_norm: 1.0
|
||||||
weight_decay: 0.001
|
weight_decay: 0.001
|
||||||
save_steps: 100
|
save_steps: 100
|
||||||
logging_steps: 20
|
logging_steps: 20
|
||||||
dataloader_num_workers: 4
|
dataloader_num_workers: 2
|
||||||
dataset_config:
|
dataset_config:
|
||||||
streaming: true
|
streaming: true
|
||||||
max_samples: 400000 # Process substantial data
|
max_samples: 400000 # Process substantial data
|
||||||
|
|
@ -78,24 +78,24 @@ progressive_stages:
|
||||||
dataset_path: "open-r1/Mixture-of-Thoughts"
|
dataset_path: "open-r1/Mixture-of-Thoughts"
|
||||||
inherit_from: "math_reasoning"
|
inherit_from: "math_reasoning"
|
||||||
adapter_config:
|
adapter_config:
|
||||||
r: 64
|
r: 8 # Minimal rank for memory
|
||||||
lora_alpha: 128
|
lora_alpha: 16
|
||||||
lora_dropout: 0.1
|
lora_dropout: 0.1
|
||||||
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||||
init_lora_weights: true
|
init_lora_weights: true
|
||||||
training:
|
training:
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
per_device_batch_size: 2 # 2 * 8 = 16 total batch size (reduced for memory)
|
per_device_batch_size: 1 # 1 * 8 = 8 total batch size (minimal)
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 2
|
||||||
learning_rate: 2e-4
|
learning_rate: 2e-4
|
||||||
warmup_steps: 300
|
warmup_steps: 300
|
||||||
max_length: 4096
|
max_length: 1024 # Reduced sequence length
|
||||||
bf16: true
|
bf16: true
|
||||||
max_grad_norm: 1.0
|
max_grad_norm: 1.0
|
||||||
weight_decay: 0.001
|
weight_decay: 0.001
|
||||||
save_steps: 200
|
save_steps: 200
|
||||||
logging_steps: 50
|
logging_steps: 50
|
||||||
dataloader_num_workers: 4
|
dataloader_num_workers: 2
|
||||||
dataset_config:
|
dataset_config:
|
||||||
streaming: true
|
streaming: true
|
||||||
max_samples: 600000
|
max_samples: 600000
|
||||||
|
|
|
||||||
|
|
@ -198,16 +198,15 @@ class ProgressiveReasoningModel:
|
||||||
if quantization_config:
|
if quantization_config:
|
||||||
self.model = prepare_model_for_kbit_training(self.model)
|
self.model = prepare_model_for_kbit_training(self.model)
|
||||||
|
|
||||||
# Disable gradient checkpointing for now to avoid conflicts
|
# Enable gradient checkpointing if requested
|
||||||
# Enable gradient checkpointing if requested (but disable use_cache)
|
if self.config["model"].get("gradient_checkpointing", False):
|
||||||
# if self.config["model"].get("gradient_checkpointing", False):
|
self.model.gradient_checkpointing_enable()
|
||||||
# self.model.gradient_checkpointing_enable()
|
print("Gradient checkpointing enabled")
|
||||||
# self.model.config.use_cache = False
|
|
||||||
# print("Gradient checkpointing enabled, use_cache disabled")
|
|
||||||
|
|
||||||
# Explicitly disable use_cache to avoid conflicts
|
# Explicitly disable use_cache to avoid conflicts and save memory
|
||||||
if hasattr(self.model, 'config'):
|
if hasattr(self.model, 'config'):
|
||||||
self.model.config.use_cache = False
|
self.model.config.use_cache = False
|
||||||
|
print("use_cache disabled for memory efficiency")
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
tokenizer_kwargs = {"trust_remote_code": True}
|
tokenizer_kwargs = {"trust_remote_code": True}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue