Upload folder using huggingface_hub
Browse files- .idea/misc.xml +3 -0
- .idea/workspace.xml +26 -1
- app.py +1 -1
- requirements.txt +0 -3
- src/builder.py +4 -0
- src/datasets/template_map_fn.py +36 -0
- src/datasets/utils.py +30 -2
- src/models/puffin/model.py +7 -11
.idea/misc.xml
CHANGED
|
@@ -1,4 +1,7 @@
|
|
| 1 |
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
<project version="4">
|
|
|
|
|
|
|
|
|
|
| 3 |
<component name="ProjectRootManager" version="2" project-jdk-name="$USER_HOME$/envs/pt2.7" project-jdk-type="Python SDK" />
|
| 4 |
</project>
|
|
|
|
| 1 |
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
<project version="4">
|
| 3 |
+
<component name="Black">
|
| 4 |
+
<option name="sdkName" value="$USER_HOME$/envs/pt2.7" />
|
| 5 |
+
</component>
|
| 6 |
<component name="ProjectRootManager" version="2" project-jdk-name="$USER_HOME$/envs/pt2.7" project-jdk-type="Python SDK" />
|
| 7 |
</project>
|
.idea/workspace.xml
CHANGED
|
@@ -1,5 +1,8 @@
|
|
| 1 |
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
<project version="4">
|
|
|
|
|
|
|
|
|
|
| 3 |
<component name="ChangeListManager">
|
| 4 |
<list default="true" id="9dd87dac-8a5e-4178-a1d7-afa664ac2f6a" name="Changes" comment="" />
|
| 5 |
<option name="SHOW_DIALOG" value="false" />
|
|
@@ -7,6 +10,13 @@
|
|
| 7 |
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
| 8 |
<option name="LAST_RESOLUTION" value="IGNORE" />
|
| 9 |
</component>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
<component name="ProjectColorInfo"><![CDATA[{
|
| 11 |
"associatedIndex": 6
|
| 12 |
}]]></component>
|
|
@@ -20,6 +30,10 @@
|
|
| 20 |
"ModuleVcsDetector.initialDetectionPerformed": "true",
|
| 21 |
"RunOnceActivity.ShowReadmeOnStart": "true",
|
| 22 |
"last_opened_file_path": "/Users/wusize/projects/Puffin",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
"nodejs_package_manager_path": "npm",
|
| 24 |
"settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable",
|
| 25 |
"vue.rearranger.settings.migration": "true"
|
|
@@ -40,11 +54,22 @@
|
|
| 40 |
<option name="number" value="Default" />
|
| 41 |
<option name="presentableId" value="Default" />
|
| 42 |
<updated>1760056680813</updated>
|
| 43 |
-
<workItem from="1760056681869" duration="
|
| 44 |
</task>
|
| 45 |
<servers />
|
| 46 |
</component>
|
| 47 |
<component name="TypeScriptGeneratedFilesManager">
|
| 48 |
<option name="version" value="3" />
|
| 49 |
</component>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
</project>
|
|
|
|
| 1 |
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
<project version="4">
|
| 3 |
+
<component name="AutoImportSettings">
|
| 4 |
+
<option name="autoReloadType" value="SELECTIVE" />
|
| 5 |
+
</component>
|
| 6 |
<component name="ChangeListManager">
|
| 7 |
<list default="true" id="9dd87dac-8a5e-4178-a1d7-afa664ac2f6a" name="Changes" comment="" />
|
| 8 |
<option name="SHOW_DIALOG" value="false" />
|
|
|
|
| 10 |
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
| 11 |
<option name="LAST_RESOLUTION" value="IGNORE" />
|
| 12 |
</component>
|
| 13 |
+
<component name="FileTemplateManagerImpl">
|
| 14 |
+
<option name="RECENT_TEMPLATES">
|
| 15 |
+
<list>
|
| 16 |
+
<option value="Python Script" />
|
| 17 |
+
</list>
|
| 18 |
+
</option>
|
| 19 |
+
</component>
|
| 20 |
<component name="ProjectColorInfo"><![CDATA[{
|
| 21 |
"associatedIndex": 6
|
| 22 |
}]]></component>
|
|
|
|
| 30 |
"ModuleVcsDetector.initialDetectionPerformed": "true",
|
| 31 |
"RunOnceActivity.ShowReadmeOnStart": "true",
|
| 32 |
"last_opened_file_path": "/Users/wusize/projects/Puffin",
|
| 33 |
+
"node.js.detected.package.eslint": "true",
|
| 34 |
+
"node.js.detected.package.tslint": "true",
|
| 35 |
+
"node.js.selected.package.eslint": "(autodetect)",
|
| 36 |
+
"node.js.selected.package.tslint": "(autodetect)",
|
| 37 |
"nodejs_package_manager_path": "npm",
|
| 38 |
"settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable",
|
| 39 |
"vue.rearranger.settings.migration": "true"
|
|
|
|
| 54 |
<option name="number" value="Default" />
|
| 55 |
<option name="presentableId" value="Default" />
|
| 56 |
<updated>1760056680813</updated>
|
| 57 |
+
<workItem from="1760056681869" duration="1047000" />
|
| 58 |
</task>
|
| 59 |
<servers />
|
| 60 |
</component>
|
| 61 |
<component name="TypeScriptGeneratedFilesManager">
|
| 62 |
<option name="version" value="3" />
|
| 63 |
</component>
|
| 64 |
+
<component name="XDebuggerManager">
|
| 65 |
+
<breakpoint-manager>
|
| 66 |
+
<breakpoints>
|
| 67 |
+
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
| 68 |
+
<url>file://$USER_HOME$/envs/pt2.7/lib/python3.10/site-packages/xtuner/dataset/map_fns/template_map_fn.py</url>
|
| 69 |
+
<line>1</line>
|
| 70 |
+
<option name="timeStamp" value="1" />
|
| 71 |
+
</line-breakpoint>
|
| 72 |
+
</breakpoints>
|
| 73 |
+
</breakpoint-manager>
|
| 74 |
+
</component>
|
| 75 |
</project>
|
app.py
CHANGED
|
@@ -8,7 +8,7 @@ import math
|
|
| 8 |
import re
|
| 9 |
from einops import rearrange
|
| 10 |
from mmengine.config import Config
|
| 11 |
-
from
|
| 12 |
|
| 13 |
import matplotlib
|
| 14 |
matplotlib.use("Agg")
|
|
|
|
| 8 |
import re
|
| 9 |
from einops import rearrange
|
| 10 |
from mmengine.config import Config
|
| 11 |
+
from src.builder import BUILDER
|
| 12 |
|
| 13 |
import matplotlib
|
| 14 |
matplotlib.use("Agg")
|
requirements.txt
CHANGED
|
@@ -12,6 +12,3 @@ pillow==11.2.1
|
|
| 12 |
scipy==1.15.2
|
| 13 |
timm==0.9.12
|
| 14 |
transformers==4.49.0
|
| 15 |
-
xtuner==0.1.23
|
| 16 |
-
deepspeed
|
| 17 |
-
|
|
|
|
| 12 |
scipy==1.15.2
|
| 13 |
timm==0.9.12
|
| 14 |
transformers==4.49.0
|
|
|
|
|
|
|
|
|
src/builder.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from mmengine.registry import Registry
|
| 2 |
+
__all__ = ['BUILDER']
|
| 3 |
+
|
| 4 |
+
BUILDER = Registry('builder')
|
src/datasets/template_map_fn.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
from mmengine.utils.misc import get_object_from_string
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def template_map_fn(example, template):
|
| 8 |
+
conversation = example.get('conversation', [])
|
| 9 |
+
for i, single_turn_conversation in enumerate(conversation):
|
| 10 |
+
input = single_turn_conversation.get('input', '')
|
| 11 |
+
if input is None:
|
| 12 |
+
input = ''
|
| 13 |
+
input_text = template.INSTRUCTION.format(input=input, round=i + 1)
|
| 14 |
+
system = single_turn_conversation.get('system', '')
|
| 15 |
+
if system != '' and system is not None:
|
| 16 |
+
system = template.SYSTEM.format(system=system)
|
| 17 |
+
input_text = system + input_text
|
| 18 |
+
single_turn_conversation['input'] = input_text
|
| 19 |
+
|
| 20 |
+
if template.get('SUFFIX', None):
|
| 21 |
+
output_text = single_turn_conversation.get('output', '')
|
| 22 |
+
output_text += template.SUFFIX
|
| 23 |
+
single_turn_conversation['output'] = output_text
|
| 24 |
+
|
| 25 |
+
# SUFFIX_AS_EOS is False ==> need_eos_token is True
|
| 26 |
+
single_turn_conversation['need_eos_token'] = \
|
| 27 |
+
not template.get('SUFFIX_AS_EOS', False)
|
| 28 |
+
single_turn_conversation['sep'] = template.get('SEP', '')
|
| 29 |
+
|
| 30 |
+
return {'conversation': conversation}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def template_map_fn_factory(template):
|
| 34 |
+
if isinstance(template, str): # for resume
|
| 35 |
+
template = get_object_from_string(template)
|
| 36 |
+
return partial(template_map_fn, template=template)
|
src/datasets/utils.py
CHANGED
|
@@ -1,14 +1,42 @@
|
|
| 1 |
import copy
|
| 2 |
import random
|
| 3 |
-
from xtuner.dataset.utils import get_bos_eos_token_ids
|
| 4 |
-
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
|
| 5 |
import json
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
INPUT_IMAGE_TOKEN_INDEX = IMAGE_TOKEN_INDEX
|
| 8 |
OUTPUT_IMAGE_TOKEN_INDEX = -300
|
| 9 |
QUERY_TOKEN_INDEX = -400
|
| 10 |
QUERY_TOKEN = '<query>'
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
def crop2square(pil_img):
|
| 13 |
width, height = pil_img.width, pil_img.height
|
| 14 |
|
|
|
|
| 1 |
import copy
|
| 2 |
import random
|
|
|
|
|
|
|
| 3 |
import json
|
| 4 |
|
| 5 |
+
|
| 6 |
+
def get_bos_eos_token_ids(tokenizer):
|
| 7 |
+
if tokenizer.__class__.__name__ in [
|
| 8 |
+
'QWenTokenizer', 'QWen2Tokenizer', 'Qwen2TokenizerFast'
|
| 9 |
+
]:
|
| 10 |
+
bos_token_id = []
|
| 11 |
+
eos_token_id = tokenizer.eos_token_id
|
| 12 |
+
assert eos_token_id is not None, \
|
| 13 |
+
'Please set eos_token for Qwen tokenizer!'
|
| 14 |
+
elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer':
|
| 15 |
+
bos_token_id = [64790, 64792]
|
| 16 |
+
eos_token_id = tokenizer.eos_token_id
|
| 17 |
+
else:
|
| 18 |
+
bos_token_id = tokenizer.bos_token_id
|
| 19 |
+
eos_token_id = tokenizer.eos_token_id
|
| 20 |
+
if isinstance(bos_token_id, int):
|
| 21 |
+
bos_token_id = [bos_token_id]
|
| 22 |
+
if isinstance(eos_token_id, int):
|
| 23 |
+
eos_token_id = [eos_token_id]
|
| 24 |
+
return bos_token_id, eos_token_id
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
IGNORE_INDEX = -100
|
| 28 |
+
DEFAULT_PAD_TOKEN_INDEX = 0
|
| 29 |
+
IMAGE_TOKEN_INDEX = -200
|
| 30 |
+
DEFAULT_IMAGE_TOKEN = '<image>'
|
| 31 |
+
|
| 32 |
INPUT_IMAGE_TOKEN_INDEX = IMAGE_TOKEN_INDEX
|
| 33 |
OUTPUT_IMAGE_TOKEN_INDEX = -300
|
| 34 |
QUERY_TOKEN_INDEX = -400
|
| 35 |
QUERY_TOKEN = '<query>'
|
| 36 |
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
def crop2square(pil_img):
|
| 41 |
width, height = pil_img.width, pil_img.height
|
| 42 |
|
src/models/puffin/model.py
CHANGED
|
@@ -11,16 +11,17 @@ from torch.autograd.function import Function
|
|
| 11 |
from torch.nn.utils.rnn import pad_sequence
|
| 12 |
from mmengine.logging import print_log
|
| 13 |
from mmengine.model import BaseModel
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
from
|
| 17 |
-
from
|
| 18 |
from transformers.cache_utils import DynamicCache
|
| 19 |
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
|
| 20 |
|
| 21 |
from src.models.connector import ConnectorConfig, ConnectorEncoder
|
| 22 |
from src.models.stable_diffusion3.pipeline_stable_diffusion_3_dynamic import StableDiffusion3Pipeline
|
| 23 |
-
from src.datasets.utils import encode_fn, QUERY_TOKEN_INDEX,
|
|
|
|
| 24 |
|
| 25 |
class _ScaleGradient(Function):
|
| 26 |
@staticmethod
|
|
@@ -74,7 +75,7 @@ class Qwen2p5RadioStableDiffusion3HFDynamic(BaseModel):
|
|
| 74 |
fold_size=2,
|
| 75 |
unconditional=0.1,
|
| 76 |
unconditional_cross_view=0.1,
|
| 77 |
-
pretrained_pth=None,
|
| 78 |
use_activation_checkpointing=False,
|
| 79 |
*args, **kwargs):
|
| 80 |
super().__init__()
|
|
@@ -136,11 +137,6 @@ class Qwen2p5RadioStableDiffusion3HFDynamic(BaseModel):
|
|
| 136 |
if use_activation_checkpointing:
|
| 137 |
self.llm.enable_input_require_grads()
|
| 138 |
self.gradient_checkpointing_enable()
|
| 139 |
-
|
| 140 |
-
if pretrained_pth is not None:
|
| 141 |
-
pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
|
| 142 |
-
info = self.load_state_dict(pretrained_state_dict, strict=False)
|
| 143 |
-
print_log(f'Load pretrained weight from {pretrained_pth}')
|
| 144 |
|
| 145 |
@property
|
| 146 |
def device(self):
|
|
|
|
| 11 |
from torch.nn.utils.rnn import pad_sequence
|
| 12 |
from mmengine.logging import print_log
|
| 13 |
from mmengine.model import BaseModel
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from src.builder import BUILDER
|
| 17 |
+
from src.datasets.template_map_fn import template_map_fn
|
| 18 |
from transformers.cache_utils import DynamicCache
|
| 19 |
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
|
| 20 |
|
| 21 |
from src.models.connector import ConnectorConfig, ConnectorEncoder
|
| 22 |
from src.models.stable_diffusion3.pipeline_stable_diffusion_3_dynamic import StableDiffusion3Pipeline
|
| 23 |
+
from src.datasets.utils import (encode_fn, QUERY_TOKEN_INDEX, IGNORE_INDEX,
|
| 24 |
+
DEFAULT_IMAGE_TOKEN, INPUT_IMAGE_TOKEN_INDEX)
|
| 25 |
|
| 26 |
class _ScaleGradient(Function):
|
| 27 |
@staticmethod
|
|
|
|
| 75 |
fold_size=2,
|
| 76 |
unconditional=0.1,
|
| 77 |
unconditional_cross_view=0.1,
|
| 78 |
+
# pretrained_pth=None,
|
| 79 |
use_activation_checkpointing=False,
|
| 80 |
*args, **kwargs):
|
| 81 |
super().__init__()
|
|
|
|
| 137 |
if use_activation_checkpointing:
|
| 138 |
self.llm.enable_input_require_grads()
|
| 139 |
self.gradient_checkpointing_enable()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
@property
|
| 142 |
def device(self):
|