grad-repair
This commit is contained in:
parent
59c3bcfc7d
commit
2d01c6577f
2 changed files with 27 additions and 15 deletions
|
|
@ -102,10 +102,6 @@ def setup_distributed_training(local_rank=-1):
|
|||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
import torch.distributed as dist
|
||||
|
||||
# Initialize distributed training
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="nccl")
|
||||
|
||||
# Use local_rank from args or environment
|
||||
if local_rank >= 0:
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
|
@ -113,6 +109,10 @@ def setup_distributed_training(local_rank=-1):
|
|||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# Initialize distributed training
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="nccl")
|
||||
|
||||
print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}, local_rank {local_rank}")
|
||||
return True
|
||||
|
||||
|
|
|
|||
|
|
@ -433,21 +433,33 @@ class ProgressiveTrainer:
|
|||
trainable_params = sum(p.numel() for p in self.model_wrapper.model.parameters() if p.requires_grad)
|
||||
print(f"Final check - Trainable parameters: {trainable_params:,}")
|
||||
|
||||
# Create trainer with minimal configuration
|
||||
# Create trainer with compatible configuration
|
||||
try:
|
||||
# Try modern SFTTrainer arguments
|
||||
trainer = SFTTrainer(
|
||||
model=self.model_wrapper.model,
|
||||
tokenizer=self.model_wrapper.tokenizer,
|
||||
train_dataset=dataset,
|
||||
args=training_args,
|
||||
max_seq_length=stage_config["training"]["max_length"],
|
||||
dataset_text_field="text"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error creating SFTTrainer with modern args: {e}")
|
||||
try:
|
||||
# Try with processing_class (newer versions)
|
||||
trainer = SFTTrainer(
|
||||
model=self.model_wrapper.model,
|
||||
processing_class=self.model_wrapper.tokenizer,
|
||||
train_dataset=dataset,
|
||||
args=training_args,
|
||||
packing=False, # Disable packing for better gradient flow
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error creating SFTTrainer: {e}")
|
||||
print("Trying with basic configuration...")
|
||||
except Exception as e2:
|
||||
print(f"Error with processing_class: {e2}")
|
||||
# Fallback to basic configuration
|
||||
trainer = SFTTrainer(
|
||||
model=self.model_wrapper.model,
|
||||
processing_class=self.model_wrapper.tokenizer,
|
||||
tokenizer=self.model_wrapper.tokenizer,
|
||||
train_dataset=dataset,
|
||||
args=training_args,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue