Skip to content

Commit

Permalink
add wandb initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrienneDeganutti committed Dec 7, 2023
1 parent 07e57b9 commit 13a1da6
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions src/tasks/run_caption_VidSwinBert.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from src.modeling.load_bert import get_bert_model
from src.solver import AdamW, WarmupLinearLR

import wandb

from azureml.core.run import Run
aml_run = Run.get_context()

Expand Down Expand Up @@ -198,6 +200,10 @@ def train(args, train_dataloader, val_dataloader, model, tokenizer, training_sav
scaled_loss.backward()
if backward_now:
global_step += 1

if is_main_process():
wandb.log({"loss": loss.item(), "accuracy": batch_acc.item(), "step": global_step})

TB_LOGGER.add_scalar('train/loss', running_loss.val, global_step)

lr_VisBone = optimizer.param_groups[0]["lr"]
Expand Down Expand Up @@ -263,7 +269,6 @@ def train(args, train_dataloader, val_dataloader, model, tokenizer, training_sav

if (args.save_steps > 0 and global_step % args.save_steps == 0) or global_step == max_global_step or global_step == 1:
epoch = global_step // global_iters_per_epoch

checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format(
epoch, global_step))
if get_world_size() > 1:
Expand Down Expand Up @@ -434,7 +439,7 @@ def gen_rows():
img_key = img_key.item()
yield img_key, json.dumps(res)

logger.info(f"Inference model computing time: {(time_meter / (step+1))} seconds per batch")
#logger.info(f"Inference model computing time: {(time_meter / (step+1))} seconds per batch")

tsv_writer(gen_rows(), cache_file)
if world_size > 1:
Expand Down Expand Up @@ -653,6 +658,10 @@ def main(args):
vl_transformer.to(args.device)

if args.do_train:

if is_main_process():
wandb.init(project="SwinBERT", name="training_process", config=args)

args = restore_training_settings(args)
train_dataloader = make_data_loader(args, args.train_yaml, tokenizer, args.distributed, is_train=True)
val_dataloader = make_data_loader(args, args.val_yaml, tokenizer, args.distributed, is_train=False)
Expand All @@ -676,4 +685,5 @@ def main(args):
if __name__ == "__main__":
shared_configs.shared_video_captioning_config(cbs=True, scst=True)
args = get_custom_args(shared_configs)
main(args)
torch.cuda.set_device(args.local_rank)
main(args)

0 comments on commit 13a1da6

Please sign in to comment.