Spaces:
Running
on
Zero
Running
on
Zero
seq length
Browse files- app.py +25 -0
- generator.py +4 -0
app.py
CHANGED
|
@@ -112,6 +112,29 @@ def infer(
|
|
| 112 |
audio_prompt_speaker_a,
|
| 113 |
audio_prompt_speaker_b,
|
| 114 |
gen_conversation_input,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
) -> tuple[np.ndarray, int]:
|
| 116 |
audio_prompt_a = prepare_prompt(text_prompt_speaker_a, 0, audio_prompt_speaker_a)
|
| 117 |
audio_prompt_b = prepare_prompt(text_prompt_speaker_b, 1, audio_prompt_speaker_b)
|
|
@@ -128,6 +151,7 @@ def infer(
|
|
| 128 |
text=line,
|
| 129 |
speaker=speaker_id,
|
| 130 |
context=prompt_segments + generated_segments,
|
|
|
|
| 131 |
)
|
| 132 |
generated_segments.append(Segment(text=line, speaker=speaker_id, audio=audio_tensor))
|
| 133 |
|
|
@@ -215,6 +239,7 @@ with gr.Blocks() as app:
|
|
| 215 |
|
| 216 |
gen_conversation_input = gr.TextArea(label="conversation", lines=20, value=DEFAULT_CONVERSATION)
|
| 217 |
generate_btn = gr.Button("Generate conversation", variant="primary")
|
|
|
|
| 218 |
audio_output = gr.Audio(label="Synthesized audio")
|
| 219 |
|
| 220 |
generate_btn.click(
|
|
|
|
| 112 |
audio_prompt_speaker_a,
|
| 113 |
audio_prompt_speaker_b,
|
| 114 |
gen_conversation_input,
|
| 115 |
+
) -> tuple[np.ndarray, int]:
|
| 116 |
+
# Estimate token limit, otherwise failure might happen after many utterances have been generated.
|
| 117 |
+
if len(gen_conversation_input.strip() + text_prompt_speaker_a.strip() + text_prompt_speaker_b.strip()) >= 2000:
|
| 118 |
+
raise gr.Error("Prompts and conversation too long.", duration=30)
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
return _infer(
|
| 122 |
+
text_prompt_speaker_a,
|
| 123 |
+
text_prompt_speaker_b,
|
| 124 |
+
audio_prompt_speaker_a,
|
| 125 |
+
audio_prompt_speaker_b,
|
| 126 |
+
gen_conversation_input,
|
| 127 |
+
)
|
| 128 |
+
except ValueError as e:
|
| 129 |
+
raise gr.Error(f"Error generating audio: {e}", duration=120)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _infer(
|
| 133 |
+
text_prompt_speaker_a,
|
| 134 |
+
text_prompt_speaker_b,
|
| 135 |
+
audio_prompt_speaker_a,
|
| 136 |
+
audio_prompt_speaker_b,
|
| 137 |
+
gen_conversation_input,
|
| 138 |
) -> tuple[np.ndarray, int]:
|
| 139 |
audio_prompt_a = prepare_prompt(text_prompt_speaker_a, 0, audio_prompt_speaker_a)
|
| 140 |
audio_prompt_b = prepare_prompt(text_prompt_speaker_b, 1, audio_prompt_speaker_b)
|
|
|
|
| 151 |
text=line,
|
| 152 |
speaker=speaker_id,
|
| 153 |
context=prompt_segments + generated_segments,
|
| 154 |
+
max_audio_length_ms=30_000,
|
| 155 |
)
|
| 156 |
generated_segments.append(Segment(text=line, speaker=speaker_id, audio=audio_tensor))
|
| 157 |
|
|
|
|
| 239 |
|
| 240 |
gen_conversation_input = gr.TextArea(label="conversation", lines=20, value=DEFAULT_CONVERSATION)
|
| 241 |
generate_btn = gr.Button("Generate conversation", variant="primary")
|
| 242 |
+
gr.Markdown("GPU time limited to 3 minutes, for longer usage duplicate the space.")
|
| 243 |
audio_output = gr.Audio(label="Synthesized audio")
|
| 244 |
|
| 245 |
generate_btn.click(
|
generator.py
CHANGED
|
@@ -137,6 +137,10 @@ class Generator:
|
|
| 137 |
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
|
| 138 |
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
for _ in range(max_audio_frames):
|
| 141 |
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
|
| 142 |
if torch.all(sample == 0):
|
|
|
|
| 137 |
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
|
| 138 |
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
|
| 139 |
|
| 140 |
+
max_seq_len = 2048 - max_audio_frames
|
| 141 |
+
if curr_tokens.size(1) >= max_seq_len:
|
| 142 |
+
raise ValueError(f"Inputs too long, must be below max_seq_len - max_audio_frames: {max_seq_len}")
|
| 143 |
+
|
| 144 |
for _ in range(max_audio_frames):
|
| 145 |
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
|
| 146 |
if torch.all(sample == 0):
|