From 821c26bf69154bca1799150217b8c875d00d6505 Mon Sep 17 00:00:00 2001 From: Soma Nakamura
Date: Thu, 10 Jul 2025 23:25:51 +0900 Subject: [PATCH] grad-repair --- .../training_config_gemma3_1b_8gpu_ddp.yaml | 14 +++--- config/training_config_gemma3_1b_minimal.yaml | 45 +++++++++++++++++++ 2 files changed, 51 insertions(+), 8 deletions(-) create mode 100644 config/training_config_gemma3_1b_minimal.yaml diff --git a/config/training_config_gemma3_1b_8gpu_ddp.yaml b/config/training_config_gemma3_1b_8gpu_ddp.yaml index ee6c2b7..5c1ed70 100644 --- a/config/training_config_gemma3_1b_8gpu_ddp.yaml +++ b/config/training_config_gemma3_1b_8gpu_ddp.yaml @@ -6,9 +6,7 @@ experiment: wandb_project: "matsuo-llm-comp-2025" model: - load_in_4bit: true # Enable quantization for memory savings - bnb_4bit_compute_dtype: "bfloat16" - bnb_4bit_use_double_quant: true + load_in_4bit: false # Disable quantization due to CUDA kernel issues device_map: "balanced" # Distribute across all GPUs gradient_checkpointing: true # Enable gradient checkpointing use_flash_attention_2: false @@ -36,7 +34,7 @@ progressive_stages: gradient_accumulation_steps: 8 # Maintain effective batch size learning_rate: 5e-4 warmup_steps: 100 - max_length: 512 # Reduced sequence length + max_length: 256 # Very short sequences for memory bf16: true max_grad_norm: 1.0 weight_decay: 0.001 @@ -53,7 +51,7 @@ progressive_stages: r: 8 # Minimal rank for memory lora_alpha: 16 lora_dropout: 0.1 - target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + target_modules: ["q_proj", "v_proj"] # Minimal modules for memory init_lora_weights: true training: num_epochs: 1 @@ -61,7 +59,7 @@ progressive_stages: gradient_accumulation_steps: 4 learning_rate: 3e-4 warmup_steps: 200 - max_length: 1024 # Reduced sequence length + max_length: 512 # Short sequences for memory bf16: true max_grad_norm: 1.0 weight_decay: 0.001 @@ -81,7 +79,7 @@ progressive_stages: r: 8 # Minimal rank for memory lora_alpha: 16 lora_dropout: 0.1 - target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + target_modules: ["q_proj", "v_proj"] # Minimal modules for memory init_lora_weights: true training: num_epochs: 1 @@ -89,7 +87,7 @@ progressive_stages: gradient_accumulation_steps: 2 learning_rate: 2e-4 warmup_steps: 300 - max_length: 1024 # Reduced sequence length + max_length: 512 # Short sequences for memory bf16: true max_grad_norm: 1.0 weight_decay: 0.001 diff --git a/config/training_config_gemma3_1b_minimal.yaml b/config/training_config_gemma3_1b_minimal.yaml new file mode 100644 index 0000000..deed65b --- /dev/null +++ b/config/training_config_gemma3_1b_minimal.yaml @@ -0,0 +1,45 @@ +experiment: + name: "progressive_reasoning_gemma3_1b_minimal" + base_model: "google/gemma-3-1b-pt" + output_dir: "./outputs" + use_wandb: true + wandb_project: "matsuo-llm-comp-2025" + +model: + load_in_4bit: false + device_map: "auto" + gradient_checkpointing: true + use_flash_attention_2: false + use_eager_attention: true + +progressive_stages: + - name: "basic_cot" + description: "Basic Chain-of-Thought reasoning" + dataset_path: "./data/basic_cot/" + adapter_config: + r: 4 # Extremely minimal rank + lora_alpha: 8 + lora_dropout: 0.1 + target_modules: ["q_proj"] # Only one module + init_lora_weights: true + training: + num_epochs: 1 # Reduced epochs + per_device_batch_size: 1 + gradient_accumulation_steps: 4 + learning_rate: 5e-4 + warmup_steps: 10 + max_length: 128 # Very short sequences + bf16: true + max_grad_norm: 1.0 + weight_decay: 0.001 + save_steps: 100 + logging_steps: 10 + dataloader_num_workers: 1 + dataloader_pin_memory: false + +evaluation: + benchmarks: + - "HLE" + - "Do-Not-Answer" + save_results: true + results_dir: "./outputs/evaluation_results" \ No newline at end of file