Skip to content

Commit

Permalink
Flash attention and mixed precision training, works quite a bit faster
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 23, 2024
1 parent a778225 commit 5967a52
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
1 change: 1 addition & 0 deletions pdelfin/train/config/qwen2vl-2b.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
model:
name_or_path: Qwen/Qwen2-VL-2B-Instruct
arch: causal
use_flash_attn: true

wandb:
project: pdelfin
Expand Down
11 changes: 5 additions & 6 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ def run_train(config: TrainConfig):

run_name = RunName.get(config)

accelerator = accelerate.Accelerator()

setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)

dataset = make_dataset(
Expand All @@ -133,7 +131,8 @@ def run_train(config: TrainConfig):
)

model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto",
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")

Expand Down Expand Up @@ -187,8 +186,7 @@ def run_train(config: TrainConfig):
save_steps=config.save.save_every_steps,
warmup_steps=config.hparams.warmup_steps,
warmup_ratio=config.hparams.warmup_ratio,
bf16=accelerator.mixed_precision == "bf16",
fp16=accelerator.mixed_precision == "fp16",
bf16=True,
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
max_grad_norm=config.hparams.clip_grad_norm,
remove_unused_columns=False,
Expand Down Expand Up @@ -223,9 +221,10 @@ def run_train(config: TrainConfig):
logger.info("Saved best model to %s", best_dir)

# Uncomment to test speed of data loader
# train_dataloader = DataLoader(train_ds, batch_size=1, num_workers=2, shuffle=False)
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
# for entry in tqdm(train_dataloader):
# print("Step!")
# model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})


def main():
Expand Down

0 comments on commit 5967a52

Please sign in to comment.