Saving weights and logs of step 1200
Browse files
flax_model.msgpack
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 891548548
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1370699db9ee8980b9d18ba78ab3c7bacbf64455af8b4d767abadaf4e1c6a466
|
| 3 |
size 891548548
|
run_t5_mlm_flax_custom_dataset.py
CHANGED
|
@@ -703,6 +703,13 @@ if __name__ == "__main__":
|
|
| 703 |
else:
|
| 704 |
model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
| 705 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 706 |
# Data collator
|
| 707 |
# This one will take care of randomly masking the tokens.
|
| 708 |
data_collator = FlaxDataCollatorForT5MLM(
|
|
|
|
| 703 |
else:
|
| 704 |
model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
| 705 |
|
| 706 |
+
|
| 707 |
+
# def to_bf16(t):
|
| 708 |
+
# return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
|
| 709 |
+
#
|
| 710 |
+
#
|
| 711 |
+
# model.params = to_bf16(model.params)
|
| 712 |
+
|
| 713 |
# Data collator
|
| 714 |
# This one will take care of randomly masking the tokens.
|
| 715 |
data_collator = FlaxDataCollatorForT5MLM(
|
runs/Jul10_08-38-10_t1v-n-0e7426e8-w-0/events.out.tfevents.1625906314.t1v-n-0e7426e8-w-0.25839.3.v2
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b87fa89d0ac5eeabdea48a6a8250033be187061f5d0f1635b1d3f57ce6c7daaf
|
| 3 |
+
size 181839
|
streaming_dataset_filter_test.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from clean import clean_text
|
| 2 |
+
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
|
| 5 |
+
dataset_v0 = load_dataset('oscar', "unshuffled_deduplicated_nl", split='train', streaming=True)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def f(obj):
|
| 9 |
+
obj["text"] = clean_text(obj["text"])
|
| 10 |
+
return obj
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
dataset_v1 = dataset_v0.map(f)
|
| 14 |
+
it = iter(dataset_v0)
|
| 15 |
+
|
| 16 |
+
print(next(it))
|
| 17 |
+
print(next(it))
|
| 18 |
+
print(next(it))
|
| 19 |
+
|
| 20 |
+
it = iter(dataset_v1)
|
| 21 |
+
|
| 22 |
+
print(next(it))
|
| 23 |
+
print(next(it))
|
| 24 |
+
print(next(it))
|