progressive-llm/config/training_config_8gpu_fsdp.yaml
2025-07-10 22:25:11 +09:00

113 lines
No EOL
3.4 KiB
YAML

experiment:
name: "progressive_reasoning_8gpu_fsdp"
base_model: "google/gemma-2-2b-it" # Can scale to much larger models with FSDP
output_dir: "./outputs"
use_wandb: true
wandb_project: "matsuo-llm-comp-2025"
model:
load_in_4bit: false
device_map: null # Let FSDP handle device placement
gradient_checkpointing: true
use_flash_attention_2: true
use_eager_attention: false
# FSDP Configuration
fsdp:
fsdp_transformer_layer_cls_to_wrap: "Gemma2DecoderLayer" # Wrap at layer level
fsdp_sharding_strategy: "FULL_SHARD" # Shard parameters, gradients, and optimizer states
fsdp_cpu_offload: false # Keep on GPU for speed
fsdp_mixed_precision: true # Use BF16 mixed precision
fsdp_auto_wrap_policy: "TRANSFORMER_BASED_WRAP"
fsdp_min_num_params: 1000000 # Wrap layers with >1M parameters
fsdp_sync_module_states: true
fsdp_forward_prefetch: true
fsdp_use_orig_params: true # Important for LoRA compatibility
progressive_stages:
- name: "basic_cot"
description: "Basic Chain-of-Thought reasoning"
dataset_path: "./data/basic_cot/"
adapter_config:
r: 64 # Can use larger ranks with FSDP
lora_alpha: 128
lora_dropout: 0.1
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
init_lora_weights: true
training:
num_epochs: 2
per_device_batch_size: 32 # Very large batch size with FSDP
gradient_accumulation_steps: 1
learning_rate: 5e-4
warmup_steps: 100
max_length: 2048
bf16: true
max_grad_norm: 1.0
weight_decay: 0.001
save_steps: 50
logging_steps: 10
dataloader_num_workers: 8
dataloader_pin_memory: true
- name: "math_reasoning"
description: "Mathematical reasoning with OpenR1-Math-220k dataset"
dataset_path: "open-r1/OpenR1-Math-220k"
inherit_from: "basic_cot"
adapter_config:
r: 128
lora_alpha: 256
lora_dropout: 0.1
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
init_lora_weights: true
training:
num_epochs: 1
per_device_batch_size: 16
gradient_accumulation_steps: 2
learning_rate: 3e-4
warmup_steps: 200
max_length: 4096
bf16: true
max_grad_norm: 1.0
weight_decay: 0.001
save_steps: 100
logging_steps: 20
dataloader_num_workers: 8
dataset_config:
streaming: true
max_samples: 200000 # Process even more data
split: "train"
- name: "complex_reasoning"
description: "Complex multi-step reasoning with Mixture-of-Thoughts"
dataset_path: "open-r1/Mixture-of-Thoughts"
inherit_from: "math_reasoning"
adapter_config:
r: 256 # Very large rank possible with FSDP
lora_alpha: 512
lora_dropout: 0.1
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
init_lora_weights: true
training:
num_epochs: 1
per_device_batch_size: 8
gradient_accumulation_steps: 4
learning_rate: 2e-4
warmup_steps: 300
max_length: 8192
bf16: true
max_grad_norm: 1.0
weight_decay: 0.001
save_steps: 200
logging_steps: 50
dataloader_num_workers: 8
dataset_config:
streaming: true
max_samples: 100000
split: "train"
evaluation:
benchmarks:
- "HLE"
- "Do-Not-Answer"
save_results: true
results_dir: "./outputs/evaluation_results"