remove_weights_from_python_wheel

#9
by jdye64 - opened
MANIFEST.in CHANGED
@@ -1,3 +1,4 @@
1
  include README.md
2
  include THIRD_PARTY_NOTICES.md
3
- recursive-include nemotron_graphic_elements_v1
 
 
1
  include README.md
2
  include THIRD_PARTY_NOTICES.md
3
+ recursive-include nemotron_graphic_elements_v1 *.py *.json *.png
4
+ recursive-exclude nemotron_graphic_elements_v1 *.pth
nemotron_graphic_elements_v1/__init__.py CHANGED
@@ -6,6 +6,8 @@ Nemotron Graphic Elements v1
6
 
7
  A specialized object detection system designed to identify and extract key elements
8
  from charts and graphs. Based on YOLOX architecture.
 
 
9
  """
10
 
11
  __version__ = "1.0.0"
@@ -19,6 +21,7 @@ from .utils import (
19
  COLORS,
20
  )
21
  from .graphic_element_v1 import Exp
 
22
 
23
  __all__ = [
24
  "define_model",
@@ -28,5 +31,7 @@ __all__ = [
28
  "reformat_for_plotting",
29
  "reorder_boxes",
30
  "COLORS",
 
 
31
  ]
32
 
 
6
 
7
  A specialized object detection system designed to identify and extract key elements
8
  from charts and graphs. Based on YOLOX architecture.
9
+
10
+ Model weights are automatically downloaded from Hugging Face Hub on first use.
11
  """
12
 
13
  __version__ = "1.0.0"
 
21
  COLORS,
22
  )
23
  from .graphic_element_v1 import Exp
24
+ from .weights import get_weights_path, clear_cache
25
 
26
  __all__ = [
27
  "define_model",
 
31
  "reformat_for_plotting",
32
  "reorder_boxes",
33
  "COLORS",
34
+ "get_weights_path",
35
+ "clear_cache",
36
  ]
37
 
nemotron_graphic_elements_v1/graphic_element_v1.py CHANGED
@@ -4,7 +4,9 @@
4
  import os
5
  import torch
6
  import torch.nn as nn
7
- from typing import List, Tuple
 
 
8
 
9
 
10
  class Exp:
@@ -16,12 +18,28 @@ class Exp:
16
  parameters, and class-specific thresholds.
17
  """
18
 
19
- def __init__(self) -> None:
20
- """Initialize the configuration with default parameters."""
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  self.name: str = "graphic-element-v1"
22
- # Use package directory for weights path
23
- package_dir = os.path.dirname(os.path.abspath(__file__))
24
- self.ckpt: str = os.path.join(package_dir, "weights.pth")
 
 
 
25
  self.device: str = "cuda:0" if torch.cuda.is_available() else "cpu"
26
 
27
  # YOLOX architecture parameters
 
4
  import os
5
  import torch
6
  import torch.nn as nn
7
+ from typing import List, Tuple, Optional
8
+
9
+ from .weights import get_weights_path
10
 
11
 
12
  class Exp:
 
18
  parameters, and class-specific thresholds.
19
  """
20
 
21
+ def __init__(
22
+ self,
23
+ weights_cache_dir: Optional[str] = None,
24
+ force_download: bool = False,
25
+ hf_token: Optional[str] = None,
26
+ ) -> None:
27
+ """
28
+ Initialize the configuration with default parameters.
29
+
30
+ Args:
31
+ weights_cache_dir: Directory to cache downloaded weights.
32
+ Defaults to ~/.cache/nemotron_graphic_elements_v1
33
+ force_download: If True, re-download weights even if cached.
34
+ hf_token: Hugging Face token for accessing gated models (if needed).
35
+ """
36
  self.name: str = "graphic-element-v1"
37
+ # Get weights path (downloads from HuggingFace if needed)
38
+ self.ckpt: str = get_weights_path(
39
+ cache_dir=weights_cache_dir,
40
+ force_download=force_download,
41
+ token=hf_token,
42
+ )
43
  self.device: str = "cuda:0" if torch.cuda.is_available() else "cpu"
44
 
45
  # YOLOX architecture parameters
nemotron_graphic_elements_v1/model.py CHANGED
@@ -13,13 +13,23 @@ from typing import Dict, List, Tuple, Union
13
  from .yolox.boxes import postprocess
14
 
15
 
16
- def define_model(config_name: str = "graphic_element_v1", verbose: bool = True) -> nn.Module:
 
 
 
 
 
 
17
  """
18
  Defines and initializes the model based on the configuration.
19
 
20
  Args:
21
  config_name (str): Configuration name. Defaults to "graphic_element_v1".
22
  verbose (bool): Whether to print verbose output. Defaults to True.
 
 
 
 
23
 
24
  Returns:
25
  torch.nn.Module: The initialized YOLOX model.
@@ -27,7 +37,11 @@ def define_model(config_name: str = "graphic_element_v1", verbose: bool = True)
27
  # Import the config class
28
  from .graphic_element_v1 import Exp
29
 
30
- config = Exp()
 
 
 
 
31
  model = config.get_model()
32
 
33
  # Load weights
 
13
  from .yolox.boxes import postprocess
14
 
15
 
16
+ def define_model(
17
+ config_name: str = "graphic_element_v1",
18
+ verbose: bool = True,
19
+ weights_cache_dir: str = None,
20
+ force_download: bool = False,
21
+ hf_token: str = None,
22
+ ) -> nn.Module:
23
  """
24
  Defines and initializes the model based on the configuration.
25
 
26
  Args:
27
  config_name (str): Configuration name. Defaults to "graphic_element_v1".
28
  verbose (bool): Whether to print verbose output. Defaults to True.
29
+ weights_cache_dir (str): Directory to cache downloaded weights.
30
+ Defaults to ~/.cache/nemotron_graphic_elements_v1
31
+ force_download (bool): If True, re-download weights even if cached.
32
+ hf_token (str): Hugging Face token for accessing gated models (if needed).
33
 
34
  Returns:
35
  torch.nn.Module: The initialized YOLOX model.
 
37
  # Import the config class
38
  from .graphic_element_v1 import Exp
39
 
40
+ config = Exp(
41
+ weights_cache_dir=weights_cache_dir,
42
+ force_download=force_download,
43
+ hf_token=hf_token,
44
+ )
45
  model = config.get_model()
46
 
47
  # Load weights
nemotron_graphic_elements_v1/weights.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ Weights management for Nemotron Graphic Elements v1.
6
+
7
+ This module handles downloading model weights from Hugging Face Hub
8
+ when they are not bundled with the package.
9
+ """
10
+
11
+ import os
12
+ from pathlib import Path
13
+ from typing import Optional
14
+
15
+ from huggingface_hub import hf_hub_download
16
+
17
+
18
+ # Hugging Face repository information
19
+ HF_REPO_ID = "nvidia/nemotron-graphic-elements-v1"
20
+ WEIGHTS_FILENAME = "nemotron_graphic_elements_v1/weights.pth"
21
+
22
+ # Default cache directory for weights
23
+ DEFAULT_CACHE_DIR = Path.home() / ".cache" / "nemotron_graphic_elements_v1"
24
+
25
+
26
+ def get_weights_path(
27
+ cache_dir: Optional[str] = None,
28
+ force_download: bool = False,
29
+ token: Optional[str] = None,
30
+ ) -> str:
31
+ """
32
+ Get the path to the model weights, downloading if necessary.
33
+
34
+ This function first checks if weights exist in the package directory
35
+ (for development or manual installation). If not found, it downloads
36
+ the weights from Hugging Face Hub to the cache directory.
37
+
38
+ Args:
39
+ cache_dir: Directory to cache downloaded weights. Defaults to
40
+ ~/.cache/nemotron_graphic_elements_v1
41
+ force_download: If True, re-download even if weights exist in cache.
42
+ token: Hugging Face token for accessing gated models (if needed).
43
+
44
+ Returns:
45
+ str: Path to the weights file.
46
+
47
+ Raises:
48
+ RuntimeError: If weights cannot be found or downloaded.
49
+ """
50
+ # First, check if weights exist in the package directory (dev mode)
51
+ package_dir = Path(__file__).parent
52
+ local_weights = package_dir / "weights.pth"
53
+
54
+ if local_weights.exists() and not force_download:
55
+ return str(local_weights)
56
+
57
+ # Set up cache directory
58
+ if cache_dir is None:
59
+ cache_dir = DEFAULT_CACHE_DIR
60
+ else:
61
+ cache_dir = Path(cache_dir)
62
+
63
+ cache_dir.mkdir(parents=True, exist_ok=True)
64
+ cached_weights = cache_dir / "weights.pth"
65
+
66
+ # Check if weights are already cached
67
+ if cached_weights.exists() and not force_download:
68
+ return str(cached_weights)
69
+
70
+ # Download from Hugging Face Hub
71
+ print(f" -> Downloading weights from Hugging Face Hub ({HF_REPO_ID})...")
72
+
73
+ try:
74
+ downloaded_path = hf_hub_download(
75
+ repo_id=HF_REPO_ID,
76
+ filename=WEIGHTS_FILENAME,
77
+ cache_dir=str(cache_dir),
78
+ force_download=force_download,
79
+ token=token,
80
+ local_dir=str(cache_dir),
81
+ local_dir_use_symlinks=False,
82
+ )
83
+
84
+ # The file might be downloaded to a subdirectory, move to expected location
85
+ downloaded_path = Path(downloaded_path)
86
+ if downloaded_path != cached_weights:
87
+ # Copy to the expected location if different
88
+ import shutil
89
+ shutil.copy2(downloaded_path, cached_weights)
90
+
91
+ print(f" -> Weights downloaded to {cached_weights}")
92
+ return str(cached_weights)
93
+
94
+ except Exception as e:
95
+ raise RuntimeError(
96
+ f"Failed to download weights from Hugging Face Hub.\n"
97
+ f"Repository: {HF_REPO_ID}\n"
98
+ f"Error: {e}\n\n"
99
+ f"Please ensure you have internet access and the huggingface_hub "
100
+ f"package is installed. You can also manually download the weights "
101
+ f"from https://huggingface.co/{HF_REPO_ID} and place them at:\n"
102
+ f" {cached_weights}"
103
+ ) from e
104
+
105
+
106
+ def clear_cache(cache_dir: Optional[str] = None) -> None:
107
+ """
108
+ Clear the cached weights.
109
+
110
+ Args:
111
+ cache_dir: Directory where weights are cached. Defaults to
112
+ ~/.cache/nemotron_graphic_elements_v1
113
+ """
114
+ if cache_dir is None:
115
+ cache_dir = DEFAULT_CACHE_DIR
116
+ else:
117
+ cache_dir = Path(cache_dir)
118
+
119
+ cached_weights = cache_dir / "weights.pth"
120
+
121
+ if cached_weights.exists():
122
+ cached_weights.unlink()
123
+ print(f" -> Removed cached weights from {cached_weights}")
124
+ else:
125
+ print(f" -> No cached weights found at {cached_weights}")
126
+
pyproject.toml CHANGED
@@ -32,6 +32,7 @@ dependencies = [
32
  "matplotlib>=3.5.0",
33
  "pandas>=1.3.0",
34
  "Pillow>=9.0.0",
 
35
  ]
36
 
37
  [project.optional-dependencies]
@@ -50,5 +51,5 @@ Repository = "https://huggingface.co/nvidia/nemotron-graphic-elements-v1"
50
  packages = ["nemotron_graphic_elements_v1", "nemotron_graphic_elements_v1.yolox", "nemotron_graphic_elements_v1.post_processing"]
51
 
52
  [tool.setuptools.package-data]
53
- "nemotron_graphic_elements_v1" = ["*.pth", "*.json", "*.png"]
54
 
 
32
  "matplotlib>=3.5.0",
33
  "pandas>=1.3.0",
34
  "Pillow>=9.0.0",
35
+ "huggingface_hub>=0.20.0",
36
  ]
37
 
38
  [project.optional-dependencies]
 
51
  packages = ["nemotron_graphic_elements_v1", "nemotron_graphic_elements_v1.yolox", "nemotron_graphic_elements_v1.post_processing"]
52
 
53
  [tool.setuptools.package-data]
54
+ "nemotron_graphic_elements_v1" = ["*.json", "*.png"]
55