Learning "Playable" State-Space Models from Audio
Table of Contents
- [The Experiment</h2>](#The Experiment)
This work attempts to reproduce a short segment of "natural" (i.e., produced by acoustic
instruments or physical objects in the world) audio by decomposing it into two distinct pieces:
- A single-layer RNN simulating the resonances of the system
- a sparse control signal, representing energy injected into the system.
The control signal can be thought of as roughly corresponding to a musical score, and the RNN
can be thought of as the dynamics/resonances of the musical instrument and the room in which it was played.
It's notable that in this experiment (unlike
my other recent work), there is no learned "encoder". We simply "overfit"
parameters to a single audio segment, by minimizing a combination of reconstruction and sparsity losses.
As a sneak-peek, here's a novel sound created by feeding a random, sparse control signal into
a state-space model "extracted" from an audio segment.
Feel free to jump ahead if you're curious to hear all the audio examples first!
First, we'll set up high-level parameters for the experiment
# the size, in samples of the audio segment we'll overfit
n_samples = 2 ** 18
# the samplerate, in hz, of the audio signal
samplerate = 22050
# derived, the total number of seconds of audio
n_seconds = n_samples / samplerate
# the size of each, half-lapped audio "frame"
window_size = 128
# the dimensionality of the control plane or control signal
control_plane_dim = 64
# the dimensionality of the state vector, or hidden state for the RNN
state_dim = 128
# the number of (batch, control_plane_dim, frames) elements allowed to be non-zero
n_active_sites = 256
The Experiment
We'll build a PyTorch model that will learn a system's resonances, along with a sparse control
signal, by "overfitting" the model to a single segment of ~12 seconds of audio drawn
from my favorite source for acoustic musical signals, the
MusicNet dataset dataset.
Even though we're overfitting a single audio signal, the imposition of sparsity forces the model to generalize in some ways.
Our working theory is that the control signal must be sparse,
which places certain constraints on the type of matrices the model must learn to accurately reproduce the audio. If I
strike a piano key, the sound does not die away immediately and I do not have to continue to "drive" the sound by
continually injecting energy; the strings and the body of the piano continue to resonate for quite some time.
In this experiment, we use the l0 norm for the sparsity loss, and a straight-through estimator so that it remains
roughly differentiable.
While it hasn't showed up in the code we've seen so far, but we'll be using
conjure
to monitor the training process while iterating on the code, and
eventually to generate this article once things have settled.
We'll start with some boring imports.
from typing import Dict, Union
import numpy as np
import torch
from torch import nn
from itertools import count
from data import get_one_audio_segment
from modules import max_norm, flattened_multiband_spectrogram, sparsify
from torch.optim import Adam
from util import device, encode_audio, make_initializer
from conjure import logger, LmdbCollection, serve_conjure, SupportedContentType, loggers, \
NumpySerializer, NumpyDeserializer, S3Collection, \
conjure_article, CitationComponent, AudioComponent, ImageComponent, \
CompositeComponent, Logger, MetaData, InstrumentComponent
from torch.nn.utils.clip_grad import clip_grad_value_
from argparse import ArgumentParser
from modules.infoloss import CorrelationLoss
from base64 import b64encode
from sklearn.decomposition import PCA
remote_collection_name = 'state-space-model-demo-2'
The InstrumentModel
Class
Now, for the good stuff! We'll define our simple State-Space Model as an
nn.Module
-derived class with four parameters
corresponding to each of the four matrices.
Note that there is a slight deviation from the canonical SSM in that we have a fifth matrix, which projects from our
"control plane" for the instrument into the dimension of a single audio frame.
init_weights = make_initializer(0.05)
class InstrumentModel(nn.Module):
"""
A state-space model-like module, with one additional matrix, used to project the control
signal into the shape of each audio frame.
The final output is produced by overlap-adding the windows/frames of audio into a single
1D signal.
"""
def __init__(self, control_plane_dim: int, input_dim: int, state_matrix_dim: int):
super().__init__()
self.state_matrix_dim = state_matrix_dim
self.input_dim = input_dim
self.control_plane_dim = control_plane_dim
self.net = nn.RNN(
input_size=input_dim,
hidden_size=state_matrix_dim,
num_layers=1,
nonlinearity='tanh',
bias=False,
batch_first=True)
print(self.net)
# matrix mapping control signal to audio frame dimension
self.proj = nn.Parameter(
torch.zeros(control_plane_dim, input_dim).uniform_(-0.01, 0.01)
)
self.out_proj = nn.Linear(state_dim, window_size, bias=False)
self.apply(init_weights)
def forward(self, control: torch.Tensor) -> torch.Tensor:
"""
(batch, control_plane, time) -> (batch, window_size, time)
"""
batch, cpd, frames = control.shape
assert cpd == self.control_plane_dim
control = control.permute(0, 2, 1)
# try to ensure that the input signal only includes low-frequency info
proj = control @ self.proj
# proj = F.interpolate(proj, size=self.input_dim, mode='linear')
# proj = proj * torch.zeros_like(proj).uniform_(-1, 1)
# proj = proj * torch.hann_window(self.input_dim, device=proj.device)
assert proj.shape == (batch, frames, self.input_dim)
result, hidden = self.net.forward(proj)
result = self.out_proj(result)
result = result.view(batch, 1, -1)
result = torch.sin(result)
return result
The OverfitControlPlane
Class
This model encapsulates an InstrumentModel
instance, and also has a parameter for the sparse "control plane" that will serve
as the input energy for our resonant model. I think of this as a time-series of vectors that describe the different
ways that energy can be injected into the model, e.g., you might have individual dimensions representing different
keys on a piano, or strings on a cello.
I don't expect the control signals learned here to be quite that clear-cut
and interpretable, but you might notice that the random audio samples produced using the learned models
do seem to disentangle some characteristics of the instruments being played!
# thanks to https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class OverfitControlPlane(nn.Module):
"""
Encapsulates parameters for control signal and state-space model
"""
def __init__(
self,
control_plane_dim: int,
input_dim: int,
state_matrix_dim: int,
n_samples: int):
super().__init__()
self.ssm = InstrumentModel(control_plane_dim, input_dim, state_matrix_dim)
self.n_samples = n_samples
self.n_frames = n_samples // input_dim
self.control_plane_dim = control_plane_dim
self.control = nn.Parameter(
torch.zeros(1, control_plane_dim, self.n_frames).uniform_(0, 1))
def _get_mapping(self, n_components: int) -> np.ndarray:
# cs = self.control_signal.data.cpu().numpy() \
# .reshape(self.control_plane_dim, self.n_frames).T
# pca = PCA(n_components=n_components)
# pca.fit(cs)
# # this will be of shape (n_components, control_plane_dim)
# return pca.components_
return np.random.uniform(-1, 1, (n_components, self.control_plane_dim))
def get_control_plane_mapping(self) -> np.ndarray:
mapping = self._get_mapping(n_components=2)
print(mapping.shape)
rnd = np.random.uniform(0, 1, (27, 2))
cp = rnd @ mapping
print(cp.shape)
print(cp)
print(cp.min(), self.control_signal.min())
print(cp.max(), self.control_signal.max())
return mapping
def get_accelerometer_mapping(self) -> np.ndarray:
return self._get_mapping(n_components=3)
@property
def control_signal_display(self) -> np.ndarray:
return self.control_signal.data.cpu().numpy().reshape((-1, self.n_frames))
@property
def control_signal(self) -> torch.Tensor:
s = sparsify(self.control, n_to_keep=n_active_sites)
return torch.relu(s)
# TODO: This should depend on the time-dimension alone
def random(self, p=0.0001):
"""
Produces a random, sparse control signal, emulating short, transient bursts
of energy into the system modelled by the `SSM`
"""
cp = torch.zeros_like(self.control, device=self.control.device).bernoulli_(p=p)
audio = self.forward(sig=cp)
return max_norm(audio)
def rolled_control_plane(self):
"""
Randomly permute the input control signal, so that the overall pattern of energy injection
is somewhat realistic, but energy is injected differently than in the original performance.
"""
indices = torch.randperm(control_plane_dim)
cp = self.control_signal[:, indices, :]
audio = self.forward(sig=cp)
return max_norm(audio)
def forward(self, sig=None):
"""
Inject energy defined by `sig` (or by the `control` parameters encapsulated by this class)
into the system modelled by `SSM`
"""
return self.ssm.forward(sig if sig is not None else self.control_signal)
def generate_param_dict(
key: str,
model: OverfitControlPlane,
logger: Logger) -> [dict, MetaData]:
serializer = NumpySerializer()
params = dict()
# note, I'm transposing here to avoid the messiness of dealing with the transpose in Javascript, for now
params['in_projection'] = b64encode(serializer.to_bytes(model.ssm.proj.data.cpu().numpy().T)).decode()
params['out_projection'] = b64encode(
serializer.to_bytes(model.ssm.out_proj.weight.data.cpu().numpy().T)).decode()
named_params = dict(model.ssm.net.named_parameters())
params['rnn_in_projection'] = b64encode(
serializer.to_bytes(named_params['weight_ih_l0'].data.cpu().numpy().T)).decode()
params['rnn_out_projection'] = b64encode(
serializer.to_bytes(named_params['weight_hh_l0'].data.cpu().numpy().T)).decode()
params['control_plane_mapping'] = b64encode(serializer.to_bytes(model.get_control_plane_mapping().T)).decode()
params['accelerometer_mapping'] = b64encode(serializer.to_bytes(model.get_accelerometer_mapping().T)).decode()
_, meta = logger.log_json(key, params)
return params, meta
The Training Process
To train the OverfitControlPlane
model, we randomly initialize parameters for InstrumentModels
and the learned
control signal, and minimize a reconstruction loss via gradient
descent. For this experiment, we're using the Adam
optimizer with a learning rate of 1e-2
.
Reconstruction Loss
The first loss term is a simple reconstruction loss, consisting of the l1 norm of the difference between
two multi-samplerate and multi-resolution spectrograms.
def transform(x: torch.Tensor):
"""
Decompose audio into sub-bands of varying sample rate, and compute spectrogram with
varying time-frequency tradeoffs on each band.
"""
return flattened_multiband_spectrogram(
x,
stft_spec={
'long': (128, 64),
'short': (64, 32),
'xs': (16, 8),
},
smallest_band_size=512)
def reconstruction_loss(recon: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute the l1 norm of the difference between the `recon` and `target`
representations
"""
fake_spec = transform(recon)
real_spec = transform(target)
return torch.abs(fake_spec - real_spec).sum()
Imposed Sparsity
Ideally, we want the model to resonate, or store and "leak" energy slowly in the way that an
acoustic instrument might. This means that the control signal is not dense and continually "driving" the instrument,
but injecting energy infrequently in ways that encourage the natural resonances of the physical object.
I'm not fully satisfied with this approach. e.g. it tends to pull away from what might be a nice,
natural control signal for a violin or other bowed instrument. In my mind, this might look like a sub-20hz sawtooth
wave that would "drive" the string, continually catching and releasing as the bow drags across the string.
Instead of a sparsity loss via l1 or l2 norm, we only keep the top 256 active elements in the control plane.
For now, the imposition of sparsity does seem to encourage models that resonate, but my intuition is that
there is a better, more nuanced approach that could handle bowed string instruments and wind instruments,
in addition to percussive instruments, where this approach really seems to shine.
def sparsity_loss(c: torch.Tensor) -> torch.Tensor:
"""
Compute the l1 norm of the control signal
"""
return torch.abs(c).sum() * 1e-2
def to_numpy(x: torch.Tensor):
return x.data.cpu().numpy()
def construct_experiment_model(state_dict: Union[None, dict] = None) -> OverfitControlPlane:
"""
Construct a randomly initialized `OverfitControlPlane` instance, ready for training/overfitting
"""
model = OverfitControlPlane(
control_plane_dim=control_plane_dim,
input_dim=window_size,
state_matrix_dim=state_dim,
n_samples=n_samples
)
model = model.to(device)
if state_dict is not None:
model.load_state_dict(state_dict)
return model
Examples
Finally, some trained models to listen to! Each example consists of the following:
- the original audio signal from the MusicNet dataset
- the sparse control signal for the reconstruction
- the reconstructed audio, produced using the sparse control signal and the learned state-space model
- a novel, random audio signal produced using the learned state-space model and a random control signal
Example 1
Original Audio
A random 11.89 seconds segment of audio drawn from the MusicNet dataset
Reconstruction
Reconstruction of the original audio after overfitting the model for 5000 iterations
Random Audio
Signal produced by a random, sparse control signal after overfitting the model for 5000 iterations
Random Permutation Along Control Plane Dimension
Control Signal
Sparse control signal for the original audio after overfitting the model for 5000 iterations
Interactive Instrument
We project the 2D coordinates of the click site to the 64-dimension control-plane dimension and trigger an event
Example 2
Original Audio
A random 11.89 seconds segment of audio drawn from the MusicNet dataset
Reconstruction
Reconstruction of the original audio after overfitting the model for 5000 iterations
Random Audio
Signal produced by a random, sparse control signal after overfitting the model for 5000 iterations
Random Permutation Along Control Plane Dimension
Control Signal
Sparse control signal for the original audio after overfitting the model for 5000 iterations
Interactive Instrument
We project the 2D coordinates of the click site to the 64-dimension control-plane dimension and trigger an event
Example 3
Original Audio
A random 11.89 seconds segment of audio drawn from the MusicNet dataset
Reconstruction
Reconstruction of the original audio after overfitting the model for 5000 iterations
Random Audio
Signal produced by a random, sparse control signal after overfitting the model for 5000 iterations
Random Permutation Along Control Plane Dimension
Control Signal
Sparse control signal for the original audio after overfitting the model for 5000 iterations
Interactive Instrument
We project the 2D coordinates of the click site to the 64-dimension control-plane dimension and trigger an event
Example 4
Original Audio
A random 11.89 seconds segment of audio drawn from the MusicNet dataset
Reconstruction
Reconstruction of the original audio after overfitting the model for 5000 iterations
Random Audio
Signal produced by a random, sparse control signal after overfitting the model for 5000 iterations
Random Permutation Along Control Plane Dimension
Control Signal
Sparse control signal for the original audio after overfitting the model for 5000 iterations
Interactive Instrument
We project the 2D coordinates of the click site to the 64-dimension control-plane dimension and trigger an event
Example 5
Original Audio
A random 11.89 seconds segment of audio drawn from the MusicNet dataset
Reconstruction
Reconstruction of the original audio after overfitting the model for 5000 iterations
Random Audio
Signal produced by a random, sparse control signal after overfitting the model for 5000 iterations
Random Permutation Along Control Plane Dimension
Control Signal
Sparse control signal for the original audio after overfitting the model for 5000 iterations
Interactive Instrument
We project the 2D coordinates of the click site to the 64-dimension control-plane dimension and trigger an event
Code For Generating this Article
What follows is the code used to train the model and produce the article you're reading. It uses
the conjure
Python library, a tool I've been writing
that helps to persist and display images, audio and other code artifacts that are interleaved throughout
this post.
def demo_page_dict(n_iterations: int = 100) -> Dict[str, any]:
print(f'Generating article, training models for {n_iterations} iterations')
remote = S3Collection(
remote_collection_name, is_public=True, cors_enabled=True)
def train_model_for_segment(
target: torch.Tensor,
iterations: int):
loss_model = CorrelationLoss().to(device)
while True:
model = construct_experiment_model()
optim = Adam(model.parameters(), lr=1e-3)
for iteration in range(iterations):
optim.zero_grad()
recon = model.forward()
loss = \
loss_model.multiband_noise_loss(target, recon, window_size=32, step=16) \
+ reconstruction_loss(recon, target)
if torch.isnan(loss).any():
print(f'detected NaN at iteration {iteration}')
break
loss.backward()
clip_grad_value_(model.parameters(), 0.5)
optim.step()
print(iteration, loss.item())
if iteration < n_iterations - 1:
print('NaN detected, starting anew')
continue
# total SSM parameters
model_param_count = count_parameters(model.ssm)
# non-zero control plane parameters
non_zero = torch.sum(model.control_signal > 0)
total_params = model_param_count + non_zero
compression_ratio = total_params / n_samples
print('COMPRESSION RATIO', compression_ratio * 100)
break
return model.state_dict()
def encode(arr: np.ndarray) -> bytes:
return encode_audio(arr)
conj_logger = Logger(remote)
# define loggers
audio_logger = logger(
'audio', 'audio/wav', encode, remote)
def train_model_for_segment_and_produce_artifacts(
key: str, n_iterations: int):
audio_tensor = get_one_audio_segment(n_samples).view(1, 1, n_samples)
audio_tensor = max_norm(audio_tensor)
state_dict = train_model_for_segment(audio_tensor, n_iterations)
hydrated = construct_experiment_model(state_dict)
with torch.no_grad():
recon = hydrated.forward()
random = hydrated.random()
rolled = hydrated.rolled_control_plane()
_, orig_audio = conj_logger.log_sound('orig', audio_tensor)
_, recon_audio = audio_logger.result_and_meta(recon)
_, random_audio = audio_logger.result_and_meta(random)
_, rolled_audio = audio_logger.result_and_meta(rolled)
_, control_plane = conj_logger.log_matrix_with_cmap('controlplane', hydrated.control_signal[0], cmap='hot')
params, param_meta = generate_param_dict(key, hydrated, conj_logger)
result = dict(
orig=orig_audio,
recon=recon_audio,
control_plane=control_plane,
random=random_audio,
rolled=rolled_audio,
params=param_meta
)
return result
def train_model_and_produce_components(key: str, n_iterations: int):
"""
Produce artifacts/media for a single example section
"""
result_dict = train_model_for_segment_and_produce_artifacts(key, n_iterations)
orig = AudioComponent(result_dict['orig'].public_uri, height=200, samples=512)
recon = AudioComponent(result_dict['recon'].public_uri, height=200, samples=512)
random = AudioComponent(result_dict['random'].public_uri, height=200, samples=512)
rolled = AudioComponent(result_dict['rolled'].public_uri, height=200, samples=512)
control = ImageComponent(result_dict['control_plane'].public_uri, height=200)
instr = InstrumentComponent(result_dict['params'].public_uri)
return dict(
orig=orig,
recon=recon,
control=control,
random=random,
rolled=rolled,
instr=instr
)
def train_model_and_produce_content_section(
n_iterations: int, number: int) -> CompositeComponent:
"""
Produce a single "Examples" section for the post
"""
component_dict = train_model_and_produce_components(f'rnnweights{number}', n_iterations)
composite = CompositeComponent(
header=f'## Example {number}',
orig_header='### Original Audio',
orig_text=f'A random {n_seconds:.2f} seconds segment of audio drawn from the MusicNet dataset',
orig_component=component_dict['orig'],
recon_header='### Reconstruction',
recon_text=f'Reconstruction of the original audio after overfitting the model for {n_iterations} iterations',
recon_component=component_dict['recon'],
random_header='### Random Audio',
random_text=f'Signal produced by a random, sparse control signal after overfitting the model for {n_iterations} iterations',
random_component=component_dict['random'],
rolled_header='### Random Permutation Along Control Plane Dimension',
rolled_text=f'',
rolled_component=component_dict['rolled'],
control_header='### Control Signal',
control_text=f'Sparse control signal for the original audio after overfitting the model for {n_iterations} iterations',
control_component=component_dict['control'],
instr_header='### Interactive Instrument',
instr_text='We project the 2D coordinates of the click site to the 64-dimension control-plane dimension and trigger an event',
instr_component=component_dict['instr'],
)
return composite
example_1 = train_model_and_produce_content_section(
n_iterations=n_iterations,
number=1
)
example_2 = train_model_and_produce_content_section(
n_iterations=n_iterations,
number=2
)
example_3 = train_model_and_produce_content_section(
n_iterations=n_iterations,
number=3
)
example_4 = train_model_and_produce_content_section(
n_iterations=n_iterations,
number=4
)
example_5 = train_model_and_produce_content_section(
n_iterations=n_iterations,
number=5
)
citation = CitationComponent(
tag='johnvinyardstatespacemodels',
author='Vinyard, John',
url='https://blog.cochlea.xyz/ssm.html',
header='RNN Resonance Modelling for Sparse Decomposition of Audio',
year='2024'
)
return dict(
example_1=example_1,
example_2=example_2,
example_3=example_3,
example_4=example_4,
example_5=example_5,
citation=citation,
)
def generate_demo_page(iterations: int = 500):
display = demo_page_dict(n_iterations=iterations)
conjure_article(
__file__,
'html',
title='Learning "Playable" State-Space Models from Audio',
**display)
Training Code
As I developed this model, I used the following code to pick a random audio segment, overfit a model, and monitor
its progress during training.
def train_and_monitor():
target = get_one_audio_segment(n_samples=n_samples, samplerate=samplerate)
collection = LmdbCollection(path='ssm')
recon_audio, orig_audio, random_audio = loggers(
['recon', 'orig', 'random'],
'audio/wav',
encode_audio,
collection)
envelopes, = loggers(
['envelopes'],
SupportedContentType.Spectrogram.value,
to_numpy,
collection,
serializer=NumpySerializer(),
deserializer=NumpyDeserializer())
orig_audio(target)
serve_conjure([
orig_audio,
recon_audio,
envelopes,
random_audio,
], port=9999, n_workers=1)
loss_model = CorrelationLoss().to(device)
def train(target: torch.Tensor):
model = construct_experiment_model()
optim = Adam(model.parameters(), lr=1e-3)
for iteration in count():
optim.zero_grad()
recon = model.forward()
recon_audio(max_norm(recon))
loss = \
loss_model.multiband_noise_loss(target, recon, window_size=32, step=16) \
+ reconstruction_loss(recon, target)
envelopes(model.control_signal.view(control_plane_dim, -1))
loss.backward()
clip_grad_value_(model.parameters(), 0.5)
optim.step()
print(iteration, loss.item())
with torch.no_grad():
rnd = model.random()
random_audio(rnd)
train(target)
Conclusion
Thanks for reading this far!
I'm excited about the results of this experiment, but am not totally pleased
with the frame-based approach, which leads to very noticeable artifacts in the reconstructions. It runs counter to
one of my guiding principles as I try to design a sparse, interpretable, and easy-to-manipulate audio codec, which
is that there is no place for arbitrary, fixed-size "frames". Ideally, we represent audio as a sparse set of events
or sound sources that are sample-rate independent, i.e., more like a function or operator, and less like a rasterized
representation.
I'm just beginning to learn more about state-space models and was excited when I learned from Albert Gu in his
excellent talk "Efficiently Modeling Long Sequences with Structured State Spaces"
that there are ways to transform state-space models, which strongly resemble
IIR filters, into their
FIR counterpart, convolutions, which I've depended on
heavily to model resonance in other recent work.
I'm looking forward to following this thread and beginning to find where the two different approaches converge!
Future Work
- Instead of (or in addition to) a sparsity loss, could we build in more physics-informed losses, such as conservation
of energy, i.e., overall energy can never increase unless it comes from the control signal?
- Could we use
scipy.signal.StateSpace
to
derive a continuous-time formulation of the model?
- How would a model like this work as an event generator in my sparse, interpretable audio model from other
experiments?
- Could we treat an entire, multi-instrument song as a single, large state-space model, learning a compressed
representation of the audio and a "playable" instrument, all at the same time?
Cite this Article
If you'd like to cite this article, you can use the following BibTeX block.
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--mode', type=str, required=True)
parser.add_argument('--iterations', type=int, default=250)
parser.add_argument('--prefix', type=str, required=False, default='')
args = parser.parse_args()
if args.mode == 'train':
train_and_monitor()
elif args.mode == 'demo':
generate_demo_page(args.iterations)
elif args.mode == 'list':
remote = S3Collection(
remote_collection_name, is_public=True, cors_enabled=True)
print('Listing stored keys')
for key in remote.iter_prefix(start_key=args.prefix):
print(key)
elif args.mode == 'clear':
remote = S3Collection(
remote_collection_name, is_public=True, cors_enabled=True)
remote.destroy(prefix=args.prefix)
else:
raise ValueError('Please provide one of train, demo, or clear')
Back to Top