grad-repair

This commit is contained in:
Soma Nakamura 2025-07-10 23:10:56 +09:00
parent 59c3bcfc7d
commit 2d01c6577f
2 changed files with 27 additions and 15 deletions

View file

@ -102,10 +102,6 @@ def setup_distributed_training(local_rank=-1):
if "RANK" in os.environ and "WORLD_SIZE" in os.environ: if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
import torch.distributed as dist import torch.distributed as dist
# Initialize distributed training
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
# Use local_rank from args or environment # Use local_rank from args or environment
if local_rank >= 0: if local_rank >= 0:
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
@ -113,6 +109,10 @@ def setup_distributed_training(local_rank=-1):
local_rank = int(os.environ.get("LOCAL_RANK", 0)) local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
# Initialize distributed training
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}, local_rank {local_rank}") print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}, local_rank {local_rank}")
return True return True

View file

@ -433,21 +433,33 @@ class ProgressiveTrainer:
trainable_params = sum(p.numel() for p in self.model_wrapper.model.parameters() if p.requires_grad) trainable_params = sum(p.numel() for p in self.model_wrapper.model.parameters() if p.requires_grad)
print(f"Final check - Trainable parameters: {trainable_params:,}") print(f"Final check - Trainable parameters: {trainable_params:,}")
# Create trainer with minimal configuration # Create trainer with compatible configuration
try: try:
# Try modern SFTTrainer arguments
trainer = SFTTrainer(
model=self.model_wrapper.model,
tokenizer=self.model_wrapper.tokenizer,
train_dataset=dataset,
args=training_args,
max_seq_length=stage_config["training"]["max_length"],
dataset_text_field="text"
)
except Exception as e:
print(f"Error creating SFTTrainer with modern args: {e}")
try:
# Try with processing_class (newer versions)
trainer = SFTTrainer( trainer = SFTTrainer(
model=self.model_wrapper.model, model=self.model_wrapper.model,
processing_class=self.model_wrapper.tokenizer, processing_class=self.model_wrapper.tokenizer,
train_dataset=dataset, train_dataset=dataset,
args=training_args, args=training_args,
packing=False, # Disable packing for better gradient flow
) )
except Exception as e: except Exception as e2:
print(f"Error creating SFTTrainer: {e}") print(f"Error with processing_class: {e2}")
print("Trying with basic configuration...") # Fallback to basic configuration
trainer = SFTTrainer( trainer = SFTTrainer(
model=self.model_wrapper.model, model=self.model_wrapper.model,
processing_class=self.model_wrapper.tokenizer, tokenizer=self.model_wrapper.tokenizer,
train_dataset=dataset, train_dataset=dataset,
args=training_args, args=training_args,
) )