Spaces:
Runtime error
Runtime error
Update cog_sdxl_dataset_and_utils.py
Browse files
cog_sdxl_dataset_and_utils.py
CHANGED
|
@@ -33,6 +33,11 @@ def prepare_mask(mask: PIL.Image.Image, width: int = 512, height: int = 512) ->
|
|
| 33 |
return torch.from_numpy(np.expand_dims(arr, 0)).unsqueeze(0)
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
class PreprocessedDataset(Dataset):
|
| 37 |
def __init__(
|
| 38 |
self,
|
|
@@ -175,3 +180,4 @@ def load_models(pretrained_model_name_or_path, revision, device, weight_dtype):
|
|
| 175 |
unet.to(device, dtype=weight_dtype)
|
| 176 |
|
| 177 |
return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet
|
|
|
|
|
|
| 33 |
return torch.from_numpy(np.expand_dims(arr, 0)).unsqueeze(0)
|
| 34 |
|
| 35 |
|
| 36 |
+
class TokenEmbeddingsHandler:
|
| 37 |
+
def __init__(self, text_encoders, tokenizers):
|
| 38 |
+
self.text_encoders = text_encoders
|
| 39 |
+
self.tokenizers = tokenizers
|
| 40 |
+
|
| 41 |
class PreprocessedDataset(Dataset):
|
| 42 |
def __init__(
|
| 43 |
self,
|
|
|
|
| 180 |
unet.to(device, dtype=weight_dtype)
|
| 181 |
|
| 182 |
return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet
|
| 183 |
+
|