grad-repair

This commit is contained in:
Soma Nakamura 2025-07-10 23:10:56 +09:00
parent 59c3bcfc7d
commit 2d01c6577f
2 changed files with 27 additions and 15 deletions

View file

@ -102,16 +102,16 @@ def setup_distributed_training(local_rank=-1):
if "RANK" in os.environ and "WORLD_SIZE" in os.environ: if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
import torch.distributed as dist import torch.distributed as dist
# Initialize distributed training
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
# Use local_rank from args or environment # Use local_rank from args or environment
if local_rank >= 0: if local_rank >= 0:
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
else: else:
local_rank = int(os.environ.get("LOCAL_RANK", 0)) local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
# Initialize distributed training
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}, local_rank {local_rank}") print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}, local_rank {local_rank}")
return True return True

View file

@ -433,24 +433,36 @@ class ProgressiveTrainer:
trainable_params = sum(p.numel() for p in self.model_wrapper.model.parameters() if p.requires_grad) trainable_params = sum(p.numel() for p in self.model_wrapper.model.parameters() if p.requires_grad)
print(f"Final check - Trainable parameters: {trainable_params:,}") print(f"Final check - Trainable parameters: {trainable_params:,}")
# Create trainer with minimal configuration # Create trainer with compatible configuration
try: try:
# Try modern SFTTrainer arguments
trainer = SFTTrainer( trainer = SFTTrainer(
model=self.model_wrapper.model, model=self.model_wrapper.model,
processing_class=self.model_wrapper.tokenizer, tokenizer=self.model_wrapper.tokenizer,
train_dataset=dataset, train_dataset=dataset,
args=training_args, args=training_args,
packing=False, # Disable packing for better gradient flow max_seq_length=stage_config["training"]["max_length"],
dataset_text_field="text"
) )
except Exception as e: except Exception as e:
print(f"Error creating SFTTrainer: {e}") print(f"Error creating SFTTrainer with modern args: {e}")
print("Trying with basic configuration...") try:
trainer = SFTTrainer( # Try with processing_class (newer versions)
model=self.model_wrapper.model, trainer = SFTTrainer(
processing_class=self.model_wrapper.tokenizer, model=self.model_wrapper.model,
train_dataset=dataset, processing_class=self.model_wrapper.tokenizer,
args=training_args, train_dataset=dataset,
) args=training_args,
)
except Exception as e2:
print(f"Error with processing_class: {e2}")
# Fallback to basic configuration
trainer = SFTTrainer(
model=self.model_wrapper.model,
tokenizer=self.model_wrapper.tokenizer,
train_dataset=dataset,
args=training_args,
)
# Train # Train
trainer.train() trainer.train()