grad-repair
This commit is contained in:
parent
5a784102b9
commit
821c26bf69
2 changed files with 51 additions and 8 deletions
|
|
@ -6,9 +6,7 @@ experiment:
|
||||||
wandb_project: "matsuo-llm-comp-2025"
|
wandb_project: "matsuo-llm-comp-2025"
|
||||||
|
|
||||||
model:
|
model:
|
||||||
load_in_4bit: true # Enable quantization for memory savings
|
load_in_4bit: false # Disable quantization due to CUDA kernel issues
|
||||||
bnb_4bit_compute_dtype: "bfloat16"
|
|
||||||
bnb_4bit_use_double_quant: true
|
|
||||||
device_map: "balanced" # Distribute across all GPUs
|
device_map: "balanced" # Distribute across all GPUs
|
||||||
gradient_checkpointing: true # Enable gradient checkpointing
|
gradient_checkpointing: true # Enable gradient checkpointing
|
||||||
use_flash_attention_2: false
|
use_flash_attention_2: false
|
||||||
|
|
@ -36,7 +34,7 @@ progressive_stages:
|
||||||
gradient_accumulation_steps: 8 # Maintain effective batch size
|
gradient_accumulation_steps: 8 # Maintain effective batch size
|
||||||
learning_rate: 5e-4
|
learning_rate: 5e-4
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
max_length: 512 # Reduced sequence length
|
max_length: 256 # Very short sequences for memory
|
||||||
bf16: true
|
bf16: true
|
||||||
max_grad_norm: 1.0
|
max_grad_norm: 1.0
|
||||||
weight_decay: 0.001
|
weight_decay: 0.001
|
||||||
|
|
@ -53,7 +51,7 @@ progressive_stages:
|
||||||
r: 8 # Minimal rank for memory
|
r: 8 # Minimal rank for memory
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.1
|
lora_dropout: 0.1
|
||||||
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
target_modules: ["q_proj", "v_proj"] # Minimal modules for memory
|
||||||
init_lora_weights: true
|
init_lora_weights: true
|
||||||
training:
|
training:
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|
@ -61,7 +59,7 @@ progressive_stages:
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
learning_rate: 3e-4
|
learning_rate: 3e-4
|
||||||
warmup_steps: 200
|
warmup_steps: 200
|
||||||
max_length: 1024 # Reduced sequence length
|
max_length: 512 # Short sequences for memory
|
||||||
bf16: true
|
bf16: true
|
||||||
max_grad_norm: 1.0
|
max_grad_norm: 1.0
|
||||||
weight_decay: 0.001
|
weight_decay: 0.001
|
||||||
|
|
@ -81,7 +79,7 @@ progressive_stages:
|
||||||
r: 8 # Minimal rank for memory
|
r: 8 # Minimal rank for memory
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.1
|
lora_dropout: 0.1
|
||||||
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
target_modules: ["q_proj", "v_proj"] # Minimal modules for memory
|
||||||
init_lora_weights: true
|
init_lora_weights: true
|
||||||
training:
|
training:
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|
@ -89,7 +87,7 @@ progressive_stages:
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
learning_rate: 2e-4
|
learning_rate: 2e-4
|
||||||
warmup_steps: 300
|
warmup_steps: 300
|
||||||
max_length: 1024 # Reduced sequence length
|
max_length: 512 # Short sequences for memory
|
||||||
bf16: true
|
bf16: true
|
||||||
max_grad_norm: 1.0
|
max_grad_norm: 1.0
|
||||||
weight_decay: 0.001
|
weight_decay: 0.001
|
||||||
|
|
|
||||||
45
config/training_config_gemma3_1b_minimal.yaml
Normal file
45
config/training_config_gemma3_1b_minimal.yaml
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
experiment:
|
||||||
|
name: "progressive_reasoning_gemma3_1b_minimal"
|
||||||
|
base_model: "google/gemma-3-1b-pt"
|
||||||
|
output_dir: "./outputs"
|
||||||
|
use_wandb: true
|
||||||
|
wandb_project: "matsuo-llm-comp-2025"
|
||||||
|
|
||||||
|
model:
|
||||||
|
load_in_4bit: false
|
||||||
|
device_map: "auto"
|
||||||
|
gradient_checkpointing: true
|
||||||
|
use_flash_attention_2: false
|
||||||
|
use_eager_attention: true
|
||||||
|
|
||||||
|
progressive_stages:
|
||||||
|
- name: "basic_cot"
|
||||||
|
description: "Basic Chain-of-Thought reasoning"
|
||||||
|
dataset_path: "./data/basic_cot/"
|
||||||
|
adapter_config:
|
||||||
|
r: 4 # Extremely minimal rank
|
||||||
|
lora_alpha: 8
|
||||||
|
lora_dropout: 0.1
|
||||||
|
target_modules: ["q_proj"] # Only one module
|
||||||
|
init_lora_weights: true
|
||||||
|
training:
|
||||||
|
num_epochs: 1 # Reduced epochs
|
||||||
|
per_device_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
learning_rate: 5e-4
|
||||||
|
warmup_steps: 10
|
||||||
|
max_length: 128 # Very short sequences
|
||||||
|
bf16: true
|
||||||
|
max_grad_norm: 1.0
|
||||||
|
weight_decay: 0.001
|
||||||
|
save_steps: 100
|
||||||
|
logging_steps: 10
|
||||||
|
dataloader_num_workers: 1
|
||||||
|
dataloader_pin_memory: false
|
||||||
|
|
||||||
|
evaluation:
|
||||||
|
benchmarks:
|
||||||
|
- "HLE"
|
||||||
|
- "Do-Not-Answer"
|
||||||
|
save_results: true
|
||||||
|
results_dir: "./outputs/evaluation_results"
|
||||||
Loading…
Add table
Reference in a new issue