wusize commited on
Commit
9bcd027
·
verified ·
1 Parent(s): 1a2a9f7

Upload folder using huggingface_hub

Browse files
.idea/misc.xml CHANGED
@@ -1,4 +1,7 @@
1
  <?xml version="1.0" encoding="UTF-8"?>
2
  <project version="4">
 
 
 
3
  <component name="ProjectRootManager" version="2" project-jdk-name="$USER_HOME$/envs/pt2.7" project-jdk-type="Python SDK" />
4
  </project>
 
1
  <?xml version="1.0" encoding="UTF-8"?>
2
  <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="$USER_HOME$/envs/pt2.7" />
5
+ </component>
6
  <component name="ProjectRootManager" version="2" project-jdk-name="$USER_HOME$/envs/pt2.7" project-jdk-type="Python SDK" />
7
  </project>
.idea/workspace.xml CHANGED
@@ -1,5 +1,8 @@
1
  <?xml version="1.0" encoding="UTF-8"?>
2
  <project version="4">
 
 
 
3
  <component name="ChangeListManager">
4
  <list default="true" id="9dd87dac-8a5e-4178-a1d7-afa664ac2f6a" name="Changes" comment="" />
5
  <option name="SHOW_DIALOG" value="false" />
@@ -7,6 +10,13 @@
7
  <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
8
  <option name="LAST_RESOLUTION" value="IGNORE" />
9
  </component>
 
 
 
 
 
 
 
10
  <component name="ProjectColorInfo"><![CDATA[{
11
  "associatedIndex": 6
12
  }]]></component>
@@ -20,6 +30,10 @@
20
  "ModuleVcsDetector.initialDetectionPerformed": "true",
21
  "RunOnceActivity.ShowReadmeOnStart": "true",
22
  "last_opened_file_path": "/Users/wusize/projects/Puffin",
 
 
 
 
23
  "nodejs_package_manager_path": "npm",
24
  "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable",
25
  "vue.rearranger.settings.migration": "true"
@@ -40,11 +54,22 @@
40
  <option name="number" value="Default" />
41
  <option name="presentableId" value="Default" />
42
  <updated>1760056680813</updated>
43
- <workItem from="1760056681869" duration="11000" />
44
  </task>
45
  <servers />
46
  </component>
47
  <component name="TypeScriptGeneratedFilesManager">
48
  <option name="version" value="3" />
49
  </component>
 
 
 
 
 
 
 
 
 
 
 
50
  </project>
 
1
  <?xml version="1.0" encoding="UTF-8"?>
2
  <project version="4">
3
+ <component name="AutoImportSettings">
4
+ <option name="autoReloadType" value="SELECTIVE" />
5
+ </component>
6
  <component name="ChangeListManager">
7
  <list default="true" id="9dd87dac-8a5e-4178-a1d7-afa664ac2f6a" name="Changes" comment="" />
8
  <option name="SHOW_DIALOG" value="false" />
 
10
  <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
11
  <option name="LAST_RESOLUTION" value="IGNORE" />
12
  </component>
13
+ <component name="FileTemplateManagerImpl">
14
+ <option name="RECENT_TEMPLATES">
15
+ <list>
16
+ <option value="Python Script" />
17
+ </list>
18
+ </option>
19
+ </component>
20
  <component name="ProjectColorInfo"><![CDATA[{
21
  "associatedIndex": 6
22
  }]]></component>
 
30
  "ModuleVcsDetector.initialDetectionPerformed": "true",
31
  "RunOnceActivity.ShowReadmeOnStart": "true",
32
  "last_opened_file_path": "/Users/wusize/projects/Puffin",
33
+ "node.js.detected.package.eslint": "true",
34
+ "node.js.detected.package.tslint": "true",
35
+ "node.js.selected.package.eslint": "(autodetect)",
36
+ "node.js.selected.package.tslint": "(autodetect)",
37
  "nodejs_package_manager_path": "npm",
38
  "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable",
39
  "vue.rearranger.settings.migration": "true"
 
54
  <option name="number" value="Default" />
55
  <option name="presentableId" value="Default" />
56
  <updated>1760056680813</updated>
57
+ <workItem from="1760056681869" duration="1047000" />
58
  </task>
59
  <servers />
60
  </component>
61
  <component name="TypeScriptGeneratedFilesManager">
62
  <option name="version" value="3" />
63
  </component>
64
+ <component name="XDebuggerManager">
65
+ <breakpoint-manager>
66
+ <breakpoints>
67
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
68
+ <url>file://$USER_HOME$/envs/pt2.7/lib/python3.10/site-packages/xtuner/dataset/map_fns/template_map_fn.py</url>
69
+ <line>1</line>
70
+ <option name="timeStamp" value="1" />
71
+ </line-breakpoint>
72
+ </breakpoints>
73
+ </breakpoint-manager>
74
+ </component>
75
  </project>
app.py CHANGED
@@ -8,7 +8,7 @@ import math
8
  import re
9
  from einops import rearrange
10
  from mmengine.config import Config
11
- from xtuner.registry import BUILDER
12
 
13
  import matplotlib
14
  matplotlib.use("Agg")
 
8
  import re
9
  from einops import rearrange
10
  from mmengine.config import Config
11
+ from src.builder import BUILDER
12
 
13
  import matplotlib
14
  matplotlib.use("Agg")
requirements.txt CHANGED
@@ -12,6 +12,3 @@ pillow==11.2.1
12
  scipy==1.15.2
13
  timm==0.9.12
14
  transformers==4.49.0
15
- xtuner==0.1.23
16
- deepspeed
17
-
 
12
  scipy==1.15.2
13
  timm==0.9.12
14
  transformers==4.49.0
 
 
 
src/builder.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from mmengine.registry import Registry
2
+ __all__ = ['BUILDER']
3
+
4
+ BUILDER = Registry('builder')
src/datasets/template_map_fn.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from functools import partial
3
+
4
+ from mmengine.utils.misc import get_object_from_string
5
+
6
+
7
+ def template_map_fn(example, template):
8
+ conversation = example.get('conversation', [])
9
+ for i, single_turn_conversation in enumerate(conversation):
10
+ input = single_turn_conversation.get('input', '')
11
+ if input is None:
12
+ input = ''
13
+ input_text = template.INSTRUCTION.format(input=input, round=i + 1)
14
+ system = single_turn_conversation.get('system', '')
15
+ if system != '' and system is not None:
16
+ system = template.SYSTEM.format(system=system)
17
+ input_text = system + input_text
18
+ single_turn_conversation['input'] = input_text
19
+
20
+ if template.get('SUFFIX', None):
21
+ output_text = single_turn_conversation.get('output', '')
22
+ output_text += template.SUFFIX
23
+ single_turn_conversation['output'] = output_text
24
+
25
+ # SUFFIX_AS_EOS is False ==> need_eos_token is True
26
+ single_turn_conversation['need_eos_token'] = \
27
+ not template.get('SUFFIX_AS_EOS', False)
28
+ single_turn_conversation['sep'] = template.get('SEP', '')
29
+
30
+ return {'conversation': conversation}
31
+
32
+
33
+ def template_map_fn_factory(template):
34
+ if isinstance(template, str): # for resume
35
+ template = get_object_from_string(template)
36
+ return partial(template_map_fn, template=template)
src/datasets/utils.py CHANGED
@@ -1,14 +1,42 @@
1
  import copy
2
  import random
3
- from xtuner.dataset.utils import get_bos_eos_token_ids
4
- from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
5
  import json
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  INPUT_IMAGE_TOKEN_INDEX = IMAGE_TOKEN_INDEX
8
  OUTPUT_IMAGE_TOKEN_INDEX = -300
9
  QUERY_TOKEN_INDEX = -400
10
  QUERY_TOKEN = '<query>'
11
 
 
 
 
12
  def crop2square(pil_img):
13
  width, height = pil_img.width, pil_img.height
14
 
 
1
  import copy
2
  import random
 
 
3
  import json
4
 
5
+
6
+ def get_bos_eos_token_ids(tokenizer):
7
+ if tokenizer.__class__.__name__ in [
8
+ 'QWenTokenizer', 'QWen2Tokenizer', 'Qwen2TokenizerFast'
9
+ ]:
10
+ bos_token_id = []
11
+ eos_token_id = tokenizer.eos_token_id
12
+ assert eos_token_id is not None, \
13
+ 'Please set eos_token for Qwen tokenizer!'
14
+ elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer':
15
+ bos_token_id = [64790, 64792]
16
+ eos_token_id = tokenizer.eos_token_id
17
+ else:
18
+ bos_token_id = tokenizer.bos_token_id
19
+ eos_token_id = tokenizer.eos_token_id
20
+ if isinstance(bos_token_id, int):
21
+ bos_token_id = [bos_token_id]
22
+ if isinstance(eos_token_id, int):
23
+ eos_token_id = [eos_token_id]
24
+ return bos_token_id, eos_token_id
25
+
26
+
27
+ IGNORE_INDEX = -100
28
+ DEFAULT_PAD_TOKEN_INDEX = 0
29
+ IMAGE_TOKEN_INDEX = -200
30
+ DEFAULT_IMAGE_TOKEN = '<image>'
31
+
32
  INPUT_IMAGE_TOKEN_INDEX = IMAGE_TOKEN_INDEX
33
  OUTPUT_IMAGE_TOKEN_INDEX = -300
34
  QUERY_TOKEN_INDEX = -400
35
  QUERY_TOKEN = '<query>'
36
 
37
+
38
+
39
+
40
  def crop2square(pil_img):
41
  width, height = pil_img.width, pil_img.height
42
 
src/models/puffin/model.py CHANGED
@@ -11,16 +11,17 @@ from torch.autograd.function import Function
11
  from torch.nn.utils.rnn import pad_sequence
12
  from mmengine.logging import print_log
13
  from mmengine.model import BaseModel
14
- from xtuner.utils import IGNORE_INDEX
15
- from xtuner.registry import BUILDER
16
- from xtuner.model.utils import guess_load_checkpoint
17
- from xtuner.dataset.map_fns.template_map_fn import template_map_fn
18
  from transformers.cache_utils import DynamicCache
19
  from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
20
 
21
  from src.models.connector import ConnectorConfig, ConnectorEncoder
22
  from src.models.stable_diffusion3.pipeline_stable_diffusion_3_dynamic import StableDiffusion3Pipeline
23
- from src.datasets.utils import encode_fn, QUERY_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, INPUT_IMAGE_TOKEN_INDEX
 
24
 
25
  class _ScaleGradient(Function):
26
  @staticmethod
@@ -74,7 +75,7 @@ class Qwen2p5RadioStableDiffusion3HFDynamic(BaseModel):
74
  fold_size=2,
75
  unconditional=0.1,
76
  unconditional_cross_view=0.1,
77
- pretrained_pth=None,
78
  use_activation_checkpointing=False,
79
  *args, **kwargs):
80
  super().__init__()
@@ -136,11 +137,6 @@ class Qwen2p5RadioStableDiffusion3HFDynamic(BaseModel):
136
  if use_activation_checkpointing:
137
  self.llm.enable_input_require_grads()
138
  self.gradient_checkpointing_enable()
139
-
140
- if pretrained_pth is not None:
141
- pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
142
- info = self.load_state_dict(pretrained_state_dict, strict=False)
143
- print_log(f'Load pretrained weight from {pretrained_pth}')
144
 
145
  @property
146
  def device(self):
 
11
  from torch.nn.utils.rnn import pad_sequence
12
  from mmengine.logging import print_log
13
  from mmengine.model import BaseModel
14
+
15
+
16
+ from src.builder import BUILDER
17
+ from src.datasets.template_map_fn import template_map_fn
18
  from transformers.cache_utils import DynamicCache
19
  from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
20
 
21
  from src.models.connector import ConnectorConfig, ConnectorEncoder
22
  from src.models.stable_diffusion3.pipeline_stable_diffusion_3_dynamic import StableDiffusion3Pipeline
23
+ from src.datasets.utils import (encode_fn, QUERY_TOKEN_INDEX, IGNORE_INDEX,
24
+ DEFAULT_IMAGE_TOKEN, INPUT_IMAGE_TOKEN_INDEX)
25
 
26
  class _ScaleGradient(Function):
27
  @staticmethod
 
75
  fold_size=2,
76
  unconditional=0.1,
77
  unconditional_cross_view=0.1,
78
+ # pretrained_pth=None,
79
  use_activation_checkpointing=False,
80
  *args, **kwargs):
81
  super().__init__()
 
137
  if use_activation_checkpointing:
138
  self.llm.enable_input_require_grads()
139
  self.gradient_checkpointing_enable()
 
 
 
 
 
140
 
141
  @property
142
  def device(self):