Saving weights and logs of step 300
Browse files
flax_model.msgpack
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 891548548
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:457c5948252576d9d5252b28b79a754223d3dea5a24a77f5b2b7cb5189129499
|
| 3 |
size 891548548
|
run_t5.sh
CHANGED
|
@@ -16,9 +16,8 @@ mkdir -p "${MODEL_DIR}/runs"
|
|
| 16 |
--preprocessing_num_workers="96" \
|
| 17 |
--do_train --do_eval \
|
| 18 |
--adafactor \
|
| 19 |
-
--dtype="bfloat16" \
|
| 20 |
--max_seq_length="512" \
|
| 21 |
-
--gradient_accumulation_steps="
|
| 22 |
--per_device_train_batch_size="32" \
|
| 23 |
--per_device_eval_batch_size="32" \
|
| 24 |
--learning_rate="5e-3" \
|
|
@@ -32,3 +31,7 @@ mkdir -p "${MODEL_DIR}/runs"
|
|
| 32 |
#git add pytorch_model.bin
|
| 33 |
#git commit -m "Update pytorch model after training"
|
| 34 |
#git push origin main
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
--preprocessing_num_workers="96" \
|
| 17 |
--do_train --do_eval \
|
| 18 |
--adafactor \
|
|
|
|
| 19 |
--max_seq_length="512" \
|
| 20 |
+
--gradient_accumulation_steps="16" \
|
| 21 |
--per_device_train_batch_size="32" \
|
| 22 |
--per_device_eval_batch_size="32" \
|
| 23 |
--learning_rate="5e-3" \
|
|
|
|
| 31 |
#git add pytorch_model.bin
|
| 32 |
#git commit -m "Update pytorch model after training"
|
| 33 |
#git push origin main
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# --dtype="bfloat16" \
|
| 37 |
+
# --resume_from_checkpoint="${MODEL_DIR}/ckpt-3300" \
|
run_t5_mlm_flax_custom_dataset.py
CHANGED
|
@@ -722,6 +722,9 @@ if __name__ == "__main__":
|
|
| 722 |
|
| 723 |
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
| 724 |
|
|
|
|
|
|
|
|
|
|
| 725 |
# Create learning rate schedule
|
| 726 |
|
| 727 |
# See https://arxiv.org/pdf/2104.07705.pdf for rationale of choosing the peak at 6% of training steps
|
|
@@ -775,6 +778,11 @@ if __name__ == "__main__":
|
|
| 775 |
# Setup train state
|
| 776 |
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
|
| 777 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 778 |
# Define gradient update step fn
|
| 779 |
def train_step(state, batch, dropout_rng):
|
| 780 |
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
|
@@ -828,8 +836,7 @@ if __name__ == "__main__":
|
|
| 828 |
# Replicate the train state on each device
|
| 829 |
state = jax_utils.replicate(state)
|
| 830 |
|
| 831 |
-
|
| 832 |
-
total_train_steps = steps_per_epoch * num_epochs
|
| 833 |
|
| 834 |
logger.info("***** Running training *****")
|
| 835 |
logger.info(f" Num examples = {len(datasets['train'])}")
|
|
@@ -855,6 +862,11 @@ if __name__ == "__main__":
|
|
| 855 |
|
| 856 |
# Gather the indexes for creating the batch and do a training step
|
| 857 |
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 858 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
| 859 |
model_inputs = data_collator(samples)
|
| 860 |
|
|
@@ -863,7 +875,6 @@ if __name__ == "__main__":
|
|
| 863 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
| 864 |
train_metrics.append(train_metric)
|
| 865 |
|
| 866 |
-
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
| 867 |
|
| 868 |
if cur_step % training_args.logging_steps * grad_accum_steps == 0 and cur_step > 0:
|
| 869 |
# Save metrics
|
|
|
|
| 722 |
|
| 723 |
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
| 724 |
|
| 725 |
+
steps_per_epoch = len(tokenized_datasets['train']) // train_batch_size
|
| 726 |
+
total_train_steps = steps_per_epoch * num_epochs
|
| 727 |
+
|
| 728 |
# Create learning rate schedule
|
| 729 |
|
| 730 |
# See https://arxiv.org/pdf/2104.07705.pdf for rationale of choosing the peak at 6% of training steps
|
|
|
|
| 778 |
# Setup train state
|
| 779 |
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
|
| 780 |
|
| 781 |
+
if training_args.resume_from_checkpoint:
|
| 782 |
+
state, resume_step = restore_checkpoint(training_args.resume_from_checkpoint, state)
|
| 783 |
+
else:
|
| 784 |
+
resume_step = 0
|
| 785 |
+
|
| 786 |
# Define gradient update step fn
|
| 787 |
def train_step(state, batch, dropout_rng):
|
| 788 |
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
|
|
|
| 836 |
# Replicate the train state on each device
|
| 837 |
state = jax_utils.replicate(state)
|
| 838 |
|
| 839 |
+
|
|
|
|
| 840 |
|
| 841 |
logger.info("***** Running training *****")
|
| 842 |
logger.info(f" Num examples = {len(datasets['train'])}")
|
|
|
|
| 862 |
|
| 863 |
# Gather the indexes for creating the batch and do a training step
|
| 864 |
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
| 865 |
+
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
| 866 |
+
# skip to the step from which we are resuming
|
| 867 |
+
if cur_step < resume_step:
|
| 868 |
+
continue
|
| 869 |
+
|
| 870 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
| 871 |
model_inputs = data_collator(samples)
|
| 872 |
|
|
|
|
| 875 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
| 876 |
train_metrics.append(train_metric)
|
| 877 |
|
|
|
|
| 878 |
|
| 879 |
if cur_step % training_args.logging_steps * grad_accum_steps == 0 and cur_step > 0:
|
| 880 |
# Save metrics
|
runs/Jul10_07-37-20_t1v-n-0e7426e8-w-0/events.out.tfevents.1625902752.t1v-n-0e7426e8-w-0.18397.3.v2
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1aa4fd14ba6d0007ac2b4c7ad5f7b03ab486b3899ece3eba1fefe852923f2366
|
| 3 |
+
size 40
|
runs/Jul10_07-45-49_t1v-n-0e7426e8-w-0/events.out.tfevents.1625903173.t1v-n-0e7426e8-w-0.20563.3.v2
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9086b97ea9ba59e96e4c66b26c205fe1207d0a94ab355127a1e4f8078d84a269
|
| 3 |
+
size 45399
|