zakaria-narjis commited on
Commit
af79ab5
·
1 Parent(s): 46e8495

Add apply button functionality and manage session state for image enhancement

Browse files
Files changed (3) hide show
  1. app.py +17 -4
  2. app_legacy.py +300 -0
  3. demo.py +6 -1
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  from PIL import Image
4
  import numpy as np
5
  from streamlit_image_comparison import image_comparison
6
- from src.envs.new_edit_photo import PhotoEditor
7
  from src.sac.sac_inference import InferenceAgent
8
  import yaml
9
  import os
@@ -15,12 +15,13 @@ import pandas as pd
15
  from bokeh.plotting import figure
16
  from bokeh.models import ColumnDataSource
17
  from bokeh.palettes import Spectral3
 
18
  # Set page config to wide mode
19
  st.set_page_config(layout="wide")
20
 
21
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  # DEVICE = torch.device("cpu")
23
- MODEL_PATH = "experiments/ResNet_10_sliders__224_128_aug__2024-07-23_21-23-35"
24
  SLIDERS = ['temp','tint','exposure', 'contrast','highlights','shadows', 'whites', 'blacks','vibrance','saturation']
25
  SLIDERS_ORD = ['contrast','exposure','temp','tint','whites','blacks','highlights','shadows','vibrance','saturation']
26
 
@@ -73,7 +74,11 @@ def enhance_image(image:np.array, params:dict):
73
  input_image = image.unsqueeze(0).to(DEVICE)
74
  parameters = [params[param_name]/100.0 for param_name in SLIDERS_ORD]
75
  parameters = torch.tensor(parameters).unsqueeze(0).to(DEVICE)
76
- enhanced_image = photo_editor(input_image,parameters)
 
 
 
 
77
  enhanced_image = enhanced_image.squeeze(0).cpu().detach().numpy()
78
  enhanced_image = np.clip(enhanced_image, 0, 1)
79
  enhanced_image = (enhanced_image*255).astype(np.uint8)
@@ -104,10 +109,13 @@ def auto_enhance(image,deterministic=True):
104
  return output_parameters
105
 
106
  def slider_callback():
 
 
107
  for name in SLIDERS:
108
  st.session_state.params[name] = st.session_state[f"slider_{name}"]
109
  image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0
110
  st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params)
 
111
 
112
  def auto_random_enhance_callback():
113
  image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0
@@ -134,6 +142,7 @@ def reset_sliders():
134
 
135
  def reset_on_upload():
136
  st.session_state.original_image = None
 
137
  reset_sliders()
138
 
139
  def create_smooth_histogram(image):
@@ -202,8 +211,12 @@ if 'enhanced_image' not in st.session_state:
202
  st.session_state.enhanced_image = None
203
  if 'original_image' not in st.session_state:
204
  st.session_state.original_image = None
 
 
205
  if 'params' not in st.session_state:
206
  st.session_state.params = {name: 0 for name in SLIDERS}
 
 
207
  for name in SLIDERS:
208
  if f"slider_{name}" not in st.session_state:
209
  st.session_state[f"slider_{name}"] = 0
@@ -263,7 +276,7 @@ if uploaded_file is not None:
263
  key=f"slider_{name}",
264
  on_change=slider_callback
265
  )
266
-
267
  # Create a single column to maximize width
268
  left_spacer, content_column, right_spacer = st.columns([1, 3, 1])
269
  with content_column:
 
3
  from PIL import Image
4
  import numpy as np
5
  from streamlit_image_comparison import image_comparison
6
+ # from src.envs.new_edit_photo import PhotoEditor
7
  from src.sac.sac_inference import InferenceAgent
8
  import yaml
9
  import os
 
15
  from bokeh.plotting import figure
16
  from bokeh.models import ColumnDataSource
17
  from bokeh.palettes import Spectral3
18
+ from src.envs.edit_photo_opt import PhotoEditor
19
  # Set page config to wide mode
20
  st.set_page_config(layout="wide")
21
 
22
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  # DEVICE = torch.device("cpu")
24
+ MODEL_PATH = os.path.join("experiments",'ResNet_10_sliders__224_128_aug__2024-07-23_21-23-35')
25
  SLIDERS = ['temp','tint','exposure', 'contrast','highlights','shadows', 'whites', 'blacks','vibrance','saturation']
26
  SLIDERS_ORD = ['contrast','exposure','temp','tint','whites','blacks','highlights','shadows','vibrance','saturation']
27
 
 
74
  input_image = image.unsqueeze(0).to(DEVICE)
75
  parameters = [params[param_name]/100.0 for param_name in SLIDERS_ORD]
76
  parameters = torch.tensor(parameters).unsqueeze(0).to(DEVICE)
77
+ if st.session_state.photopro_image is None:
78
+ enhanced_image,photopro_image = photo_editor(input_image,parameters,use_photopro_image=False)
79
+ st.session_state.photopro_image = photopro_image
80
+ else:
81
+ enhanced_image = photo_editor(st.session_state.photopro_image,parameters,use_photopro_image=True)
82
  enhanced_image = enhanced_image.squeeze(0).cpu().detach().numpy()
83
  enhanced_image = np.clip(enhanced_image, 0, 1)
84
  enhanced_image = (enhanced_image*255).astype(np.uint8)
 
109
  return output_parameters
110
 
111
  def slider_callback():
112
+ st.session_state.apply_button_enabled = True
113
+ def apply_button_callback():
114
  for name in SLIDERS:
115
  st.session_state.params[name] = st.session_state[f"slider_{name}"]
116
  image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0
117
  st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params)
118
+ st.session_state.apply_button_enabled = False
119
 
120
  def auto_random_enhance_callback():
121
  image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0
 
142
 
143
  def reset_on_upload():
144
  st.session_state.original_image = None
145
+ st.session_state.photopro_image = None
146
  reset_sliders()
147
 
148
  def create_smooth_histogram(image):
 
211
  st.session_state.enhanced_image = None
212
  if 'original_image' not in st.session_state:
213
  st.session_state.original_image = None
214
+ if 'photopro_image' not in st.session_state:
215
+ st.session_state.photopro_image = None
216
  if 'params' not in st.session_state:
217
  st.session_state.params = {name: 0 for name in SLIDERS}
218
+ if "apply_button_enabled" not in st.session_state:
219
+ st.session_state.apply_button_enabled = False
220
  for name in SLIDERS:
221
  if f"slider_{name}" not in st.session_state:
222
  st.session_state[f"slider_{name}"] = 0
 
276
  key=f"slider_{name}",
277
  on_change=slider_callback
278
  )
279
+ st.sidebar.button("Apply manual edit", on_click=apply_button_callback, key="apply_button",use_container_width=True,disabled=not st.session_state.apply_button_enabled)
280
  # Create a single column to maximize width
281
  left_spacer, content_column, right_spacer = st.columns([1, 3, 1])
282
  with content_column:
app_legacy.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ from streamlit_image_comparison import image_comparison
6
+ from src.envs.new_edit_photo import PhotoEditor
7
+ from src.sac.sac_inference import InferenceAgent
8
+ import yaml
9
+ import os
10
+ from src.envs.photo_env import PhotoEnhancementEnvTest
11
+ from tensordict import TensorDict
12
+ import torchvision.transforms.v2.functional as F
13
+ from streamlit import cache_resource
14
+ import pandas as pd
15
+ from bokeh.plotting import figure
16
+ from bokeh.models import ColumnDataSource
17
+ from bokeh.palettes import Spectral3
18
+ # Set page config to wide mode
19
+ st.set_page_config(layout="wide")
20
+
21
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ # DEVICE = torch.device("cpu")
23
+ MODEL_PATH = "experiments/ResNet_10_sliders__224_128_aug__2024-07-23_21-23-35"
24
+ SLIDERS = ['temp','tint','exposure', 'contrast','highlights','shadows', 'whites', 'blacks','vibrance','saturation']
25
+ SLIDERS_ORD = ['contrast','exposure','temp','tint','whites','blacks','highlights','shadows','vibrance','saturation']
26
+
27
+ class Config(object):
28
+ def __init__(self,dictionary):
29
+ self.__dict__.update(dictionary)
30
+
31
+ @cache_resource
32
+ def load_preprocessor_agent(preprocessor_agent_path,device):
33
+ with open(os.path.join(preprocessor_agent_path,"configs/sac_config.yaml")) as f:
34
+ sac_config_dict = yaml.load(f, Loader=yaml.FullLoader)
35
+ with open(os.path.join(preprocessor_agent_path,"configs/env_config.yaml")) as f:
36
+ env_config_dict = yaml.load(f, Loader=yaml.FullLoader)
37
+ with open(os.path.join("src/configs/inference_config.yaml")) as f:
38
+ inf_config_dict = yaml.load(f, Loader=yaml.FullLoader)
39
+
40
+ inference_config = Config(inf_config_dict)
41
+ sac_config = Config(sac_config_dict)
42
+ env_config = Config(env_config_dict)
43
+
44
+ inference_env = PhotoEnhancementEnvTest(
45
+ batch_size=env_config.train_batch_size,
46
+ imsize=env_config.imsize,
47
+ training_mode=None,
48
+ done_threshold=env_config.threshold_psnr,
49
+ edit_sliders=env_config.sliders_to_use,
50
+ features_size=env_config.features_size,
51
+ discretize=env_config.discretize,
52
+ discretize_step=env_config.discretize_step,
53
+ use_txt_features=env_config.use_txt_features if hasattr(env_config,'use_txt_features') else False,
54
+ augment_data=False,
55
+ pre_encoding_device=device,
56
+ pre_load_images=False,
57
+ logger=None
58
+ )
59
+
60
+ inference_config.device = device
61
+ preprocessor_agent = InferenceAgent(inference_env, inference_config)
62
+ preprocessor_agent.device = device
63
+ preprocessor_agent.load_backbone(os.path.join(preprocessor_agent_path,'models','backbone.pth'))
64
+ preprocessor_agent.load_actor_weights(os.path.join(preprocessor_agent_path,'models','actor_head.pth'))
65
+ preprocessor_agent.load_critics_weights(os.path.join(preprocessor_agent_path,'models','qf1_head.pth'),
66
+ os.path.join(preprocessor_agent_path,'models','qf2_head.pth'))
67
+ return preprocessor_agent
68
+
69
+ enhancer_agent = load_preprocessor_agent(MODEL_PATH,DEVICE)
70
+ photo_editor = PhotoEditor(SLIDERS)
71
+
72
+ def enhance_image(image:np.array, params:dict):
73
+ input_image = image.unsqueeze(0).to(DEVICE)
74
+ parameters = [params[param_name]/100.0 for param_name in SLIDERS_ORD]
75
+ parameters = torch.tensor(parameters).unsqueeze(0).to(DEVICE)
76
+ enhanced_image = photo_editor(input_image,parameters)
77
+ enhanced_image = enhanced_image.squeeze(0).cpu().detach().numpy()
78
+ enhanced_image = np.clip(enhanced_image, 0, 1)
79
+ enhanced_image = (enhanced_image*255).astype(np.uint8)
80
+ return enhanced_image
81
+
82
+ def auto_enhance(image,deterministic=True):
83
+ input_image = image.unsqueeze(0).to(DEVICE)
84
+ input_image = input_image.permute(0,3,1,2)
85
+ IMAGE_SIZE = enhancer_agent.env.imsize
86
+ input_image = F.resize(input_image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=F.InterpolationMode.BICUBIC)
87
+ batch_observation = TensorDict(
88
+ {
89
+ "batch_images":input_image,
90
+ },
91
+ batch_size = [input_image.shape[0]],
92
+ )
93
+ parameters = enhancer_agent.act(batch_observation,deterministic=deterministic,n_samples=0)
94
+ parameters = parameters.squeeze(0)*100.0
95
+ parameters = torch.round(parameters)
96
+ output_parameters = []
97
+ index = 0
98
+ for slider in SLIDERS_ORD:
99
+ if slider in enhancer_agent.env.edit_sliders:
100
+ output_parameters.append(parameters[index].item())
101
+ index += 1
102
+ else:
103
+ output_parameters.append(0)
104
+ return output_parameters
105
+
106
+ def slider_callback():
107
+ for name in SLIDERS:
108
+ st.session_state.params[name] = st.session_state[f"slider_{name}"]
109
+ image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0
110
+ st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params)
111
+
112
+ def auto_random_enhance_callback():
113
+ image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0
114
+ auto_params = auto_enhance(image_tensor,deterministic=False)
115
+ for i, name in enumerate(SLIDERS_ORD):
116
+ st.session_state[f"slider_{name}"] = int(auto_params[i])
117
+ st.session_state.params[name] = int(auto_params[i])
118
+ st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params)
119
+
120
+ def auto_enhance_callback():
121
+ image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0
122
+ auto_params = auto_enhance(image_tensor)
123
+ for i, name in enumerate(SLIDERS_ORD):
124
+ st.session_state[f"slider_{name}"] = int(auto_params[i])
125
+ st.session_state.params[name] = int(auto_params[i])
126
+ st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params)
127
+
128
+ def reset_sliders():
129
+ for name in SLIDERS:
130
+ st.session_state[f"slider_{name}"] = 0
131
+ st.session_state.params[name] = 0
132
+ # st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params)
133
+ st.session_state.enhanced_image = st.session_state.original_image
134
+
135
+ def reset_on_upload():
136
+ st.session_state.original_image = None
137
+ reset_sliders()
138
+
139
+ def create_smooth_histogram(image):
140
+ # Compute histograms for each channel
141
+ bins = np.linspace(0, 255, 256)
142
+ hist_r, _ = np.histogram(image[..., 0], bins=bins)
143
+ hist_g, _ = np.histogram(image[..., 1], bins=bins)
144
+ hist_b, _ = np.histogram(image[..., 2], bins=bins)
145
+
146
+ # Normalize the histograms
147
+ def normalize_histogram(hist):
148
+ hist_central = hist[1:-1]
149
+ hist_max = np.max(hist_central)
150
+ hist_min = np.min(hist_central)
151
+
152
+ hist_normalized = (hist_central - hist_min) / (hist_max - hist_min)
153
+
154
+ hist[0] = min(hist[0] / hist_max, 1)
155
+ hist[-1] = min(hist[-1] / hist_max, 1)
156
+
157
+ return np.concatenate(([hist[0]], hist_normalized, [hist[-1]]))
158
+
159
+ hist_r_norm = normalize_histogram(hist_r)
160
+ hist_g_norm = normalize_histogram(hist_g)
161
+ hist_b_norm = normalize_histogram(hist_b)
162
+
163
+ # Create Bokeh figure with transparent background
164
+ p = figure(width=300, height=150, toolbar_location=None,
165
+ x_range=(0, 255), y_range=(0, 1.1),
166
+ background_fill_color=None,
167
+ border_fill_color=None,
168
+ outline_line_color=None)
169
+
170
+ # Remove all axes, labels, and grids
171
+ p.axis.visible = False
172
+ p.xgrid.grid_line_color = None
173
+ p.ygrid.grid_line_color = None
174
+
175
+ # Create ColumnDataSource for each channel
176
+ source_r = ColumnDataSource(data=dict(left=bins[:-1], right=bins[1:], top=hist_r_norm))
177
+ source_g = ColumnDataSource(data=dict(left=bins[:-1], right=bins[1:], top=hist_g_norm))
178
+ source_b = ColumnDataSource(data=dict(left=bins[:-1], right=bins[1:], top=hist_b_norm))
179
+
180
+ # Plot the histograms
181
+ p.quad(bottom=0, top='top', left='left', right='right', source=source_r,
182
+ fill_color="red", fill_alpha=0.9, line_color=None)
183
+ p.quad(bottom=0, top='top', left='left', right='right', source=source_g,
184
+ fill_color="green", fill_alpha=0.9, line_color=None)
185
+ p.quad(bottom=0, top='top', left='left', right='right', source=source_b,
186
+ fill_color="blue", fill_alpha=0.9, line_color=None)
187
+
188
+ # Remove padding
189
+ p.min_border_left = 0
190
+ p.min_border_right = 0
191
+ p.min_border_top = 0
192
+ p.min_border_bottom = 0
193
+
194
+ return p
195
+
196
+ # In your Streamlit app
197
+ def plot_histogram_streamlit(image):
198
+ histogram = create_smooth_histogram(image)
199
+ st.sidebar.bokeh_chart(histogram, use_container_width=True)
200
+ # Initialize session state
201
+ if 'enhanced_image' not in st.session_state:
202
+ st.session_state.enhanced_image = None
203
+ if 'original_image' not in st.session_state:
204
+ st.session_state.original_image = None
205
+ if 'params' not in st.session_state:
206
+ st.session_state.params = {name: 0 for name in SLIDERS}
207
+ for name in SLIDERS:
208
+ if f"slider_{name}" not in st.session_state:
209
+ st.session_state[f"slider_{name}"] = 0
210
+
211
+ # Set up the Streamlit app
212
+ st.title("Photo Enhancement App")
213
+
214
+ # File uploader in the main area
215
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png",".tif"], on_change=reset_on_upload)
216
+
217
+ if uploaded_file is not None:
218
+ # Load the original image
219
+ st.session_state.original_image = np.array(Image.open(uploaded_file).convert('RGB'),dtype=np.uint16)
220
+
221
+ # Enhance the image initially
222
+ if st.session_state.enhanced_image is None:
223
+ st.session_state.enhanced_image = st.session_state.original_image
224
+
225
+ # Sidebar for controls
226
+ st.sidebar.title("Controls")
227
+
228
+ # Display histogram
229
+ st.sidebar.subheader("Colors Histogram")
230
+ plot_histogram_streamlit(st.session_state.enhanced_image)
231
+
232
+ # Select box to choose which image to display
233
+ display_option = st.sidebar.selectbox(
234
+ "Select view mode",
235
+ ("Comparison", "Enhanced")
236
+ )
237
+
238
+ # Create two columns for the buttons
239
+ col1, col2,col3 = st.sidebar.columns(3)
240
+
241
+ # Button for auto-enhancement
242
+ with col1:
243
+ st.button("Auto Enhance", on_click=auto_enhance_callback, key="auto_enhance_button",use_container_width=True)
244
+
245
+ with col2:
246
+ st.button("Auto Random Enhance", on_click=auto_random_enhance_callback, key="auto_random_enhance_button",use_container_width=True)
247
+ # Button for resetting sliders
248
+ with col3:
249
+ st.button("Reset", on_click=reset_sliders, key="reset_button",use_container_width=True)
250
+
251
+ st.sidebar.subheader("Adjustments")
252
+ slider_names = SLIDERS
253
+
254
+ for name in slider_names:
255
+ if f"slider_{name}" not in st.session_state:
256
+ st.session_state[f"slider_{name}"] = 0
257
+
258
+ st.sidebar.slider(
259
+ name.capitalize(),
260
+ min_value=-100,
261
+ max_value=100,
262
+ value=st.session_state[f"slider_{name}"],
263
+ key=f"slider_{name}",
264
+ on_change=slider_callback
265
+ )
266
+
267
+ # Create a single column to maximize width
268
+ left_spacer, content_column, right_spacer = st.columns([1, 3, 1])
269
+ with content_column:
270
+ if display_option == "Enhanced":
271
+ if st.session_state.enhanced_image is not None:
272
+ st.image(st.session_state.enhanced_image.astype(np.uint8), caption="Enhanced Image", use_column_width=True)
273
+ else:
274
+ st.warning("Enhanced image is not available. Try adjusting the sliders or clicking 'Auto Enhance'.")
275
+ else: # Comparison view
276
+ if st.session_state.enhanced_image is not None:
277
+ image_comparison(
278
+ img1=Image.fromarray(st.session_state.original_image.astype(np.uint8)),
279
+ img2=Image.fromarray(st.session_state.enhanced_image.astype(np.uint8)),
280
+ label1="Original",
281
+ label2="Enhanced",
282
+ width=850, # You might want to adjust this value
283
+ starting_position=50,
284
+ show_labels=True,
285
+ make_responsive=True,
286
+ )
287
+ else:
288
+ st.warning("Enhanced image is not available for comparison. Try adjusting the sliders or clicking 'Auto Enhance'.")
289
+
290
+ # Add custom CSS to make the image comparison component responsive
291
+ st.markdown("""
292
+ <style>
293
+ .stImageComparison {
294
+ width: 100% !important;
295
+ }
296
+ .stImageComparison > figure > div {
297
+ width: 100% !important;
298
+ }
299
+ </style>
300
+ """, unsafe_allow_html=True)
demo.py CHANGED
@@ -109,10 +109,13 @@ def auto_enhance(image,deterministic=True):
109
  return output_parameters
110
 
111
  def slider_callback():
 
 
112
  for name in SLIDERS:
113
  st.session_state.params[name] = st.session_state[f"slider_{name}"]
114
  image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0
115
  st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params)
 
116
 
117
  def auto_random_enhance_callback():
118
  image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0
@@ -212,6 +215,8 @@ if 'photopro_image' not in st.session_state:
212
  st.session_state.photopro_image = None
213
  if 'params' not in st.session_state:
214
  st.session_state.params = {name: 0 for name in SLIDERS}
 
 
215
  for name in SLIDERS:
216
  if f"slider_{name}" not in st.session_state:
217
  st.session_state[f"slider_{name}"] = 0
@@ -271,7 +276,7 @@ if uploaded_file is not None:
271
  key=f"slider_{name}",
272
  on_change=slider_callback
273
  )
274
-
275
  # Create a single column to maximize width
276
  left_spacer, content_column, right_spacer = st.columns([1, 3, 1])
277
  with content_column:
 
109
  return output_parameters
110
 
111
  def slider_callback():
112
+ st.session_state.apply_button_enabled = True
113
+ def apply_button_callback():
114
  for name in SLIDERS:
115
  st.session_state.params[name] = st.session_state[f"slider_{name}"]
116
  image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0
117
  st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params)
118
+ st.session_state.apply_button_enabled = False
119
 
120
  def auto_random_enhance_callback():
121
  image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0
 
215
  st.session_state.photopro_image = None
216
  if 'params' not in st.session_state:
217
  st.session_state.params = {name: 0 for name in SLIDERS}
218
+ if "apply_button_enabled" not in st.session_state:
219
+ st.session_state.apply_button_enabled = False
220
  for name in SLIDERS:
221
  if f"slider_{name}" not in st.session_state:
222
  st.session_state[f"slider_{name}"] = 0
 
276
  key=f"slider_{name}",
277
  on_change=slider_callback
278
  )
279
+ st.sidebar.button("Apply manual edit", on_click=apply_button_callback, key="apply_button",use_container_width=True,disabled=not st.session_state.apply_button_enabled)
280
  # Create a single column to maximize width
281
  left_spacer, content_column, right_spacer = st.columns([1, 3, 1])
282
  with content_column: