Create README.md
Browse files
README.md
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- generated_from_trainer
|
| 4 |
+
- retnet
|
| 5 |
+
model-index:
|
| 6 |
+
- name: sdprompt-retnet-300m
|
| 7 |
+
results: []
|
| 8 |
+
license: mit
|
| 9 |
+
datasets:
|
| 10 |
+
- Gustavosta/Stable-Diffusion-Prompts
|
| 11 |
+
- FredZhang7/anime-prompts-180K
|
| 12 |
+
language:
|
| 13 |
+
- en
|
| 14 |
+
library_name: transformers
|
| 15 |
+
pipeline_tag: text-generation
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
| 19 |
+
should probably proofread and complete it, then remove this comment. -->
|
| 20 |
+
|
| 21 |
+
# SDPrompt-RetNet-300M
|
| 22 |
+
|
| 23 |
+
This model is a RetNet model trained from scratch using https://github.com/syncdoth/RetNet.
|
| 24 |
+
It achieves the following results on the evaluation set:
|
| 25 |
+
- Loss: 0.3616
|
| 26 |
+
|
| 27 |
+
## Usage
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
pip install transformers safetensors timm
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
```py
|
| 34 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
| 35 |
+
|
| 36 |
+
MODEL_NAME = ""
|
| 37 |
+
|
| 38 |
+
DEVICE = "cuda"
|
| 39 |
+
|
| 40 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 41 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 42 |
+
MODEL_NAME,
|
| 43 |
+
trust_remote_code=True,
|
| 44 |
+
).to(DEVICE)
|
| 45 |
+
|
| 46 |
+
streamer = TextStreamer(tokenizer)
|
| 47 |
+
|
| 48 |
+
prompt = "<s>1girl"
|
| 49 |
+
|
| 50 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 51 |
+
|
| 52 |
+
_ = model.generate(
|
| 53 |
+
inputs["input_ids"],
|
| 54 |
+
max_new_tokens=256,
|
| 55 |
+
do_sample=True,
|
| 56 |
+
top_p=0.9,
|
| 57 |
+
top_k=20,
|
| 58 |
+
temperature=0.9,
|
| 59 |
+
streamer=streamer,
|
| 60 |
+
)
|
| 61 |
+
# <s> 1girl, absurdres, animal ear fluff, animal ears, bangs, bare shoulders, black hair, blue archive, blunt bangs, blush, closed mouth, collarbone, commentary request, eyes visible through hair, green eyes, hair between eyes, halo, hand on own face, hand up, highres, jacket, kisaki blue archive, long hair, long sleeves, looking at viewer, open clothes, open jacket, shinonome asu, simple background, solo, track jacket, upper body, white background, white jacket</s>
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
## Model description
|
| 65 |
+
|
| 66 |
+
This model is trained with Stable Diffusion prompts and Danbooru tags to generate prompts for image generation models.
|
| 67 |
+
|
| 68 |
+
## Training data
|
| 69 |
+
|
| 70 |
+
- [Gustavosta/Stable-Diffusion-Prompts](https://huggingface.co/datasets/Gustavosta/Stable-Diffusion-Prompts)
|
| 71 |
+
- [FredZhang7/anime-prompts-180K](https://huggingface.co/datasets/FredZhang7/anime-prompts-180K)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
## Training procedure
|
| 75 |
+
|
| 76 |
+
### Training hyperparameters
|
| 77 |
+
|
| 78 |
+
The following hyperparameters were used during training:
|
| 79 |
+
- learning_rate: 0.0006
|
| 80 |
+
- train_batch_size: 8
|
| 81 |
+
- eval_batch_size: 8
|
| 82 |
+
- seed: 42
|
| 83 |
+
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
|
| 84 |
+
- lr_scheduler_type: cosine
|
| 85 |
+
- lr_scheduler_warmup_steps: 500
|
| 86 |
+
- num_epochs: 5
|
| 87 |
+
|
| 88 |
+
### Training results
|
| 89 |
+
|
| 90 |
+
| Training Loss | Epoch | Step | Validation Loss |
|
| 91 |
+
|:-------------:|:-----:|:------:|:---------------:|
|
| 92 |
+
| 2.6714 | 0.03 | 1000 | 2.5787 |
|
| 93 |
+
| 2.1551 | 0.07 | 2000 | 2.3981 |
|
| 94 |
+
| 2.1439 | 0.1 | 3000 | 2.1160 |
|
| 95 |
+
| 1.8406 | 0.14 | 4000 | 1.9138 |
|
| 96 |
+
| 1.7485 | 0.17 | 5000 | 1.7847 |
|
| 97 |
+
| 1.6417 | 0.21 | 6000 | 1.7120 |
|
| 98 |
+
| 1.6084 | 0.24 | 7000 | 1.6055 |
|
| 99 |
+
| 1.4805 | 0.28 | 8000 | 1.5946 |
|
| 100 |
+
| 1.5524 | 0.31 | 9000 | 1.5027 |
|
| 101 |
+
| 1.4425 | 0.35 | 10000 | 1.4876 |
|
| 102 |
+
| 1.4007 | 0.38 | 11000 | 1.4364 |
|
| 103 |
+
| 1.4637 | 0.42 | 12000 | 1.3896 |
|
| 104 |
+
| 1.3211 | 0.45 | 13000 | 1.3968 |
|
| 105 |
+
| 1.3246 | 0.49 | 14000 | 1.3403 |
|
| 106 |
+
| 1.3461 | 0.52 | 15000 | 1.3156 |
|
| 107 |
+
| 1.2897 | 0.56 | 16000 | 1.2977 |
|
| 108 |
+
| 1.2748 | 0.59 | 17000 | 1.2823 |
|
| 109 |
+
| 1.2424 | 0.62 | 18000 | 1.2649 |
|
| 110 |
+
| 1.348 | 0.66 | 19000 | 1.2134 |
|
| 111 |
+
| 1.1797 | 0.69 | 20000 | 1.2030 |
|
| 112 |
+
| 1.2116 | 0.73 | 21000 | 1.2033 |
|
| 113 |
+
| 1.1702 | 0.76 | 22000 | 1.1453 |
|
| 114 |
+
| 1.1027 | 0.8 | 23000 | 1.1597 |
|
| 115 |
+
| 1.1932 | 0.83 | 24000 | 1.1506 |
|
| 116 |
+
| 1.3669 | 0.87 | 25000 | 1.1428 |
|
| 117 |
+
| 1.0705 | 0.9 | 26000 | 1.1239 |
|
| 118 |
+
| 1.1474 | 0.94 | 27000 | 1.1239 |
|
| 119 |
+
| 1.0879 | 0.97 | 28000 | 1.1168 |
|
| 120 |
+
| 0.9879 | 1.01 | 29000 | 1.0848 |
|
| 121 |
+
| 0.9928 | 1.04 | 30000 | 1.0953 |
|
| 122 |
+
| 0.9095 | 1.08 | 31000 | 1.1043 |
|
| 123 |
+
| 1.0423 | 1.11 | 32000 | 1.0823 |
|
| 124 |
+
| 0.9478 | 1.15 | 33000 | 1.0840 |
|
| 125 |
+
| 0.9979 | 1.18 | 34000 | 1.0387 |
|
| 126 |
+
| 1.0316 | 1.22 | 35000 | 1.0282 |
|
| 127 |
+
| 1.0531 | 1.25 | 36000 | 1.0369 |
|
| 128 |
+
| 0.919 | 1.28 | 37000 | 1.0398 |
|
| 129 |
+
| 1.0596 | 1.32 | 38000 | 1.0410 |
|
| 130 |
+
| 0.9076 | 1.35 | 39000 | 0.9889 |
|
| 131 |
+
| 0.9698 | 1.39 | 40000 | 1.0004 |
|
| 132 |
+
| 0.9633 | 1.42 | 41000 | 1.0038 |
|
| 133 |
+
| 0.9622 | 1.46 | 42000 | 0.9933 |
|
| 134 |
+
| 0.9809 | 1.49 | 43000 | 0.9805 |
|
| 135 |
+
| 0.9496 | 1.53 | 44000 | 0.9755 |
|
| 136 |
+
| 0.9435 | 1.56 | 45000 | 0.9759 |
|
| 137 |
+
| 0.9337 | 1.6 | 46000 | 0.9615 |
|
| 138 |
+
| 0.8844 | 1.63 | 47000 | 0.9524 |
|
| 139 |
+
| 0.9039 | 1.67 | 48000 | 0.9567 |
|
| 140 |
+
| 0.905 | 1.7 | 49000 | 0.9430 |
|
| 141 |
+
| 0.9491 | 1.74 | 50000 | 0.9205 |
|
| 142 |
+
| 0.8464 | 1.77 | 51000 | 0.9109 |
|
| 143 |
+
| 0.9384 | 1.81 | 52000 | 0.9056 |
|
| 144 |
+
| 0.8121 | 1.84 | 53000 | 0.8969 |
|
| 145 |
+
| 0.8381 | 1.88 | 54000 | 0.8869 |
|
| 146 |
+
| 0.8171 | 1.91 | 55000 | 0.8946 |
|
| 147 |
+
| 0.9024 | 1.94 | 56000 | 0.8993 |
|
| 148 |
+
| 0.84 | 1.98 | 57000 | 0.9011 |
|
| 149 |
+
| 0.6702 | 2.01 | 58000 | 0.8876 |
|
| 150 |
+
| 0.6278 | 2.05 | 59000 | 0.8716 |
|
| 151 |
+
| 0.6876 | 2.08 | 60000 | 0.8546 |
|
| 152 |
+
| 0.6754 | 2.12 | 61000 | 0.8639 |
|
| 153 |
+
| 0.6479 | 2.15 | 62000 | 0.8425 |
|
| 154 |
+
| 0.698 | 2.19 | 63000 | 0.8533 |
|
| 155 |
+
| 0.708 | 2.22 | 64000 | 0.8407 |
|
| 156 |
+
| 0.7021 | 2.26 | 65000 | 0.8160 |
|
| 157 |
+
| 0.5881 | 2.29 | 66000 | 0.8251 |
|
| 158 |
+
| 0.6181 | 2.33 | 67000 | 0.8205 |
|
| 159 |
+
| 0.6789 | 2.36 | 68000 | 0.8066 |
|
| 160 |
+
| 0.6452 | 2.4 | 69000 | 0.8037 |
|
| 161 |
+
| 0.6483 | 2.43 | 70000 | 0.7915 |
|
| 162 |
+
| 0.5868 | 2.47 | 71000 | 0.7864 |
|
| 163 |
+
| 0.6257 | 2.5 | 72000 | 0.7895 |
|
| 164 |
+
| 0.6593 | 2.53 | 73000 | 0.7718 |
|
| 165 |
+
| 0.5957 | 2.57 | 74000 | 0.7490 |
|
| 166 |
+
| 0.6351 | 2.6 | 75000 | 0.7481 |
|
| 167 |
+
| 0.699 | 2.64 | 76000 | 0.7628 |
|
| 168 |
+
| 0.566 | 2.67 | 77000 | 0.7590 |
|
| 169 |
+
| 0.5892 | 2.71 | 78000 | 0.7628 |
|
| 170 |
+
| 0.6052 | 2.74 | 79000 | 0.7633 |
|
| 171 |
+
| 0.6494 | 2.78 | 80000 | 0.7588 |
|
| 172 |
+
| 0.5917 | 2.81 | 81000 | 0.7118 |
|
| 173 |
+
| 0.508 | 2.85 | 82000 | 0.6857 |
|
| 174 |
+
| 0.523 | 2.88 | 83000 | 0.6738 |
|
| 175 |
+
| 0.4894 | 2.92 | 84000 | 0.6713 |
|
| 176 |
+
| 0.5096 | 2.95 | 85000 | 0.6625 |
|
| 177 |
+
| 0.352 | 2.99 | 86000 | 0.6802 |
|
| 178 |
+
| 0.3927 | 3.02 | 87000 | 0.6606 |
|
| 179 |
+
| 0.3468 | 3.06 | 88000 | 0.6546 |
|
| 180 |
+
| 0.3368 | 3.09 | 89000 | 0.6520 |
|
| 181 |
+
| 0.352 | 3.12 | 90000 | 0.6495 |
|
| 182 |
+
| 0.3613 | 3.16 | 91000 | 0.6324 |
|
| 183 |
+
| 0.3501 | 3.19 | 92000 | 0.6227 |
|
| 184 |
+
| 0.3269 | 3.23 | 93000 | 0.6091 |
|
| 185 |
+
| 0.3583 | 3.26 | 94000 | 0.6153 |
|
| 186 |
+
| 0.3278 | 3.3 | 95000 | 0.6178 |
|
| 187 |
+
| 0.3216 | 3.33 | 96000 | 0.6208 |
|
| 188 |
+
| 0.3383 | 3.37 | 97000 | 0.6195 |
|
| 189 |
+
| 0.3326 | 3.4 | 98000 | 0.6088 |
|
| 190 |
+
| 0.3081 | 3.44 | 99000 | 0.5956 |
|
| 191 |
+
| 0.3459 | 3.47 | 100000 | 0.5840 |
|
| 192 |
+
| 0.3139 | 3.51 | 101000 | 0.5712 |
|
| 193 |
+
| 0.3087 | 3.54 | 102000 | 0.5677 |
|
| 194 |
+
| 0.2798 | 3.58 | 103000 | 0.5566 |
|
| 195 |
+
| 0.3166 | 3.61 | 104000 | 0.5332 |
|
| 196 |
+
| 0.2981 | 3.65 | 105000 | 0.5333 |
|
| 197 |
+
| 0.3027 | 3.68 | 106000 | 0.5276 |
|
| 198 |
+
| 0.2815 | 3.72 | 107000 | 0.5024 |
|
| 199 |
+
| 0.2294 | 3.75 | 108000 | 0.5081 |
|
| 200 |
+
| 0.2452 | 3.78 | 109000 | 0.4824 |
|
| 201 |
+
| 0.2733 | 3.82 | 110000 | 0.4695 |
|
| 202 |
+
| 0.3001 | 3.85 | 111000 | 0.4627 |
|
| 203 |
+
| 0.2322 | 3.89 | 112000 | 0.4580 |
|
| 204 |
+
| 0.2362 | 3.92 | 113000 | 0.4402 |
|
| 205 |
+
| 0.2488 | 3.96 | 114000 | 0.4263 |
|
| 206 |
+
| 0.2449 | 3.99 | 115000 | 0.3999 |
|
| 207 |
+
| 0.1798 | 4.03 | 116000 | 0.4038 |
|
| 208 |
+
| 0.1956 | 4.06 | 117000 | 0.4037 |
|
| 209 |
+
| 0.1831 | 4.1 | 118000 | 0.4040 |
|
| 210 |
+
| 0.1802 | 4.13 | 119000 | 0.4039 |
|
| 211 |
+
| 0.1641 | 4.17 | 120000 | 0.4029 |
|
| 212 |
+
| 0.1769 | 4.2 | 121000 | 0.4016 |
|
| 213 |
+
| 0.1564 | 4.24 | 122000 | 0.4026 |
|
| 214 |
+
| 0.1552 | 4.27 | 123000 | 0.3988 |
|
| 215 |
+
| 0.1806 | 4.31 | 124000 | 0.3995 |
|
| 216 |
+
| 0.1783 | 4.34 | 125000 | 0.3995 |
|
| 217 |
+
| 0.1736 | 4.38 | 126000 | 0.3940 |
|
| 218 |
+
| 0.1657 | 4.41 | 127000 | 0.3913 |
|
| 219 |
+
| 0.1598 | 4.44 | 128000 | 0.3871 |
|
| 220 |
+
| 0.1599 | 4.48 | 129000 | 0.3831 |
|
| 221 |
+
| 0.1606 | 4.51 | 130000 | 0.3776 |
|
| 222 |
+
| 0.1639 | 4.55 | 131000 | 0.3754 |
|
| 223 |
+
| 0.1736 | 4.58 | 132000 | 0.3742 |
|
| 224 |
+
| 0.1653 | 4.62 | 133000 | 0.3703 |
|
| 225 |
+
| 0.1708 | 4.65 | 134000 | 0.3681 |
|
| 226 |
+
| 0.1729 | 4.69 | 135000 | 0.3674 |
|
| 227 |
+
| 0.1564 | 4.72 | 136000 | 0.3660 |
|
| 228 |
+
| 0.1734 | 4.76 | 137000 | 0.3641 |
|
| 229 |
+
| 0.163 | 4.79 | 138000 | 0.3632 |
|
| 230 |
+
| 0.1585 | 4.83 | 139000 | 0.3626 |
|
| 231 |
+
| 0.1603 | 4.86 | 140000 | 0.3619 |
|
| 232 |
+
| 0.1751 | 4.9 | 141000 | 0.3617 |
|
| 233 |
+
| 0.1622 | 4.93 | 142000 | 0.3617 |
|
| 234 |
+
| 0.161 | 4.97 | 143000 | 0.3617 |
|
| 235 |
+
| 0.1541 | 5.0 | 144000 | 0.3616 |
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
### Framework versions
|
| 239 |
+
|
| 240 |
+
- Transformers 4.34.1
|
| 241 |
+
- Pytorch 2.0.0+cu118
|
| 242 |
+
- Datasets 2.14.5
|
| 243 |
+
- Tokenizers 0.14.0
|