Iterative Decomposition Model V3

Table of Contents

Iterative Decomposition V3 Model

This article covers the continuation of work I've been pursuing in the area of sparse, interpretable audio models.

Some previous iterations of this work:

All training and model code can be found here.

Our goal is to decompose a musical audio signal into a small number of "events", roughly analogous to a musical score, but carrying information about the resonant characteristics of the instrument being played, and the room it is being played in. Each event is represented by a low-dimensional (32, in this case) vector and a time at which the event occurs in the "score".

We seek to achieve this goal by iteratively guessing at the next-most informative event, removing it from the original signal, and repeating the process, until no informative/interesting signal is remaining.

This is very similar to the matching pursuit algorithm, where we repeatedly convolve an audio signal with a dictionary of audio atoms, picking the most highly-correlated atom at each step, removing it, and the repeating the process until the norm of the original signal reaches some acceptable threshold.

The Algorithm

In this work, we replace the convolution of a large dictionary of audio "atoms" with analysis via a "deep" neural network, which uses an STFT transform followed by cascading dilated convolutions to efficiently analyze relatively long audio segments. At each iteration, it proposes an event vector and a time-of-occurrence. This event is then rendered by an event-generator network and scheduled, by convolving the rendered event with a dirac function (unit-valued spike) at the desired time. It is subtracted from the original audio spectrogram[^1], and the process is repeated. During training, this process runs for a fixed number of steps, but it's possible to imagine a modification whereby some other stopping condition is observed to improve efficiency.

The decoder makes some simple physics-based assumptions about the underlying signal, and uses convolutions with long kernels to model the transfer functions of the instruments and the rooms in which they are performed.

The Training Process

We train the model on ~three-second segments of audio from the MusicNet dataset, which represents approximately 33 hours of public-domain classical music. We optimize the model via gradient descent using the following training objectives:

  1. An iterative reconstruction loss, which asks the model to maximize the energy it removes from the signal at each step
  2. A sparsity loss, which asks the model to minimize the l1 norm of the event vectors, ideally leading to a sparse (few event) solution
  3. An adversarial loss, which masks about 50% of events and asks a discriminator network (trained in parallel) to judge them; this is intended to encourage the events to be independent and stand on their own as believable, musical events.

Improvements over the Previous Model

While the previous model only operated on around 1.5 seconds of audio, this model doubles that window, ultimately driving toward a fully streaming algorithm that can handle signals of arbitrary length. It also makes progress toward a much simpler decoder, which generates each event as a linear combination of lookup tables for the following elements:

Future Work

Model Size, Training Time and Dataset

Firstly, this model is relatively small, weighing in at ~26M parameters (~117 MB on disk) and has only been trained for around 24 hours, so it seems there is a lot of space to increase the model size, dataset size and training time to further improve. The reconstruction quality of the examples on this page is not amazing, certainly not good enough even for a lossy audio codec, but the structure the model extracts seems like it could be used for many interesting applications.

Streaming and/or Arbitrary Lengths

Ultimately, the model should be able to handle audio segments of arbitrary lengths, adhering to some event "budget" to find the sparsest-possible explanation of the audio segment.

A Better Sparsity Loss

Some of the examples lead me to believe that my current sparsity loss is too aggressive; the model sometimes prefers to leave events out entirely rather than get the "win" of reducing overall signal energy. Using the l1 norm penalty seems like a sledgehammer, and a more nuanced loss would probably do better.

Different Event Generator Variants

The decoder side of the model is very interesting, and all sorts of physical modelling-like approaches could yield better, more realistic, and sparser renderings of the audio.

Cite this Article

If you'd like to cite this article, you can use the following BibTeX block.

Event Scatterplot

Here is a scatterplot mapping events from four different audio segments onto a 2D plane using t-SNE.
Each 32-dimensional event vector encodes information about attack, resonance, and room impulse response.

Examples

Example 1

Original Audio

Reconstruction

Randomized

Here, we generate random event vectors with the original event times.

Here we use the original event vectors, but generate random times.

Event Vectors

Event Scatterplot

Events clustered using t-SNE

Individual Audio Events

Example 2

Original Audio

Reconstruction

Randomized

Here, we generate random event vectors with the original event times.

Here we use the original event vectors, but generate random times.

Event Vectors

Event Scatterplot

Events clustered using t-SNE

Individual Audio Events

Example 3

Original Audio

Reconstruction

Randomized

Here, we generate random event vectors with the original event times.

Here we use the original event vectors, but generate random times.

Event Vectors

Event Scatterplot

Events clustered using t-SNE

Individual Audio Events

Example 4

Original Audio

Reconstruction

Randomized

Here, we generate random event vectors with the original event times.

Here we use the original event vectors, but generate random times.

Event Vectors

Event Scatterplot

Events clustered using t-SNE

Individual Audio Events

Example 5

Original Audio

Reconstruction

Randomized

Here, we generate random event vectors with the original event times.

Here we use the original event vectors, but generate random times.

Event Vectors

Event Scatterplot

Events clustered using t-SNE

Individual Audio Events

# the size, in samples of the audio segment we'll overfit n_samples = 2 ** 16 samples_per_event = 2048 n_events = n_samples // samples_per_event context_dim = 32 # the samplerate, in hz, of the audio signal samplerate = 22050 # derived, the total number of seconds of audio n_seconds = n_samples / samplerate transform_window_size = 2048 transform_step_size = 256 n_frames = n_samples // transform_step_size from argparse import ArgumentParser from typing import Dict, Tuple import numpy as np import torch from sklearn.manifold import TSNE from torch import nn from conjure import S3Collection, \ conjure_article, CitationComponent, AudioComponent, ImageComponent, \ CompositeComponent, Logger, ScatterPlotComponent from data import get_one_audio_segment, AudioIterator from iterativedecomposition import Model as IterativeDecompositionModel from modules.eventgenerators.overfitresonance import OverfitResonanceModel from modules import max_norm, sparse_softmax remote_collection_name = 'iterative-decomposition-v3' def to_numpy(x: torch.Tensor): return x.data.cpu().numpy() # thanks to https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9 def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def process_events( vectors: torch.Tensor, times: torch.Tensor, total_seconds: float) -> Tuple: positions = torch.argmax(times, dim=-1, keepdim=True) / times.shape[-1] times = [float(x) for x in (positions * total_seconds).view(-1).data.cpu().numpy()] normalized = vectors.data.cpu().numpy().reshape((-1, context_dim)) normalized = normalized - normalized.min(axis=0, keepdims=True) normalized = normalized / (normalized.max(axis=0, keepdims=True) + 1e-8) tsne = TSNE(n_components=2) points = tsne.fit_transform(normalized) proj = np.random.uniform(0, 1, (2, 3)) colors = points @ proj colors -= colors.min() colors /= (colors.max() + 1e-8) colors *= 255 colors = colors.astype(np.uint8) colors = [f'rgb({c[0]} {c[1]} {c[2]})' for c in colors] return points, times, colors def load_model(wavetable_device: str = 'cpu') -> nn.Module: hidden_channels = 512 model = IterativeDecompositionModel( in_channels=1024, hidden_channels=hidden_channels, resonance_model=OverfitResonanceModel( n_noise_filters=32, noise_expressivity=8, noise_filter_samples=128, noise_deformations=16, instr_expressivity=8, n_events=1, n_resonances=4096, n_envelopes=256, n_decays=32, n_deformations=32, n_samples=n_samples, n_frames=n_frames, samplerate=samplerate, hidden_channels=hidden_channels, wavetable_device=wavetable_device, fine_positioning=True )) with open('iterativedecomposition4.dat', 'rb') as f: model.load_state_dict(torch.load(f, map_location=lambda storage, loc: storage)) print('Total parameters', count_parameters(model)) print('Encoder parameters', count_parameters(model.encoder)) print('Decoder parameters', count_parameters(model.resonance)) return model def scatterplot_section(logger: Logger) -> ScatterPlotComponent: model = load_model() ai = AudioIterator( batch_size=4, n_samples=n_samples, samplerate=22050, normalize=True, as_torch=True) batch = next(iter(ai)) batch = batch.view(-1, 1, n_samples).to('cpu') events, vectors, times = model.iterative(batch) total_seconds = n_samples / samplerate points, times, colors = process_events(vectors, times, total_seconds) events = events.view(-1, n_samples) events = {f'event{i}': events[i: i + 1, :] for i in range(events.shape[0])} scatterplot_srcs = [] event_components = {} for k, v in events.items(): _, e = logger.log_sound(k, v) scatterplot_srcs.append(e.public_uri) event_components[k] = AudioComponent(e.public_uri, height=35, controls=False) scatterplot_component = ScatterPlotComponent( scatterplot_srcs, width=500, height=500, radius=0.3, points=points, times=times, colors=colors, ) return scatterplot_component def generate_multiple_events( model: nn.Module, vectors: torch.Tensor, times: torch.Tensor) -> torch.Tensor: generation_result = torch.cat( [model.generate(vectors[:, i:i + 1, :], times[:, i:i + 1, :]) for i in range(n_events)], dim=1) generation_result = torch.sum(generation_result, dim=1, keepdim=True) generation_result = max_norm(generation_result) return generation_result def generate( model: nn.Module, vectors: torch.Tensor, times: torch.Tensor, randomize_events: bool, randomize_times: bool) -> torch.Tensor: batch, n_events, _ = vectors.shape if randomize_events: vectors = torch.zeros_like(vectors).uniform_(vectors.min().item(), vectors.max().item()) if randomize_times: times = torch.zeros_like(times).uniform_(-1, 1) times = sparse_softmax(times, dim=-1, normalize=True) * times # generation_result = torch.cat( # [model.generate(vectors[:, i:i + 1, :], times[:, i:i + 1, :]) for i in range(n_events)], dim=1) # generation_result = torch.sum(generation_result, dim=1, keepdim=True) # generation_result = max_norm(generation_result) generation_result = generate_multiple_events(model, vectors, times) return generation_result def reconstruction_section(logger: Logger) -> CompositeComponent: model = load_model() # get a random audio segment samples = get_one_audio_segment(n_samples, samplerate, device='cpu').view(1, 1, n_samples) events, vectors, times = model.iterative(samples) # generate audio with the same times, but randomized event vectors randomized_events = generate(model, vectors, times, randomize_events=True, randomize_times=False) _, random_events = logger.log_sound('randomizedevents', randomized_events) random_events_component = AudioComponent(random_events.public_uri, height=100, controls=True) # generate audio with the same events, but randomized times randomized_times = generate(model, vectors, times, randomize_events=False, randomize_times=True) _, random_times = logger.log_sound('randomizedtimes', randomized_times) random_times_component = AudioComponent(random_times.public_uri, height=100, controls=True) total_seconds = n_samples / samplerate points, times, colors = process_events(vectors, times, total_seconds) # sum together all events summed = torch.sum(events, dim=1, keepdim=True) _, original = logger.log_sound(f'original', samples) _, reconstruction = logger.log_sound(f'reconstruction', summed) orig_audio_component = AudioComponent(original.public_uri, height=100) recon_audio_component = AudioComponent(reconstruction.public_uri, height=100) events = {f'event{i}': events[:, i: i + 1, :] for i in range(events.shape[1])} scatterplot_srcs = [] event_components = {} for k, v in events.items(): _, e = logger.log_sound(k, v) scatterplot_srcs.append(e.public_uri) event_components[k] = AudioComponent(e.public_uri, height=25, controls=False) scatterplot_component = ScatterPlotComponent( scatterplot_srcs, width=300, height=300, radius=0.04, points=points, times=times, colors=colors, ) _, event_vectors = logger.log_matrix_with_cmap('latents', vectors[0].T, cmap='hot') latents = ImageComponent(event_vectors.public_uri, height=200, title='latent event vectors') composite = CompositeComponent( orig_audio=orig_audio_component, recon_audio=recon_audio_component, latents=latents, scatterplot=scatterplot_component, random_events=random_events_component, random_times=random_times_component, **event_components ) return composite

Notes

This blog post is generated from a Python script using conjure.

[^1]: While the STFT (short-time fourier transform) doesn't capture everything of perceptual import, it does a fairly good job, better than the "raw", time-domain audio signal, at least. In the time domain, we get into trouble when we begin to try to represent and remove the noisier parts of the signal; here the statistics and relationships between different auditory bandpass filters become more important than the precise amplitude values.

def demo_page_dict() -> Dict[str, any]: print(f'Generating article...') remote = S3Collection( remote_collection_name, is_public=True, cors_enabled=True) logger = Logger(remote) print('Creating large scatterplot') large_scatterplot = scatterplot_section(logger) print('Creating reconstruction examples') example_1 = reconstruction_section(logger) example_2 = reconstruction_section(logger) example_3 = reconstruction_section(logger) example_4 = reconstruction_section(logger) example_5 = reconstruction_section(logger) citation = CitationComponent( tag='johnvinyarditerativedecompositionv3', author='Vinyard, John', url='https://blog.cochlea.xyz/iterative-decomposition-v3.html', header='Iterative Decomposition V3', year='2024', ) return dict( large_scatterplot=large_scatterplot, example_1=example_1, example_2=example_2, example_3=example_3, example_4=example_4, example_5=example_5, citation=citation ) def generate_demo_page(): display = demo_page_dict() conjure_article( __file__, 'html', title='Iterative Decomposition Model V3', **display) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('--clear', action='store_true') parser.add_argument('--list', action='store_true') args = parser.parse_args() if args.list: remote = S3Collection( remote_collection_name, is_public=True, cors_enabled=True) print(remote) print('Listing stored keys') for key in remote.iter_prefix(start_key=b'', prefix=b''): print(key) if args.clear: remote = S3Collection( remote_collection_name, is_public=True, cors_enabled=True) remote.destroy(prefix=b'') generate_demo_page()
Back to Top