From 2d01c6577f7d5afe35ada5347a889d1aee45f2fb Mon Sep 17 00:00:00 2001 From: Soma Nakamura Date: Thu, 10 Jul 2025 23:10:56 +0900 Subject: [PATCH] grad-repair --- scripts/train_progressive.py | 8 ++++---- src/training.py | 34 +++++++++++++++++++++++----------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/scripts/train_progressive.py b/scripts/train_progressive.py index 3d2ce54..c997eea 100755 --- a/scripts/train_progressive.py +++ b/scripts/train_progressive.py @@ -102,16 +102,16 @@ def setup_distributed_training(local_rank=-1): if "RANK" in os.environ and "WORLD_SIZE" in os.environ: import torch.distributed as dist - # Initialize distributed training - if not dist.is_initialized(): - dist.init_process_group(backend="nccl") - # Use local_rank from args or environment if local_rank >= 0: torch.cuda.set_device(local_rank) else: local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) + + # Initialize distributed training + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}, local_rank {local_rank}") return True diff --git a/src/training.py b/src/training.py index 9d98d82..5028769 100644 --- a/src/training.py +++ b/src/training.py @@ -433,24 +433,36 @@ class ProgressiveTrainer: trainable_params = sum(p.numel() for p in self.model_wrapper.model.parameters() if p.requires_grad) print(f"Final check - Trainable parameters: {trainable_params:,}") - # Create trainer with minimal configuration + # Create trainer with compatible configuration try: + # Try modern SFTTrainer arguments trainer = SFTTrainer( model=self.model_wrapper.model, - processing_class=self.model_wrapper.tokenizer, + tokenizer=self.model_wrapper.tokenizer, train_dataset=dataset, args=training_args, - packing=False, # Disable packing for better gradient flow + max_seq_length=stage_config["training"]["max_length"], + dataset_text_field="text" ) except Exception as e: - print(f"Error creating SFTTrainer: {e}") - print("Trying with basic configuration...") - trainer = SFTTrainer( - model=self.model_wrapper.model, - processing_class=self.model_wrapper.tokenizer, - train_dataset=dataset, - args=training_args, - ) + print(f"Error creating SFTTrainer with modern args: {e}") + try: + # Try with processing_class (newer versions) + trainer = SFTTrainer( + model=self.model_wrapper.model, + processing_class=self.model_wrapper.tokenizer, + train_dataset=dataset, + args=training_args, + ) + except Exception as e2: + print(f"Error with processing_class: {e2}") + # Fallback to basic configuration + trainer = SFTTrainer( + model=self.model_wrapper.model, + tokenizer=self.model_wrapper.tokenizer, + train_dataset=dataset, + args=training_args, + ) # Train trainer.train()