A natural sounding, conversational text-to-speech model
In this blog post, I describe my work with Ge Zhu and Professor Zhiyao Duan to develop an initial version of a text-to-speech (TTS) model we call Parakeet. The research presented in this post was completed early last fall and was supported by Google’s TPU Research Cloud (TRC) program. This project would not have been possible without their immense generosity.
Parakeet takes in as input a text prompt, optionally containing multiple speakers or non-verbal events like “laughter,” and outputs up to 30 seconds of corresponding audio. In designing Parakeet, we had two main goals in mind for our model:
A brief overview of our methodology is as follows:
We plan to release our fine-tuned whisper models and possibly the generative model (and/or future improved versions). The generative model would have to be released under a non-commercial license due to our datasets.
This project is a work in progress, but below are samples from our model, followed by a more detailed methodology. These samples are cherry-picked (in general, best-of-generally-between-4-and-16-samples). We hope to eventually improve our model so that our outputs are more consistently high-quality.
Parakeet is able to generate conversational audio (right column) given a text prompt (left column). As seen in text prompts below, [S1], [S2], etc., are used to denote different speakers. Non-verbal events such as “laughs” can be specified in parentheses.
Text | Audio |
---|---|
Text prompt: [S1] What's sort of cool, is that, uh, you can produce coughs if you have to. [S2] What do you mean? [S1] Well (coughs), there, I just coughed. | |
Text prompt: [S1] Something really funny happened to me this morning. [S2] Oh wow, what? [S1] Well, uh, I woke up as usual- [S2] Mm-hmm. [S1] … went downstairs to have uh breakfast- [S2] Yeah. [S1] … started eating. Then, uh, 10 minutes later I realized it was the middle of the night. [S2] Oh no way, (laughs) that's so funny. | |
Text prompt (different seeds): [S1] Okay, now uh, I'm going to demo this text to speech system. [S2] Mm-hmm. [S1] The speech you're hearing right now, uh, isn't real. It's- it's fake. [S2] No, that's hard to believe- [S1] I'm serious. | |
Our model is able to perform zero-shot voice cloning simply by prefilling the decoder with an audio prompt and setting the text condition to be the concatenation of text from the audio prompt and the desired text.
In the below table for each sample, the audio column contains the audio prompt and the full audio continuation (including the prompt).
Text | Audio (prompt / continuation) |
---|---|
2-second cloning example: [S1] ... Gapes, that is the first symptom. This part is generated. Uh, everything after two- two seconds was generated. (prompt is underlined) | |
Whispering example (prompt text not included): April fools. You don't need to run. I was just kidding. | |
Emotional example (prompt text not included): I'm just so angry right now! | |
Multilingual example: [S1] (speaking in foreign language) | Let’s go to the beach now. | |
Our dataset is comprised of three different sources, Spotify Podcast Dataset, LibriVox, and Common Voice.
Spotify Podcast Dataset consists of over 100,000 podcast episodes totaling approximately 60,000 hours of audio (the dataset is no longer maintained by Spotify as of December 2023). Though it provides transcriptions, they are machine-generated (and not of great quality) and unsuitable for training a generative model. We first split the podcast into segments of up to 30 seconds in length using pyannote
Ideally, our transcriptions would contain both speaker labels as well as events (i.e. “[S1] Hey! [S2] (sighs) Um, how’s it going?”). We initially tried prompting Whisper
The initial version of WhisperD was created through a two-stage fine-tuning process, where we first fine-tune on lower-quality automatic transcriptions and then fine-tune again on higher-quality human data. Our reasoning was that human-annotated transcriptions are expensive to obtain, and that a larger dataset of automatic transcriptions could be inexpensively used to roughly calibrate Whisper to perform speaker annotation before a higher-quality fine-tuning pass. We used online transcription services for both the automatic and human transcriptions to transcribe random subsets of our podcast data. The automatic portion consisted of ~20 hours, while the human portion consisted of slightly under 2 hours.
After the fine-tuning process, we’re able to generate transcriptions such as:
Spotify Podcast Audio | WhisperD Transcription |
---|---|
| [S1] ... this very important that, um, your spouse is, is supportive and, and willing to sacrifice because it is a sacrifice to be a, a coach's wife or husband. And so, um, you know, I heard coaches say that when I was young, you know, before I was married, that, you know, your wife is very important because, you know |
| [S1] But, you know, what can you do in a country town? You don't have a lot of choices. You just, eh. [S2] (laughs) Sorry, I just made like a royal mess of new desk, but anyway. [S1] It's not new, it's all right. [S2] Okay, that's all right. Um, so- [S1] But yes. [S2] Oh, no. [S1] So, I mean, that's- |
We use this model to generate transcriptions for all of the VAD-split podcast segments. We also trained what we hoped would be a better version of WhisperD with a more sophisticated method involving bootstrapping, though meaningful evaluation turned out to be challenging (manual inspection seemed to be more informative than many of the eval methods we devised). We also generate transcriptions using this newer WhisperD and randomly sample which transcription to use when training our generative model later.
Lastly, we train two additional WhisperD models and use them to transcribe subsets of our dataset:
During generative training later, with some relatively small probability, transcriptions are replaced with either fluent or fuzzy versions.
LibriVox is an archive of public domain audiobooks. We take ~30,000 hours, and we attempt to balance speaker durations during sampling. We again perform VAD to split into 30 second (maximum) segments and use Whisper-v2-medium to generate transcriptions. Interestingly, we find that the pretrained Whisper models omit text such as “This is a LibriVox recording…,” (this is likely due to Whisper training on these samples and using ground-truth transcriptions that only contain book text) so we very briefly fine-tune Whisper-medium with examples containing these sorts of phrases.
We also use the English subset of Common Voice 14
Rather than predicting raw audio directly, we instead predict latent audio codes / tokens, using the pre-trained Descript-Audio-Codec (DAC) as our autoencoder. DAC is a residual-quantized VAE
We now train a model to autoregressively predict these DAC tokens, conditioned on raw text. In the model featured in this blog post, we use an encoder-decoder transformer, though we’ve also experimented with decoder-only models.
Given that there are 9 residual codebooks, we need to make a decision on how to predict the 86 x 9 (= 774) codes per second. Though flattening would be an option, this would leave us with 30-second sequence lengths of 23,220 (the ratio of attention flops to MLP flops would be high with a vanilla setup given relatively small model size), and the model might be inclined to spend more-than-optimal compute on predicting less important residual codebooks (though this depends on training dynamics, as it’s possible the model will learn to “think ahead” in the residual streams of less-important codebook levels).
There are two more practical options with which we experimented.
We found that delay pattern prediction seemed to match hierarchical transformer setups in performance. We also tried combining delay-pattern with hierarchical prediction, which we found improved results over hierarchical prediction without any codebook delay but still was not obviously superior to the non-hierarchical approach. Given this, we choose to train a non-hierarchical model with a delay pattern. Note that if the conditional independence assumed by the delay-pattern proves to be a bottleneck in terms of quality, we could convert our model to a hierarchical one by training a small transformer to replace the linear output projections.
We train a 3B parameter encoder-decoder transformer model. The encoder’s input is raw bytes of text, which closely corresponds to English characters, as the majority of our data is English aside from a small portion of podcast data.
The encoder (text) sequence length is 768, while the decoder (DAC token) length is 2048. To enable classifier-free guidance (see later section) at inference, we drop out the text condition for 15% of samples during training.
The encoder has model size 1536 with 12 layers and 16 heads. The decoder has model size 2560 with 32 layers and uses GQA, with 32 query heads and 8 KV heads. The model uses rotary positional embeddings in self-attention modules and SwiGLU activation. We train our model with batch size 256 for 110,000 steps with a cosine learning rate decay, with peak learning rate 2e-4. We train on a v3-256 TPU pod, provided graciously by TPU research cloud, implementing parallel training in JAX
After training, we fine-tune the model on a small higher-quality subset of our dataset (we filter using PESQ
First introduced in diffusion models
CFG in autoregressive models involves training both a conditional model \(P_c(x_t \vert x_{1,\ldots,t-1}, c)\), where c represents the text condition, and an unconditional model \(P_u(x_t \vert x_{1,\ldots,t-1})\). In practice the unconditional model is learned by dropping out the text condition for a small portion of batch samples during training. Then, when sampling a particular code, given conditional logits \(l_c\) and unconditional logits \(l_u\), the final logits \(l_{cfg}\) are obtained with:
\[l_{cfg} = l_c + \alpha (l_c - l_u)\]where \(\alpha\) is a hyperparameter scalar controlling the degree of guidance. After CFG, we can apply top-k or top-p sampling; we choose top-k with k = 50.
When we apply CFG to Parakeet sampling, quality is significantly improved. However, on inspecting generations, there tends to be a dramatic speed-up over the duration of the sample (i.e. the rate of speaking increases significantly over time). Our intuition for this problem is as follows: Say that is our model is (at some level) predicting phonemes and the ground truth distribution for the next phoneme occuring is 25% at a given timestep. Our conditional model may predict 20%, but because our uncondtional model cannot see the text transcription, its prediction for the correct next phoneme will be much lower, say 5%. With a reasonable level of CFG, because \((l_c - l_u)\) will be large for the correct next phoneme, we’ll obtain a much higher final probability, say 50%, which biases our generation towards faster speech. And this effect compounds! After an initial speed-up, a good model would now “think” it’s predicting for a fast speaker, so the conditional model will predict even higher probabilities for the next phoneme occuring soon, leading to an even larger bias after CFG is applied.
Here is an audio sample with CFG (with \(\alpha = 3\)), and for reference a sample without any CFG.
The quality and text-alignment of the CFG sample is higher (and the difference is often larger than the above examples), but the speed-up is a significant issue.
To address this, we introduce CFG-filter, a modification to CFG that mitigates the speed drift. The idea is to first apply the CFG calculation to obtain a new set of logits \(l_{cfg}\) as before, but rather than use these logits to sample, we use these logits to obtain a top-k mask to apply to our original conditional logits. Intuitively, this serves to constrict the space of possible “phonemes” to text-aligned phonemes without heavily biasing the relative probabilities of these phonemes (or for example, start next word vs pause more).
Here is the same text prompt with CFG-filter:
The pseudocode is as follows:
def mask_by_top_k(l_c: jax.Array, cfg_logits: jax.Array, k: int) -> jax.Array:
# returns x where x[i] = l_c[i] if cfg_logits[i] in top k elements of cfg_logits else -inf
def cfg_filter(l_c: jax.Array, l_u: jax.Array, alpha: float, k: int):
cfg_logits = l_c + alpha * (l_c - l_u)
sample_logits = mask_by_top_k(l_c, cfg_logits, k)
return sample_logits
Note that we can optionally reapply a smaller level of CFG after masking (i.e. rather than returning sample logits as is, we could set \(\text{new_sample_logits} = \text{sample_logits} + \beta (\text{sample_logits} - l_u)\), where intuitively we should have $\beta < \alpha$. We can also again apply top-k or top-p sampling before returning the final sample_logits.
A reasonable next step for improving autoregressive versions of Parakeet would be to test replacing DAC with a different autoencoder. It’s possible that given DAC’s relatively high number of codes per second (86 base codes per second, and 774 codes including residual levels), it may not be particularly conducive to autoregressive generation. Training a new RQ-VAE with some combination of a higher downsampling rate and less residual levels (with a possibly larger vocab size) would be a sensible starting point. It also might be exploring FSQ
Though higher autoencoder compression might result in worse decoder reconstruction quality, one could train a more powerful decoder (for example, diffusion) to mitigate this. However, a caveat is that ideally the latents / codes should contain most of the perceptible audio information, as leaving decision making / ambiguitity to the decoder might complicate audio prompt continuation.
Some of our more recent work has involved transitioning from autoregression to diffusion. A downside of the current autoregression approach is that mistakes in the sampling process cannot be corrected. Diffusion allows (hand-wavingly) for such correction, but a vanilla approach would involve generating 30 second segments regardless of text prompt length or training a duration predictor, both of which have disadvantages. Additionally, for text-to-speech, the approach of generating entire long segments at once may have theoretical downsides having to do with a certain time-wise asymmetry (e.g. it’s easier to predict the phoneme occuring at the 2-second-mark than the 28-second-mark); a more thorough exploration is left to future blog posts.
Some of our work on diffusion involved splitting audio latent segments into “blocks” to allow for autoregressive sampling of blocks, and concurrently Aran Komatsuzaki proposed a particularly elegant block-wise approach
We hope to make more progress on this and share our work in the future.
While Parakeet could be adopted for AI-Human conversational TTS, a much more important project would involve a fully-streaming TTS, a project to which I have given some thought since the development of our initial Parakeet model. Rather than current conversational TTS systems, which involve a human talking, the AI waiting for the human to finish talking, the AI responding, etc., a fully-streaming TTS would involve a constant stream of AI audio output, which could include laughter, interruptions, etc. A well-performing streaming TTS model might be able to pass the audio version of a Turing test: a human might not be able to tell whether they’re speaking to another human or an AI system. Though it would require significant engineering work, I believe this project is relatively low-hanging fruit (with a high ratio of impressiveness to engineering capital). I’ve been thinking about different approaches and am excited to see what the future holds in this area.
To cite this blog post, please use:
bibtex @misc{darefsky2024parakeet, author = {Darefsky, Jordan and Zhu, Ge and Duan, Zhiyao}, title = {Parakeet}, year = {2024}, url = {https://jordandarefsky.com/blog/2024/parakeet/} }