grad-repair
This commit is contained in:
parent
59c3bcfc7d
commit
2d01c6577f
2 changed files with 27 additions and 15 deletions
|
|
@ -102,10 +102,6 @@ def setup_distributed_training(local_rank=-1):
|
||||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
# Initialize distributed training
|
|
||||||
if not dist.is_initialized():
|
|
||||||
dist.init_process_group(backend="nccl")
|
|
||||||
|
|
||||||
# Use local_rank from args or environment
|
# Use local_rank from args or environment
|
||||||
if local_rank >= 0:
|
if local_rank >= 0:
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
|
|
@ -113,6 +109,10 @@ def setup_distributed_training(local_rank=-1):
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
|
|
||||||
|
# Initialize distributed training
|
||||||
|
if not dist.is_initialized():
|
||||||
|
dist.init_process_group(backend="nccl")
|
||||||
|
|
||||||
print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}, local_rank {local_rank}")
|
print(f"Distributed training initialized: rank {dist.get_rank()}/{dist.get_world_size()}, local_rank {local_rank}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -433,21 +433,33 @@ class ProgressiveTrainer:
|
||||||
trainable_params = sum(p.numel() for p in self.model_wrapper.model.parameters() if p.requires_grad)
|
trainable_params = sum(p.numel() for p in self.model_wrapper.model.parameters() if p.requires_grad)
|
||||||
print(f"Final check - Trainable parameters: {trainable_params:,}")
|
print(f"Final check - Trainable parameters: {trainable_params:,}")
|
||||||
|
|
||||||
# Create trainer with minimal configuration
|
# Create trainer with compatible configuration
|
||||||
try:
|
try:
|
||||||
|
# Try modern SFTTrainer arguments
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model=self.model_wrapper.model,
|
||||||
|
tokenizer=self.model_wrapper.tokenizer,
|
||||||
|
train_dataset=dataset,
|
||||||
|
args=training_args,
|
||||||
|
max_seq_length=stage_config["training"]["max_length"],
|
||||||
|
dataset_text_field="text"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error creating SFTTrainer with modern args: {e}")
|
||||||
|
try:
|
||||||
|
# Try with processing_class (newer versions)
|
||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
model=self.model_wrapper.model,
|
model=self.model_wrapper.model,
|
||||||
processing_class=self.model_wrapper.tokenizer,
|
processing_class=self.model_wrapper.tokenizer,
|
||||||
train_dataset=dataset,
|
train_dataset=dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
packing=False, # Disable packing for better gradient flow
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e2:
|
||||||
print(f"Error creating SFTTrainer: {e}")
|
print(f"Error with processing_class: {e2}")
|
||||||
print("Trying with basic configuration...")
|
# Fallback to basic configuration
|
||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
model=self.model_wrapper.model,
|
model=self.model_wrapper.model,
|
||||||
processing_class=self.model_wrapper.tokenizer,
|
tokenizer=self.model_wrapper.tokenizer,
|
||||||
train_dataset=dataset,
|
train_dataset=dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue