Automatic music transcription is the process of converting audio files like MP3 and WAV into sheet music, guitar tablature, and any format a musician may want to learn a song on their instrument.
We’ll go over the best current tools for doing this, which happen to be deep learning-based, and a novel approach for it.
The current state-of-the-art for this task comes from Magenta, an open-source research project developed by the now defunct (as of April 2023) Google Brain Team.
They released a paper Sequence-to-Sequence Piano Transcription with Transformers in 2021 which used a T5-inspired transformer model (similar to “t5-small”) with 54 million parameters and the Maestro dataset, achieving great results. The problem is approached as a sequence-to-sequence task using an encoder-decoder Transformer architecture. The encoder processes mel spectrogram frames as input and produces embeddings, while the decoder uses these embeddings via cross-attention to autoregressively generate a sequence of MIDI-like tokens. Their vocabulary consisted of four types of tokens:
- Note tokens (128 values for MIDI pitches)
- Velocity tokens (128 values including zero for note-off)
- Time tokens (6,000 values in 10ms bins for absolute timing)
- EOS token (to mark sequence end)
See the image below for a visualisation of the architecture and an example sequence of their custom MIDI tokens:
Our model is a generic encoder-decoder Transformer architecture where each input position contains a single spectrogram frame and each output position contains an event from our MIDI-like vocabulary. Outputs tokens are autoregressively sampled from the decoder, at each step taking the token with maximum probability.
In 2022, they released a paper, MT3: Multi-Task Multitrack Music Transcription. This experiment used the same approach as the last one but added additional instrument tokens to represent the different instruments. Again, they used a similar T5 model and achieved great performance against many of the datasets trained on, notably Slakh, Maestro and MusicNet.
MR-MT3 was released the following year as a slight improvement to MT3.
Compute/GPU resources
Huge resources were needed to train this from scratch, despite being much smaller in size compared to even the smallest language models. The 2021 paper noted:
“We trained all models on 32 TPUv3 cores, resulting in a per-core batch size of 8. Based on validation set results, overfitting did not seem to be a problem, so we allowed training to progress for 400K steps, which took about 2.5 days for our baseline models.”
The MT3 paper doesn’t provide as specific details on training, stating they train for 1 million steps.
Other limitations
These models have some inherent limitations in their output flexibility. While language models typically have large vocabularies (often 30,000+ tokens) that are extensively pre-trained on diverse natural language data, MT3 and similar music transcription models use a much smaller, specialised token vocabulary (only a few thousand tokens) focused solely on musical events. This specialisation means that adding new tokens, such as for new instruments or playing techniques like palm muting on guitars or pizzicato on violins, is likely not easy — it requires significant retraining to integrate these new tokens effectively with the existing vocabulary, and often requires substantial training data demonstrating these techniques. This differs from large language models which can often describe such musical nuances in natural language without modification, as they’ve encountered these concepts during their broad pre-training.
Transfer learning and zero-shot
We can leverage transfer learning from large open-source pre-trained audio and language models. Examples of music generation models include OpenAI’s Jukebox and Meta’s MusicGen.
GPT-4o is designed to handle text, audio and images “natively”. Although OpenAI has not released the technical details on this, it’s assumed that some weights in the network will process all modalities. It’s possible that the model uses a decoder-only architecture like language only GPT models without the need for encoder components to convert different modalities to a dense representation first. This design allows the model to seamlessly process and interpret inputs like text and images together, potentially offering performance benefits both computationally and in terms of model understanding.
Many multi-modal models take a simpler approach reminiscent of the encoder-decoder architecture: they combine two pre-trained models — an encoder for the specific input modality (like ViT for vision or an audio encoder for sound) and a Large Language Model (such as LLaMA, Gemma, or Qwen). These models are connected through projection layers that align their representations in a shared latent space, often using just a single linear layer. These projection layers learn to convert the encoder’s output into a format that matches the LLM’s expected input dimensions and characteristics. The projection creates new embeddings/tokens from the input modality that can then be injected into the LLM’s input sequence. LLaVA is a prime example of this architecture for vision-language tasks, while Spotify’s Llark and Qwen-Audio apply the same principle using audio encoders instead of vision encoders.
Here’s some pseudocode on how the models are stitched together:
# Extract features from final layer of audio encoder
# Shape: [batch_size, audio_seq_len, encoder_dim=1024]
audio_features = audio_model(audio_input)# Project audio features to match LLM's embedding dimension
# Shape: [batch_size, audio_seq_len, llm_embed_dim=4096]
audio_embeddings = projection_layer(audio_features)
# Get text embeddings from LLM's embedding layer
# Shape: [batch_size, text_seq_len, llm_embed_dim=4096]
text_embeddings = llm.embed_text(text_input)
# Concatenate along sequence length dimension
# Shape: [batch_size, audio_seq_len + text_seq_len, llm_embed_dim=4096]
combined_input = concatenate([audio_embeddings, text_embeddings], dim=1)
# Feed them into the LLM as normal for generation
output = llm(combined_input)
Overview of architecture
Llark uses OpenAI’s Jukebox and Qwen2-Audio uses OpenAI’s Whisper for the audio towers. Jukebox is a music generation model but it can also take in audio clips as input and outputs a continuation of the audio clip. Whisper is used for transcribing voice to text.
Given their purpose, the choice of audio module is clear: Llark specialises in music analysis, while Qwen2Audio primarily focuses on responding to voice instructions with some basic audio and music analysis capabilities.
Determining the optimal source for extracting embeddings from large pre-trained models involves research and experimentation. Additionally, deciding whether to fine-tune the entire module or freeze parts of it is a crucial design choice. For instance, LlaVa’s training strategy involves freezing the vision tower and focusing on fine-tuning the projection layer and language model. We’ll go over this aspect of each model below.
Llark: why Jukebox? Are these embeddings the best as of September 2024?
Determining the optimal location to extract embeddings from large models typically requires extensive probing. This involves testing various activations or extracted layers of the model on different classification tasks through a process of trial and error. For music generation models, this could include tasks like genre recognition, instrument detection, emotion detection, as well as analysis of harmonic structures and temporal patterns. Many commercial embedding models (like OpenAI’s embedding models) are trained specifically for embedding generation with specialised architectures and training objectives, rather than being fine-tuned versions of existing language models.
The two largest publicly available music generation and music continuation (i.e.: able to take in audio as input) models are Jukebox and MusicGen. MusicGen is newer and faster, and therefore seemed like it would be the obvious choice to me. However, according to this paper on probing MusicGen, embeddings extracted from Jukebox appear to outperform MusicGen on average in classification tasks. The findings from this paper led to the authors of Llark using the following approach for extracting embeddings:
- Embeddings are derived from the output of the 36th layer of the Jukebox encoder following the approach described in Castellon et al. (2021)
- Original Jukebox encoding:
* 4800-dimensional vectors at 345Hz
* For a 25s clip: over 4.14 * 10⁷ floating-point values - The authors use a downsampling approach: Mean-pooling within 100ms frames, resulting in:
* Downsampled frequency: 10Hz
* Embedding size: 1.2 × 10⁶ for a 25s audio clip. That means a 2D array with shape [240, 4800].
* Retains temporal information (unlike Castellon et al. who average over the time dimension)
(The downsampled embedding size is approximately 6x larger than CLIP ViT-L14 models used in many multimodal vision models)
Qwen2Audio: Whisper
The embedding extraction for Qwen2Audio isn’t mentioned in detail in the paper. Whisper is an encoder-decoder architecture where the encoder generates deeply learned representations of the audio and the decoder decodes the representations to text (the transcription). In Qwen2Audio, it appears they extract embeddings from the final layer of Whisper’s encoder, although they don’t mention whether they freeze it during training.
Pre-trained weights, training data and datasets
Unfortunately Spotify has not provided any datasets or their trained model weights to the public, noting:
“With respect to inputs: the inputs to our model are public, open-source, Creative Commons-licensed audio and associated annotations. However, each individual audio file can have its own, potentially more restrictive license. Many of the audio files include “no derivatives” licenses. We encourage users of the datasets to familiarize themselves with the restrictions of these licenses; in order to honor such licenses, we do not release any derivatives from the training data in this paper (including query- response pairs or trained model weights).”
They used the following datasets:
- MusicCaps (Agostinelli et al., 2023)
- YouTube8M-MusicTextClips (McKee et al., 2023)
- MusicNet (Thickstun et al., 2017)
- FMA (Defferrard et al., 2017)
- MTG-Jamendo (Bogdanov et al., 2019)
- MagnaTagATune (Law et al., 2009)
Llark details it’s training data generation process in the following extract:
“We use variants of ChatGPT to extract the instruction- tuning data for all experiments. However, the exact language model used varies by dataset. We select the OpenAI model as follows: We use GPT-4 for all reasoning tasks. We found that GPT-4 was much more adept at following the complex instructions in the Reasoning task family. For datasets with more than 25k samples, we limit Reasoning data to a random subsample of 25k tracks.”
This results in Q&A data like this:
The datasets used for training Qwen2Audio are not shared either, but the trained model is widely available and also is implemented in the transformers
library:
For this project, fine-tuning off a pre-trained Llark model would have been optimal, given it’s reportedly good performance against the evaluation benchmarks Spotify stated in the paper.
However, given they didn’t release the weights for it, it’s unfeasible to start training a model like this from scratch without a fair bit of expertise and money. Spotify trained it on:
Our model is trained on 4 80GB NVIDIA A100 GPUs. Training takes approximately 54 hours.
This would cost around $700 using a provider like LambdaLabs.
Because of the above, I went with Qwen. However, Qwen2-Audio doesn’t perform that well across basic music tasks like tempo and instrument detection. I detail this below in the evaluation section. This means that the model is probably not large enough or pre-trained enough to achieve this task, but my hope is I could at least set a starting point and framework for fine-tuning on this task in the future. As Alibaba state in their Qwen2-Audio blog post:
We also plan to build larger Qwen2-Audio models to explore the scaling laws of audio language models.
For my own learning though, I did have a go at re-creating the model using torch
and pre-trained models with the transformers
library.
I also created datasets for Q&A data and embeddings. I generated short form Q&A data for the URMP dataset, e.g.: “What is the tempo of this track”, “What instruments are playing in this audio”.
Here’s a notebook for running Jukebox in a Colab environment to take advantage of the cheap T4 GPU’s. I uploaded both Q&A and embeddings datasets to HuggingFace here.
Here’s a notebook with Llark replicated.
Transcription format
I chose ABC music notation as the output format that the language model is expected to transcribe the music in. Here’s an example of it:
X:1
M:4/4
L:1/16
K:none
Q:67V:1 name="Electric Bass (finger)"
%%octave-default C4
GAA^2E3A2<A^2 | D^D^2E2A2A^4 A^2E2 | A2A^4A^2E2 A2A^4 | A^2E2A2A^4A^2E2A2 |
A^4 A^2E2 A2A^4A^2 E2 | A2A^4 |
V:2 name="Bright Acoustic Piano"
%%octave-default C5
[E3C3][E3C3][E3C3] [E3C3][A^,2E2A^2] | [E3A^3][E3A^3][E3A^3][E3A^3][E3A^3] |
[E3A^3][E3A^3][E3A^3] [E3A^3][E3A^3] | [E3A^3][E3A^3][E3A^3][E3A^3][E3A^3] |
[E3A^3][E3A^3][E3A^3] [E3A^3][E3A^3] | [E3A^3] |
V:3 name="Electric Guitar (jazz)"
%%octave-default C5
E'3C'3A^4E'3C'3 | A^4E'3 C'3A^4E'3C'3 | A^4 E'3C'3A^4 E'3C'3 | A^4E'3C'3A^4E'3C'3 |
A^4E'3C'3 A^4E'3C'3 | A^4 |
In this notation we have the time signature and tempo defined at the top denoted by ‘M’ and ‘Q’. The ‘L’ indicates the default note length of the notation, in this case a sixteenth note, which is the norm. We then define each instrument and the default octave they should adhere to when writing the notes for each of them. Here’s a summary of the key syntactical points for writing notes in ABC music notation:
- Notes are represented by letters A-G, with lowercase letters indicating higher octaves
- Sharps are denoted by ^ before the note, flats by _
- Natural signs are represented by =
- Note length is indicated by numbers after the note (C2 is twice as long as C)
- Dotted notes use a . after the note (C. is a dotted quarter note)
- Rests are represented by z, with numbers for duration (z2 is a half rest)
- Chords are enclosed in square brackets [CEG]
- Ties are shown with a hyphen –
- Bar lines are represented by |
- Broken rhythms use > or < between notes (C>D means dotted-C eighth note followed by D sixteenth note)
Why ABC?
The reasons for choosing this notation are:
- It’s a minimalist format for writing music
- It’s widely used and popular; language models already have good comprehension of ABC notation due to extensive pre-training on it.
- It’s flexible and can easily be extended to include tempo changes, time signature changes, additional playing styles like mentioned above, etc…
I converted the MIDI files provided by the datasets to ABC notation using this library. A notebook for creating the datasets is here.
To evaluate both the original model and each stage of fine-tuning I performed thereafter, I randomly selected 30 samples of varying complexity from the URMP dataset and ran the model three times on each sample, manually examining all responses.
Through manual testing, I found the optimal decoding parameters to be a temperature of 0.7 and a top_p of 1.2. The maximum number of tokens to return was capped at 2048. Adjusting the max seemed to have little difference on performance.
The original model performed poorly on this evaluation set. While it occasionally predicted the tempo and instruments correctly, it mostly failed to do so. A text file with the evaluation results is available here.
Given this starting point, it’s unlikely that we’ll see strong results from this experiment without a robust pre-trained model. However, the goal is to develop strategies that can be applied in the future as more advanced pre-trained models become available.
I first attempted fine-tuning with basic cross-entropy loss. Supervised fine-tuning with cross-entropy loss is a quick way to start teaching the model but a basic loss function like this has limitations as we will see below. The intuition behind this stage of training is that it would nudge the model in the right direction and it would pick up any patterns or any customised ABC notation the dataset may have which the model may not have seen before.
Cross-entropy loss with teacher forcing
First, we trained it in a typical supervised fine-tuning manner for language models. I used the SFTtrainer
from the trl
library for this, which uses cross-entropy loss with teacher forcing defined step by step below:
- The model predicts the next token in the sequence.
- The loss is calculated based on the difference between the predicted probabilities (logits) and the actual next token.
- For the next prediction, the model is given the actual correct token (ground truth), rather than its own prediction. This is known as teacher forcing, it helps stabilise training and significantly speed it up, especially in the early stages.
The results from this training phase were poor. It degraded the performance of the original model. The model, which previously handled tempo and instrument recognition well, now mostly got these wrong. It also began producing garbled text output with endless repetition. This occurred even when setting a low learning rate, applying gradient clipping, and using low LoRA ranks to mitigate large changes to the model. Overall, it seemed the model was very sensitive to the training applied.
However, while this training phase may offer some improvements, it won’t lead to optimal performance due to the limitations of our basic loss function. This function struggles to fully capture the model’s performance nuances. For example, when using teacher forcing, instrument predictions can yield deceptively low loss across certain token sections. If an instrument name begins with “V”, the model might confidently predict “Violin” or “Viola” based on our dataset, regardless of accuracy. Additionally, the loss function may not accurately reflect near-misses, such as predicting a tempo of 195 instead of 200 — a small difference that’s reasonably accurate but potentially penalised heavily dependent on the distribution of probabilities amongst logits. It’s possible that neighbouring numbers also have high probabilities.
RLHF with PPO
Because of these limitations, we can create our own custom loss function that can more accurately score the response from the model. That is, given a predicted sequence from the model, the loss function could give it a score between 0 and 1 on how good it is.
However, integrating this custom loss function into supervised fine-tuning presents a significant challenge. The issue stems from the non-linearity introduced by the custom loss function, which prevents the direct calculation of gradients. Let’s break this down:
In traditional SFT with cross-entropy loss:
- The model outputs logits (raw scores) for each token in its vocabulary
- These logits directly represent the model’s prediction probabilities
- The loss function compares these probabilities to the ground truth
- Gradients can be computed directly through this comparison
- The chain rule of calculus allows us to propagate these gradients back through the model
With our custom loss function:
- The model must first generate complete text output
- This generation process involves sampling from probability distributions
- Our loss function then analyses this text output (checking tempo, notes, etc.)
- This creates a non-differentiable step between the model’s logits and our loss calculation
- The sampling and text analysis steps break the gradient chain needed for backpropagation
To overcome this, reinforcement learning techniques like Proximal Policy Optimisation (PPO) can be employed. PPO is specifically designed to handle non-differentiable loss functions and can optimise the model by considering the entire policy (the model’s output distribution), rather than relying on gradient information from logits.
Note, there’s a lot of great articles on here explaining PPO!
The key insight of PPO is that instead of trying to directly backpropagate through the non-differentiable steps, it:
- Treats the model’s outputs as actions in a reinforcement learning framework
- Uses the custom loss function as a reward signal
- Updates the model’s policy (its probability distributions over tokens) to maximise expected reward
- Does this while ensuring the updated policy doesn’t deviate too far from the current one
This approach allows us to effectively train the model with the custom loss function, ensuring performance improvements without disrupting the core training dynamics. The PPO algorithm’s conservative update strategy helps maintain stability during training, which is particularly important when working with large language models.
Usually, this scoring function would be implemented as a separate LLM in the form of a “reward model” commonly used when fine-tuning models via RLHF, which was a breakthrough first introduced when ChatGPT came out. Due to the nature of this task, we can manually write code to score the responses, which uses fewer resources and is quicker.
For time signature and tempo recognition this is easy to calculate. We extract all predicted items with regex, for example extracting the metre:
def extract_metre(self, abc_string):
return re.search(r'M:(\S+)', abc_string).group(1)
The model should learn the syntax and structure we want it to output in the SFT stage. If it outputs something that will cause our regex to not find anything or error, we can just skip that sample, assuming it’s a small minority of the dataset.
We extract the predicted tempo and write a function that is more forgiving for small errors but penalises larger errors more heavily:
- For small differences (≤10 BPM), it uses linear scaling.
- For larger differences, it switches to exponential scaling.
- The final loss is capped between 0 and 1.
Let’s break down the key components of this custom loss:
Code for the custom loss is here
1. Metre Loss
The metre loss focuses on the time signature of the piece. It compares the predicted metre with the ground truth, considering both the numerator and denominator separately, as well as their ratio. This approach allows for a nuanced evaluation that can handle various time signatures accurately.
The metre loss uses a combination of linear and exponential scaling to penalise differences. Small discrepancies result in a linear increase in loss, while larger differences lead to an exponential increase, capped at a maximum value of 1.
2. Tempo Loss
Tempo loss evaluates the accuracy of the predicted beats per minute (BPM). Similar to the metre loss, it uses a combination of linear and exponential scaling.
For small tempo differences (≤10 BPM), the function applies linear scaling. Larger differences trigger exponential scaling, ensuring that significant tempo mismatches are penalised more heavily.
3. Pitch Loss
The pitch loss is perhaps the most crucial component, as it assesses the accuracy of the transcribed notes. This function uses the Levenshtein distance to compare the sequence of notes in each voice.
The pitch loss calculation accounts for multiple voices, matching each predicted voice to the closest ground truth voice. This approach allows for flexibility in voice ordering while still maintaining accuracy in the overall pitch content.
4. Instrument Loss
The instrument loss evaluates the accuracy of instrument selection for each voice.
This function considers exact matches, instruments from the same family, and uses string similarity for more nuanced comparisons. It provides a comprehensive assessment of how well the model identifies and assigns instruments to each voice.
5. Combining the Losses
The final loss is a weighted combination of these individual components:
total_loss = (0.5 * pitch_loss +
0.15 * metre_loss +
0.15 * tempo_loss +
0.2 * instrument_loss)
This weighting scheme prioritises pitch accuracy while still considering other important aspects of music transcription.
PPO training generally requires a lot more memory than SFT for a few reasons:
- Multiple policy evaluations — PPO needs to maintain both the current policy (model weights) and an “old” policy to compute the probability ratio between them. This effectively doubles the model parameters in memory.
- Experience buffer — PPO stores a buffer of experiences (states, actions, rewards, etc.) to perform updates in mini-batches. This buffer can be quite large and takes significant memory.
- Advantage estimation — Computing advantages requires keeping track of value estimates and returns across trajectories, adding another layer of memory overhead.
- Additional optimisation objectives — PPO tracks multiple loss components (policy loss, value loss, entropy bonus) and their gradients, whereas SFT has a single loss.
Because of the above, we’re more limited than SFT in the size of the models we can train and how much it costs. Whereas the above training I could do on an A100 40GB in Colab, for the PPO training I needed more memory. I trained on an H100 80GB, which could train a LoRA with a rank of 128 and a batch size of 8.
My hyperparameter sweep was narrow, I went with what seemed most intuitive using batch sizes ranging from 1 to 16 and learning rates from 2e-5 to 2e-4.
The model made no improvements to the task. The text file with the results is here.
I tracked various training metrics using Weights & Biases (WandB). Key metrics included the policy loss, value loss, total loss, KL divergence, and the reward model’s score.
For all hyperparameter runs, the logs no improvement in the rewards and loss calculated over time. The KL divergence remained within the pre-defined threshold.