From 5a784102b9694e1888ed6e8bf0b57da26e4adc18 Mon Sep 17 00:00:00 2001 From: Soma Nakamura Date: Thu, 10 Jul 2025 23:21:23 +0900 Subject: [PATCH] grad-repair --- .../training_config_gemma3_1b_8gpu_ddp.yaml | 40 +++++++++---------- src/progressive_model.py | 13 +++--- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/config/training_config_gemma3_1b_8gpu_ddp.yaml b/config/training_config_gemma3_1b_8gpu_ddp.yaml index 527c9a8..ee6c2b7 100644 --- a/config/training_config_gemma3_1b_8gpu_ddp.yaml +++ b/config/training_config_gemma3_1b_8gpu_ddp.yaml @@ -6,11 +6,11 @@ experiment: wandb_project: "matsuo-llm-comp-2025" model: - load_in_4bit: false # Can use FP16/BF16 with multiple GPUs + load_in_4bit: true # Enable quantization for memory savings bnb_4bit_compute_dtype: "bfloat16" bnb_4bit_use_double_quant: true device_map: "balanced" # Distribute across all GPUs - gradient_checkpointing: true + gradient_checkpointing: true # Enable gradient checkpointing use_flash_attention_2: false use_eager_attention: true @@ -25,24 +25,24 @@ progressive_stages: description: "Basic Chain-of-Thought reasoning" dataset_path: "./data/basic_cot/" adapter_config: - r: 16 # Moderate rank for DDP - lora_alpha: 32 + r: 8 # Minimal rank for memory + lora_alpha: 16 lora_dropout: 0.1 target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] init_lora_weights: true training: num_epochs: 2 - per_device_batch_size: 8 # 8 * 8 = 64 total batch size (reduced for memory) - gradient_accumulation_steps: 1 + per_device_batch_size: 1 # 1 * 8 = 8 total batch size (minimal) + gradient_accumulation_steps: 8 # Maintain effective batch size learning_rate: 5e-4 warmup_steps: 100 - max_length: 1024 + max_length: 512 # Reduced sequence length bf16: true max_grad_norm: 1.0 weight_decay: 0.001 save_steps: 50 logging_steps: 10 - dataloader_num_workers: 4 + dataloader_num_workers: 2 dataloader_pin_memory: false - name: "math_reasoning" @@ -50,24 +50,24 @@ progressive_stages: dataset_path: "open-r1/OpenR1-Math-220k" inherit_from: "basic_cot" adapter_config: - r: 32 - lora_alpha: 64 + r: 8 # Minimal rank for memory + lora_alpha: 16 lora_dropout: 0.1 target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] init_lora_weights: true training: num_epochs: 1 - per_device_batch_size: 4 # 4 * 8 = 32 total batch size (reduced for memory) - gradient_accumulation_steps: 2 + per_device_batch_size: 1 # 1 * 8 = 8 total batch size (minimal) + gradient_accumulation_steps: 4 learning_rate: 3e-4 warmup_steps: 200 - max_length: 2048 + max_length: 1024 # Reduced sequence length bf16: true max_grad_norm: 1.0 weight_decay: 0.001 save_steps: 100 logging_steps: 20 - dataloader_num_workers: 4 + dataloader_num_workers: 2 dataset_config: streaming: true max_samples: 400000 # Process substantial data @@ -78,24 +78,24 @@ progressive_stages: dataset_path: "open-r1/Mixture-of-Thoughts" inherit_from: "math_reasoning" adapter_config: - r: 64 - lora_alpha: 128 + r: 8 # Minimal rank for memory + lora_alpha: 16 lora_dropout: 0.1 target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] init_lora_weights: true training: num_epochs: 1 - per_device_batch_size: 2 # 2 * 8 = 16 total batch size (reduced for memory) - gradient_accumulation_steps: 4 + per_device_batch_size: 1 # 1 * 8 = 8 total batch size (minimal) + gradient_accumulation_steps: 2 learning_rate: 2e-4 warmup_steps: 300 - max_length: 4096 + max_length: 1024 # Reduced sequence length bf16: true max_grad_norm: 1.0 weight_decay: 0.001 save_steps: 200 logging_steps: 50 - dataloader_num_workers: 4 + dataloader_num_workers: 2 dataset_config: streaming: true max_samples: 600000 diff --git a/src/progressive_model.py b/src/progressive_model.py index b6995ed..2ab5ba3 100644 --- a/src/progressive_model.py +++ b/src/progressive_model.py @@ -198,16 +198,15 @@ class ProgressiveReasoningModel: if quantization_config: self.model = prepare_model_for_kbit_training(self.model) - # Disable gradient checkpointing for now to avoid conflicts - # Enable gradient checkpointing if requested (but disable use_cache) - # if self.config["model"].get("gradient_checkpointing", False): - # self.model.gradient_checkpointing_enable() - # self.model.config.use_cache = False - # print("Gradient checkpointing enabled, use_cache disabled") + # Enable gradient checkpointing if requested + if self.config["model"].get("gradient_checkpointing", False): + self.model.gradient_checkpointing_enable() + print("Gradient checkpointing enabled") - # Explicitly disable use_cache to avoid conflicts + # Explicitly disable use_cache to avoid conflicts and save memory if hasattr(self.model, 'config'): self.model.config.use_cache = False + print("use_cache disabled for memory efficiency") # Load tokenizer tokenizer_kwargs = {"trust_remote_code": True}