Spaces:
Running
on
Zero
Running
on
Zero
Sorting modalities in generate_output() backend.py for consistent generations
Browse files- src/backend.py +20 -29
src/backend.py
CHANGED
|
@@ -242,25 +242,27 @@ def generate_output(s2l1c_input, s2l2a_input, s1rtc_input, dem_input, num_infere
|
|
| 242 |
gr.Warning("You need to remove some of the inputs that you would like to generate. If all modalities are known, there is nothing to generate.")
|
| 243 |
return s2l1c_input, s2l2a_input, s1rtc_input, dem_input
|
| 244 |
|
| 245 |
-
|
| 246 |
-
|
| 247 |
if s2l1c_active:
|
| 248 |
-
|
| 249 |
-
condition_modalities.append('s2_l1c')
|
| 250 |
if s2l2a_active:
|
| 251 |
-
|
| 252 |
-
condition_modalities.append('s2_l2a')
|
| 253 |
if s1rtc_active:
|
| 254 |
-
|
| 255 |
-
condition_modalities.append('s1_rtc')
|
| 256 |
if dem_active:
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
imgs_out = custom_inference(
|
| 261 |
-
images=
|
| 262 |
generate_modalities=[el for el in ['s2_l1c', 's2_l2a', 's1_rtc', 'dem'] if el not in condition_modalities],
|
| 263 |
-
condition_modalities=
|
| 264 |
num_inference_steps=num_inference_steps_slider,
|
| 265 |
seed=seed
|
| 266 |
)
|
|
@@ -268,22 +270,11 @@ def generate_output(s2l1c_input, s2l2a_input, s1rtc_input, dem_input, num_infere
|
|
| 268 |
output = []
|
| 269 |
|
| 270 |
# Collect outputs
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
output.append(s2l2a_input)
|
| 277 |
-
else:
|
| 278 |
-
output.append(to_PIL(imgs_out['s2_l2a'][0]))
|
| 279 |
-
if s1rtc_active:
|
| 280 |
-
output.append(s1rtc_input)
|
| 281 |
-
else:
|
| 282 |
-
output.append(to_PIL(imgs_out['s1_rtc'][0]))
|
| 283 |
-
if dem_active:
|
| 284 |
-
output.append(dem_input)
|
| 285 |
-
else:
|
| 286 |
-
output.append(to_PIL(imgs_out['dem'][0]))
|
| 287 |
|
| 288 |
return output
|
| 289 |
|
|
|
|
| 242 |
gr.Warning("You need to remove some of the inputs that you would like to generate. If all modalities are known, there is nothing to generate.")
|
| 243 |
return s2l1c_input, s2l2a_input, s1rtc_input, dem_input
|
| 244 |
|
| 245 |
+
# Instead of collecting in UI order, create ordered dictionaries
|
| 246 |
+
input_images = {}
|
| 247 |
if s2l1c_active:
|
| 248 |
+
input_images['s2_l1c'] = s2l1c_input
|
|
|
|
| 249 |
if s2l2a_active:
|
| 250 |
+
input_images['s2_l2a'] = s2l2a_input
|
|
|
|
| 251 |
if s1rtc_active:
|
| 252 |
+
input_images['s1_rtc'] = s1rtc_input
|
|
|
|
| 253 |
if dem_active:
|
| 254 |
+
input_images['dem'] = dem_input
|
| 255 |
+
|
| 256 |
+
condition_modalities = list(input_images.keys())
|
| 257 |
+
|
| 258 |
+
# Sort modalities and collect images in the same order
|
| 259 |
+
sorted_modalities = sorted(condition_modalities, key=lambda x: ['dem', 's1_rtc', 's2_l1c', 's2_l2a'].index(x))
|
| 260 |
+
sorted_images = [input_images[mod] for mod in sorted_modalities]
|
| 261 |
+
|
| 262 |
imgs_out = custom_inference(
|
| 263 |
+
images=sorted_images,
|
| 264 |
generate_modalities=[el for el in ['s2_l1c', 's2_l2a', 's1_rtc', 'dem'] if el not in condition_modalities],
|
| 265 |
+
condition_modalities=sorted_modalities,
|
| 266 |
num_inference_steps=num_inference_steps_slider,
|
| 267 |
seed=seed
|
| 268 |
)
|
|
|
|
| 270 |
output = []
|
| 271 |
|
| 272 |
# Collect outputs
|
| 273 |
+
for modality in sorted_modalities:
|
| 274 |
+
if modality in input_images:
|
| 275 |
+
output.append(input_images[modality])
|
| 276 |
+
else:
|
| 277 |
+
output.append(to_PIL(imgs_out[modality][0]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
return output
|
| 280 |
|