Extracting Playable Instrument Models from Short Audio Examples
Table of Contents
In this micro-experiment, I ask whether it's possible to factor short audio segments into two distinct parts: first a
control-signal, an abstract representation of the actions taken by a human performer, and second, an "instrument",
representing the resonances of the physical object(s) being played, and the room in which the performance occurs.
The weights, or learned parameters of the instrument artifact can be used to produce new musical performances,
while the two factors together constitute a (lossy) compressive and interpretable representation of the original signal.
To make this more concrete, here's a video of me using a learned instrument to produce a new audio segment via a
hand-tracking interface.
If you're excited to just listen to reconstructions and play with the "instrument" artifacts, feel free to
jump ahead to the audio examples!
Extracting Instruments from Audio Recordings
My primary focus over the last two years has been on finding
sparse representations of audio signals.
One promising approach has been to build models that factor audio signals into two components:
- a control signal, describing the way energy is injected into some system; the work of a human performer.
- The acoustic resonances of the system itself; describing how energy injected from the control signal
is stored by the system and emitted over time.
Human musicians aren't engaged in wiggling physical objects many hundreds or thousands of times per-second, but they
do learn to skillfully manipulate these properties to create engaging music. This mental model naturally leads to
sparser, score-like representations of musical audio.
I've also begun to investigate whether this approach can extract useful artifacts from very small datasets, and in
this case, even individual audio segments.
The Model
Instead of compiling a massive dataset and worshipping at the altar of the
"Bitter Lesson", I instead cram this experiment full of inductive
biases, imposing my own, rudimentary mental model of a musical instrument as a strong form of regularization.
My hypothesis is that we can learn something useful from even a few seconds of audio by insisting that it fits a
simple model of physics and resonance.
I've attempted a similar experiment before, using a
single-layer recurrent neural network to model the instrument. This yielded some
limited success, but the frame-based approach caused many audible artifacts that I wasn't happy with.
In this new experiment, the control signal is a 16-dimensional time-varying vector running at ~20hz. This signal is first
convolved with a small set of "attack envelopes", which are multiplied with uniform (white) noise. These attacks are then
routed to and convolved with some number of resonances, parameterized by a set of
damped harmonic oscillators.
Finally, the resonance output is multiplied by a learnable gain, followed by a non-linearity (tanh) to simulate subtle
distortions.
We also model deformations, or changes from the resting state of
the physical object. A deformation in the real-world might be the turn of a tuning peg on a guitar or a change in the
position of a slide on a trombone. Deformations are modeled as a time-varying change in the weights that determine
how energy is routed to different resonances. As a concrete example, the control signal might become non-zero,
modeling a transient pluck of a guitar string. Then the deformation mix might oscillate between two different resonances,
modeling the change in bridge height induced by a whammy/tremolo bar.
Importantly, assuming we've chosen a model with sufficient capacity to describe the resonances, the control signal
scales with the length of the audio segment, but the model of the resonances does not; it is constant.
The Loss
We fit the model using a simple, L1 loss on a multi-resolution short-time fourier transform of the target and reconstruction.
I have a strong intuition that a more perceptually-informed loss would work better, and require even less overall
model capacity, but that's an experiment for another day!
We encourage sparsity in the control signal with an addition L1 loss on control signal magnitudes, pushing most entries
to zero. We hope that the model will rely heavily on resonances in the object, rather than a high-energy, busy control
signal.
For the examples on this page, we train the model for 10,000 iterations using the Adam optimizer.
User Interface
Our control signal is defined as a sparse, N-dimensional vector, varying at ~20hz. For this work,
I chose a familiar and natural interface, the human hand. We project hand-tracking landmarks onto the control
signal's input space. I should note that MediaPipe
proved indispensable for this first pass, making this part of the implementation straightforward and quick.
I'm certain that the current interface is not ideal, and I'm very excited about alternative possibilities.
We're literally swimming in multi-dimensional, time-varying sensor data, streaming from smartphones, watches,
WiFi signals and much more!
After overfitting a single audio segment of ~12 seconds, weights for the model are persisted, and then used by a
WebAudio implementation of the decoder that runs
in the browser.
To ensure that the browser-based implementation can perform in realtime, the model we learn is tiny, with a control plane
dimension of 16, a resonance count of 16, and "expressivity" of 2, the latter defining how many alternative resonances
make up out deformation space.
The Results
Reconstruction Quality
The reconstruction quality leaves much to be desired, but there's no doubt that the model captures many key aspects of
the sound. What's more, the "instrument" artifact allows us to play in the resonance space of the original recording,
producing the same tones and timbres in novel arrangements.
Model Size and Compression Ratio
Comparing the model size, including the control signal and deformations, which vary over time, we end up with a
representation about 14% the size of the original WAV audio. Clearly, this compression is lossy, and there's
much room for improvement.
Size of Web Audio Parameters
We fully "materialize" both the attack envelopes and resonances to our full sample rate (22050hz, in this case) to avoid
implementing the interpolation and damped-harmonic oscillator code in TypeScript/JavaScript. This increases the size
of the stored model significantly and could be replaced with a one-time upsampling operation prior to decoding in
future versions, to save disk space.
Examples
Example 1
Original Audio
Reconstruction
Sparse Control Signal
Random Audio
Hand-Controlled Instrument
Example 2
Original Audio
Reconstruction
Sparse Control Signal
Random Audio
Hand-Controlled Instrument
Example 3
Original Audio
Reconstruction
Sparse Control Signal
Random Audio
Hand-Controlled Instrument
Thanks for reading!
Future Work
The approach is promising, but there are many questions left to explore:
- how to create an efficient implementation using JavaScript and the WebAudio API such that we can use larger
control planes, more resonances, and more expressivity/deform-ability?
- the control-plane representation is overly-simplistic. Are there better models?
- This work seems to be driving toward physical modelling synthesis. Can we just create a differentiable physics simulation?
- what is a good, natural, intuitive set of sensor data that is readily available using smartphones or some other
pervasive technology that can be mapped on the control-plane dimensions in natural and fun ways?
Citation
If you'd like to cite this article, you can use the following BibTeX block.
from base64 import b64encode
from typing import Tuple, Callable, Union, Dict, Any
import numpy as np
import torch
from attr import dataclass
from sklearn.decomposition import DictionaryLearning
from torch import nn
from torch.optim import Adam
import conjure
from conjure import serve_conjure, SupportedContentType, NumpyDeserializer, NumpySerializer, Logger, MetaData, \
CompositeComponent, AudioComponent, ConvInstrumentComponent, conjure_article, CitationComponent, S3Collection, \
VideoComponent, ImageComponent
from data import get_one_audio_segment
from modules import max_norm, interpolate_last_axis, sparsify, unit_norm, stft
from modules.transfer import fft_convolve
from modules.upsample import upsample_with_holes, ensure_last_axis_length
from util import device, encode_audio, make_initializer
import argparse
LossFunc = Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
MaterializeResonances = Callable[..., torch.Tensor]
init_weights = make_initializer(0.02)
# TODO: How to package up and share these params
n_samples = 2 ** 17
resonance_window_size = 256
step_size = resonance_window_size // 2
n_frames = n_samples // step_size
# KLUDGE: control_plane_dim and n_resonances
# must have the same value (for now)
control_plane_dim = 16
n_resonances = 16
expressivity = 2
n_to_keep = 128
do_sparsify = False
sparsity_coefficient = 0.5
n_oscillators = 64
attack_full_size = 2048
attack_n_frames = 256
deformation_frame_size = 128
web_components_version = '0.0.101'
use_learned_deformations_for_random = False
# TODO: Move this into conjure
def encode_array(arr: Union[np.ndarray, torch.Tensor], serializer: NumpySerializer) -> str:
if isinstance(arr, torch.Tensor):
arr = arr.data.cpu().numpy()
return b64encode(serializer.to_bytes(arr)).decode()
def decaying_noise(
n_items: int,
n_samples: int,
low_exp: int,
high_exp: int,
device: torch.device,
include_noise: bool = True):
t = torch.linspace(1, 0, n_samples, device=device)
pos = torch.zeros(n_items, device=device).uniform_(low_exp, high_exp)
if include_noise:
noise = torch.zeros(n_items, n_samples, device=device).uniform_(-1, 1)
return (t[None, :] ** pos[:, None]) * noise
else:
return (t[None, :] ** pos[:, None])
def make_ramp(n_samples: int, ramp_length: int, device: torch.device) -> torch.Tensor:
ramp = torch.ones(n_samples, device=device)
ramp[:ramp_length] = torch.linspace(0, 1, ramp_length, device=device)
return ramp
def materialize_attack_envelopes(
low_res: torch.Tensor,
window_size: int,
is_fft: bool = False,
add_noise: bool = True) -> torch.Tensor:
if low_res.shape[-1] == window_size:
return low_res * torch.zeros_like(low_res).uniform_(-1, 1)
if is_fft:
low_res = torch.view_as_complex(low_res)
low_res = torch.fft.irfft(low_res)
# impulse = fft_resample(low_res[None, ...], desired_size=window_size, is_lowest_band=True)[0]
impulse = interpolate_last_axis(low_res, desired_size=window_size)
if add_noise:
impulse = impulse * torch.zeros_like(impulse).uniform_(-1, 1)
# ramp = make_ramp(impulse.shape[-1], ramp_length=10, device=impulse.device)
# impulse = impulse * ramp
return impulse
def execute_layer(
control_signal: torch.Tensor,
attack_envelopes: torch.Tensor,
mix: torch.Tensor,
routing: torch.Tensor,
res: torch.Tensor,
deformations: torch.Tensor,
gains: torch.Tensor,
window_size: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch, n_events, control_plane_dim, frames = control_signal.shape
batch, n_events, expressivity, def_frames = deformations.shape
# route control plane to resonance plane at control-rate
routed = (control_signal.permute(0, 1, 3, 2) @ routing).permute(0, 1, 3, 2)
before_upsample = routed
_, _, n_resonances, expressivity, n_samples = res.shape
# ensure that resonances all have unit norm so that control
# plane energy drives overall loudness/response
res = unit_norm(res)
# "route" energy from control plane to resonances
routed = routed.view(batch, n_events, n_resonances, 1, frames)
# convolve with noise impulse
routed = upsample_with_holes(routed, n_samples)
impulse = materialize_attack_envelopes(attack_envelopes, window_size)
impulse = ensure_last_axis_length(impulse, n_samples)
impulse = impulse.view(1, 1, control_plane_dim, 1, n_samples)
routed = fft_convolve(impulse, routed)
cs = routed.view(1, control_plane_dim, n_samples).sum(dim=1, keepdim=True)
# convolve control plane with all resonances
conv = fft_convolve(routed, res)
# interpolate between variations on each resonance
base_deformation = torch.zeros_like(deformations)
base_deformation[:, :, 0:1, :] = 1
d = base_deformation + deformations
d = torch.softmax(d, dim=-2)
d = d.view(batch, n_events, 1, expressivity, def_frames)
d = interpolate_last_axis(d, n_samples)
x = d * conv
x = torch.sum(x, dim=-2)
mixes = mix.view(1, 1, n_resonances, 1, 1, 2)
mixes = torch.softmax(mixes, dim=-1)
stacked = torch.stack([
routed,
x.reshape(*routed.shape)
], dim=-1)
x = mixes * stacked
x = torch.sum(x, dim=-1)
x = x.view(1, 1, n_resonances, -1)
summed = torch.tanh(x * torch.abs(gains.view(1, 1, n_resonances, 1)))
summed = torch.sum(summed, dim=-2, keepdim=True)
return summed, before_upsample, cs
@torch.jit.script
def damped_harmonic_oscillator(
time: torch.Tensor,
mass: torch.Tensor,
damping: torch.Tensor,
tension: torch.Tensor,
initial_displacement: torch.Tensor,
initial_velocity: float,
) -> torch.Tensor:
x = (damping / (2 * mass))
# if torch.isnan(x).sum() > 0:
# print('x first appearance of NaN')
omega = torch.sqrt(torch.clamp(tension - (x ** 2), 1e-12, np.inf))
# if torch.isnan(omega).sum() > 0:
# print('omega first appearance of NaN')
phi = torch.atan2(
(initial_velocity + (x * initial_displacement)),
(initial_displacement * omega)
)
a = initial_displacement / torch.cos(phi)
z = a * torch.exp(-x * time) * torch.cos(omega * time - phi)
return z
class DampedHarmonicOscillatorBlock(nn.Module):
def __init__(
self,
n_samples: int,
n_oscillators: int,
n_resonances: int,
expressivity: int):
super().__init__()
self.n_samples = n_samples
self.n_oscillators = n_oscillators
self.n_resonances = n_resonances
self.expressivity = expressivity
self.damping = nn.Parameter(
torch.zeros(n_oscillators, n_resonances, expressivity) \
.uniform_(0.5, 1.5))
self.mass = nn.Parameter(
torch.zeros(n_oscillators, n_resonances, expressivity) \
.uniform_(-2, 2))
self.tension = nn.Parameter(
torch.zeros(n_oscillators, n_resonances, expressivity) \
.uniform_(4, 9))
self.initial_displacement = nn.Parameter(
torch.zeros(n_oscillators, n_resonances, expressivity) \
.uniform_(-1, 2))
self.amplitudes = nn.Parameter(
torch.zeros(n_oscillators, n_resonances, expressivity, 1) \
.uniform_(-1, 1))
def _materialize_resonances(self, device: torch.device):
time = torch.linspace(0, 10, self.n_samples, device=device) \
.view(1, 1, 1, self.n_samples)
x = damped_harmonic_oscillator(
time=time,
mass=torch.sigmoid(self.mass[..., None]),
damping=torch.sigmoid(self.damping[..., None]) * 30,
tension=10 ** self.tension[..., None],
initial_displacement=self.initial_displacement[..., None],
initial_velocity=0
)
x = x.view(self.n_oscillators, self.n_resonances, self.expressivity, self.n_samples)
x = x * self.amplitudes ** 2
x = torch.sum(x, dim=0)
# ramp = make_ramp(self.n_samples, ramp_length=10, device=x.device)
return x.view(1, 1, self.n_resonances, self.expressivity, self.n_samples) #* ramp[None, None, None, None, :]
def forward(self) -> torch.Tensor:
x = self._materialize_resonances(self.damping.device)
x = unit_norm(x)
return x
class ResonanceLayer(nn.Module):
def __init__(
self,
n_samples: int,
resonance_window_size: int,
control_plane_dim: int,
n_resonances: int,
expressivity: int,
base_resonance: float = 0.5):
super().__init__()
self.expressivity = expressivity
self.n_resonances = n_resonances
self.control_plane_dim = control_plane_dim
self.resonance_window_size = resonance_window_size
self.n_samples = n_samples
self.base_resonance = base_resonance
self.attack_full_size = attack_full_size
resonance_coeffs = resonance_window_size // 2 + 1
self.attack_envelopes = nn.Parameter(
# decaying_noise(self.control_plane_dim, 256, 4, 20, device=device, include_noise=False)
torch.zeros(self.control_plane_dim, attack_n_frames).uniform_(-1, 1)
)
self.router = nn.Parameter(
torch.zeros((self.control_plane_dim, self.n_resonances)).uniform_(-1, 1))
self.resonance = DampedHarmonicOscillatorBlock(
n_samples, n_oscillators, n_resonances, expressivity
)
self.mix = nn.Parameter(torch.zeros(self.n_resonances, 2).uniform_(-1, 1))
self.gains = nn.Parameter(torch.zeros((n_resonances, 1)).uniform_(0.01, 1.1))
def get_mixes(self):
return self.mix
def get_attack_envelopes(self):
# The web audio component adds random noise each time, instead of "baking" a single
# uniform sampling into the attack envelopes
return materialize_attack_envelopes(self.attack_envelopes, self.attack_full_size, add_noise=False)
def get_materialized_resonance(self):
return self.resonance.forward()
def get_gains(self):
return self.gains
def get_router(self):
return self.router
def forward(
self,
control_signal: torch.Tensor,
deformations: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
res = self.resonance.forward()
# print(res.shape)
output, fwd, cs = execute_layer(
control_signal,
self.attack_envelopes,
self.mix,
self.router,
res,
deformations,
self.gains,
self.attack_full_size,
)
return output, fwd, cs
class ResonanceStack(nn.Module):
def __init__(
self,
n_layers: int,
n_samples: int,
resonance_window_size: int,
control_plane_dim: int,
n_resonances: int,
expressivity: int,
base_resonance: float = 0.5):
super().__init__()
self.expressivity = expressivity
self.n_resonances = n_resonances
self.control_plane_dim = control_plane_dim
self.resonance_window_size = resonance_window_size
self.n_samples = n_samples
self.base_resonance = base_resonance
self.mix = nn.Parameter(torch.zeros(n_layers))
self.layers = nn.ModuleList([ResonanceLayer(
n_samples,
resonance_window_size,
control_plane_dim,
n_resonances,
expressivity,
base_resonance
) for _ in range(n_layers)])
def get_mix(self, layer: int):
layer = self.layers[layer]
return layer.get_mixes()
def get_materialized_resonance(self, layer: int) -> torch.Tensor:
layer = self.layers[layer]
return layer.get_materialized_resonance()
def get_gains(self, layer: int) -> torch.Tensor:
layer = self.layers[layer]
return layer.get_gains()
def get_router(self, layer: int) -> torch.Tensor:
layer = self.layers[layer]
return layer.get_router()
def get_attack_envelopes(self, layer: int) -> torch.Tensor:
layer = self.layers[layer]
return layer.get_attack_envelopes()
def forward(self, control_signal: torch.Tensor, deformations: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, n_events, cpd, frames = control_signal.shape
outputs = []
cs = control_signal
fcs = None
for layer in self.layers:
output, cs, fcs = layer(cs, deformations)
outputs.append(output)
final = torch.stack(outputs, dim=-1)
mx = torch.softmax(self.mix, dim=-1)
final = final @ mx[:, None]
return final.view(batch_size, n_events, self.n_samples), fcs
class OverfitResonanceStack(nn.Module):
def __init__(
self,
n_layers: int,
n_samples: int,
resonance_window_size: int,
control_plane_dim: int,
n_resonances: int,
expressivity: int,
n_frames: int,
base_resonance: float = 0.5,
n_to_keep: int = 1024):
super().__init__()
self.expressivity = expressivity
self.n_resonances = n_resonances
self.control_plane_dim = control_plane_dim
self.resonance_window_size = resonance_window_size
self.n_samples = n_samples
self.base_resonance = base_resonance
self.n_frames = n_frames
self.n_to_keep = n_to_keep
# self.resonance_frame_size = 1024
self.deformation_frame_size = deformation_frame_size
self.deformation_frames = n_samples // self.deformation_frame_size
control_plane = torch.zeros(
(1, 1, control_plane_dim, n_frames)) \
.uniform_(-0.01, 0.01)
self.control_plane = nn.Parameter(control_plane)
deformations = torch.zeros(
(1, 1, expressivity, self.deformation_frames)).uniform_(-0.01, 0.01)
self.deformations = nn.Parameter(deformations)
self.network = ResonanceStack(
n_layers=n_layers,
n_samples=n_samples,
resonance_window_size=resonance_window_size,
control_plane_dim=control_plane_dim,
n_resonances=n_resonances,
expressivity=expressivity,
base_resonance=base_resonance
)
self.apply(init_weights)
@property
def flattened_deformations(self):
return self.deformations.view(self.expressivity, self.deformation_frames)
def _get_mapping(self, n_components: int) -> np.ndarray:
cs = self.control_signal.data.cpu().numpy() \
.reshape(self.control_plane_dim, self.n_frames).T
pca = DictionaryLearning(n_components=n_components)
pca.fit(cs)
# this will be of shape (n_components, control_plane_dim)
return pca.components_
def get_hand_tracking_mapping(self) -> np.ndarray:
mapping = self._get_mapping(n_components=21 * 3)
print('PCA Weight Shape', mapping.shape)
return mapping
def get_mixes(self, layer: int) -> torch.Tensor:
return self.network.get_mix(layer)
def get_materialized_resonance(self, layer: int) -> torch.Tensor:
return self.network.get_materialized_resonance(layer)
def get_gains(self, layer: int) -> torch.Tensor:
return self.network.get_gains(layer)
def get_router(self, layer: int) -> torch.Tensor:
return self.network.get_router(layer)
def get_attack_envelopes(self, layer: int) -> torch.Tensor:
return self.network.get_attack_envelopes(layer)
def _process_control_plane(
self,
cp: torch.Tensor,
n_to_keep: int = None,
do_sparsify: bool = do_sparsify) -> torch.Tensor:
cp = cp.view(1, self.control_plane_dim, self.n_frames)
# I _think_ this allows gradients to flow back to all elements, rather than
# just the top-k
cp = cp + cp.mean()
if do_sparsify:
cp = sparsify(cp, n_to_keep=n_to_keep or self.n_to_keep)
cp = cp.view(1, 1, self.control_plane_dim, self.n_frames)
cp = torch.relu(cp)
return cp
@property
def control_signal(self):
cp = self.control_plane
cp = self._process_control_plane(cp)
return cp
@property
def active_elements(self) -> torch.Tensor:
return (self.control_signal > 0).sum()
@property
def sparsity(self) -> torch.Tensor:
return self.active_elements / self.control_signal.numel()
def random(self, use_learned_deformations: bool = False):
# print(self.control_plane.min().item(), self.control_plane.max().item())
rcp = torch \
.zeros_like(self.control_plane) \
.uniform_(
self.control_plane.min().item(),
self.control_plane.max().item())
rcp = self._process_control_plane(rcp, n_to_keep=16, do_sparsify=True)
x = self.forward(
rcp, self.deformations if use_learned_deformations else torch.zeros_like(self.deformations))
return x
def forward(
self,
cp: torch.Tensor = None,
deformations: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
cp = cp if cp is not None else self.control_signal
deformations = deformations if deformations is not None else self.deformations
x, cs = self.network.forward(cp, deformations)
return x, cs, cp
def compression_ratio(self, n_samples: int):
# thanks to https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# Note: This approach should work regardless of whether we have a "hard" number
# of non-zero control-plan elements to keep, or if we have a sparsity penalty
# that encourages a sparse control plane
return (self.deformations.numel() + count_parameters(self.network) + (self.active_elements * 3)) / n_samples
def produce_web_audio_params(self) -> dict:
hand = self.get_hand_tracking_mapping().T
assert hand.shape == (self.control_plane_dim, 21 * 3)
router = self.get_router(0).T
assert router.shape == (self.control_plane_dim, self.n_resonances)
gains = self.get_gains(0).view(-1)
assert gains.shape == (self.n_resonances,)
resonances = self.get_materialized_resonance(0).reshape(-1, self.n_samples)
assert resonances.shape == (self.n_resonances * self.expressivity, self.n_samples)
attacks = self.get_attack_envelopes(0)
mixes = self.get_mixes(0)
serializer = NumpySerializer()
params = dict(
gains=encode_array(gains, serializer),
router=encode_array(router, serializer),
resonances=encode_array(resonances, serializer),
hand=encode_array(hand, serializer),
attacks=encode_array(attacks, serializer),
mix=encode_array(mixes, serializer)
)
return params
def generate_param_dict(
key: str,
logger: Logger,
model: OverfitResonanceStack) -> [dict, MetaData]:
params = model.produce_web_audio_params()
_, meta = logger.log_json(key, params)
print('WEIGHTS URI', meta.public_uri.geturl())
return params, meta
@dataclass
class OverfitModelResult:
model: OverfitResonanceStack
target: torch.Tensor
recon: torch.Tensor
rand: torch.Tensor
control_signal: torch.Tensor
def transform(x: torch.Tensor) -> torch.Tensor:
return stft(x, 2048, 256, pad=True)
def compute_loss(
x: torch.Tensor,
y: torch.Tensor,
cp: torch.Tensor,
sparsity_loss: float = sparsity_coefficient) -> torch.Tensor:
x = transform(x)
y = transform(y)
return torch.abs(x - y).sum() + (torch.abs(cp).sum() * sparsity_loss)
def produce_overfit_model(
n_samples: int,
resonance_window_size: int,
control_plane_dim: int,
n_resonances: int,
expressivity: int,
n_to_keep: int,
n_iterations: int,
loss_func: LossFunc) -> OverfitModelResult:
"""
Overfit a resonance model on a single audio segment and return
artifacts, including the trained model itself
"""
step_size = resonance_window_size // 2
n_frames = n_samples // step_size
target = get_one_audio_segment(n_samples)
target = max_norm(target)
model = OverfitResonanceStack(
n_layers=1,
n_samples=n_samples,
resonance_window_size=resonance_window_size,
control_plane_dim=control_plane_dim,
n_resonances=n_resonances,
expressivity=expressivity,
base_resonance=0.01,
n_frames=n_frames,
n_to_keep=n_to_keep
).to(device)
optimizer = Adam(model.parameters(), lr=1e-3)
for i in range(n_iterations):
optimizer.zero_grad()
recon, fcs, cp = model.forward()
loss = loss_func(target, recon, cp)
print(i, loss.item())
loss.backward()
optimizer.step()
with torch.no_grad():
rnd = model.random(use_learned_deformations=use_learned_deformations_for_random)[0]
return OverfitModelResult(
target=max_norm(target),
recon=max_norm(recon),
rand=max_norm(rnd),
model=model,
control_signal=cp
)
def produce_audio_component(
logger: conjure.Logger,
key: str,
audio: Union[torch.Tensor, np.ndarray]) -> AudioComponent:
_, meta = logger.log_sound(key, audio)
component = AudioComponent(meta.public_uri, height=200)
return component
def produce_control_signal_component(logger: conjure.Logger, key: str, control_signal: torch.Tensor) -> ImageComponent:
_, meta = logger.log_matrix_with_cmap(key, control_signal.view(control_plane_dim, n_frames), cmap='gray')
component = ImageComponent(meta.public_uri, height=200)
return component
def produce_content_section(
logger: conjure.Logger,
n_samples: int,
resonance_window_size: int,
control_plane_dim: int,
n_resonances: int,
expressivity: int,
n_to_keep: int,
n_iterations: int,
loss_func: LossFunc,
example_number: int = 1
) -> CompositeComponent:
result = produce_overfit_model(
n_samples,
resonance_window_size,
control_plane_dim,
n_resonances,
expressivity,
n_to_keep,
n_iterations,
loss_func
)
orig_component = produce_audio_component(logger, 'orig', result.target)
recon_component = produce_audio_component(logger, 'recon', result.recon)
rand_component = produce_audio_component(logger, 'rand', result.rand)
cs_component = produce_control_signal_component(logger, 'cs', result.control_signal)
_, meta = generate_param_dict('model', logger, result.model)
print('INSTRUMENT URI', meta.public_uri)
instr_component = ConvInstrumentComponent(meta.public_uri)
return CompositeComponent(
header=f'# Example {example_number}',
orig=f'## Original Audio',
orig_component=orig_component,
recon=f'## Reconstruction',
recon_component=recon_component,
control_signal_header=f'## Sparse Control Signal',
control_signal=cs_component,
rand=f'## Random Audio',
rand_component=rand_component,
conv=f'## Hand-Controlled Instrument',
conv_component=instr_component,
)
def conv_instrument_dict(
n_examples: int,
logger: conjure.Logger,
n_samples: int,
resonance_window_size: int,
control_plane_dim: int,
n_resonances: int,
expressivity: int,
n_to_keep: int,
n_iterations: int,
loss_func: LossFunc) -> Dict[str, Any]:
examples = {
f'example_{i + 1}': produce_content_section(
logger,
n_samples,
resonance_window_size,
control_plane_dim,
n_resonances,
expressivity,
n_to_keep,
n_iterations,
loss_func,
example_number=i + 1)
for i in range(n_examples)}
return dict(
videoexample=VideoComponent(
src='https://resonancemodel.s3.us-east-1.amazonaws.com/resonancemodel.mp4',
width=500,
height=500,
start_time=1.4
),
examples=CompositeComponent(**examples),
citation=CitationComponent(
tag='johnvinyardresonancemodel',
author='Vinyard, John',
url='https://blog.cochlea.xyz/resonance-model.html',
header='Factoring Audio Into Playable Resonance Models',
year='2025',
)
)
def generate_article(n_iteraations: int, n_examples: int):
collection = S3Collection('resonancemodel', is_public=True, cors_enabled=True)
logger = conjure.Logger(collection)
content = conv_instrument_dict(
logger=logger,
n_samples=n_samples,
resonance_window_size=resonance_window_size,
control_plane_dim=control_plane_dim,
n_resonances=n_resonances,
expressivity=expressivity,
n_to_keep=n_to_keep,
n_iterations=n_iteraations,
loss_func=compute_loss,
n_examples=n_examples
)
conjure_article(
__file__,
'html',
title='Extracting Playable Instrument Models from Short Audio Examples',
web_components_version=web_components_version,
**content
)
def overfit_model():
"""
function to support live training and monitoring
"""
target = get_one_audio_segment(n_samples)
target = max_norm(target)
model = OverfitResonanceStack(
n_layers=1,
n_samples=n_samples,
resonance_window_size=resonance_window_size,
control_plane_dim=control_plane_dim,
n_resonances=n_resonances,
expressivity=expressivity,
base_resonance=0.01,
n_frames=n_frames,
n_to_keep=n_to_keep
).to(device)
optimizer = Adam(model.parameters(), lr=1e-3)
collection = conjure.LmdbCollection(path='resonancemodel')
remote_collection = conjure.S3Collection('resonancemodel', is_public=True, cors_enabled=True)
remote_logger = conjure.Logger(remote_collection)
t, r, rand, res, att = conjure.loggers(
['target', 'recon', 'random', 'resonance', 'att'],
'audio/wav',
encode_audio,
collection,
store_history=True)
def to_numpy(x: torch.Tensor) -> np.ndarray:
return x.data.cpu().numpy()
c, deformations, routing = conjure.loggers(
['control', 'deformations', 'routing'],
SupportedContentType.Spectrogram.value,
to_numpy,
collection,
serializer=NumpySerializer(),
deserializer=NumpyDeserializer(),
store_history=True)
serve_conjure(
[t, r, c, rand, res, deformations, routing, att],
port=9999,
n_workers=1,
web_components_version=web_components_version)
t(max_norm(target))
def train():
iteration = 0
while True:
optimizer.zero_grad()
recon, fcs, cp = model.forward()
r(max_norm(recon))
c(model.control_signal[0, 0])
loss = compute_loss(target, recon, cp)
loss.backward()
optimizer.step()
print(
iteration,
loss.item(),
model.compression_ratio(n_samples).item(),
model.active_elements.item(),
model.sparsity.item())
deformations(model.flattened_deformations)
routing(torch.abs(model.get_router(0)))
fcs = max_norm(fcs)
att(fcs)
with torch.no_grad():
rand(max_norm(model.random(use_learned_deformations=use_learned_deformations_for_random)[0]))
rz = model.get_materialized_resonance(0).view(-1, n_samples)
res(max_norm(rz[np.random.randint(0, n_resonances * expressivity - 1)]))
iteration += 1
if iteration > 0 and iteration % 10000 == 0:
print('Serializing')
generate_param_dict('resonancemodelparams', remote_logger, model)
input('Continue?')
train()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train', 'article'], default='train')
parser.add_argument('--iterations', type=int, default=10000)
parser.add_argument('--examples', type=int, default=2)
args = parser.parse_args()
if args.mode == 'train':
overfit_model()
elif args.mode == 'article':
generate_article(n_examples=args.examples, n_iteraations=args.iterations)
else:
raise ValueError(f'Unknown mode {args.mode}')
Back to Top