Audio Splatting With Physics-Inspired Event Generators

Table of Contents

Audio Splatting

Gaussian Splatting

In Gaussian Splatting, a large number of three-dimensional gaussians are randomly initialized and then fit via backpropagation to several two-dimensional views of a three-dimensional scene/environment.

Application to Audio

We draw inspiration from this field of research and apply a similar process to audio "atoms" used to compose a reconstruction of acoustic instruments in a physical space.

In this work, we use a (roughly) physics-inspired event generator, similar to the one used in Toward a Sparse Interpretable Audio Codec to overfit to a single audio segment, drawn from the MusicNet dataset.

Training Process

We randomly initialize 64 event vectors and then iteratively minimize a differentiable, perceptually-inspired loss function for 3,000 iterations.

Finally, we randomly perturb the learned/overfit event vectors to begin to get a sense for some of the ways we might manipulate and edit the sparse representation.

A Sparse, Interpretable Representation

The sparse, event-based representation shows promise for interpretability and manipulability, event without a trained encoder network.

Future Work

The current work deals only with mono audio at 22050hz, but it's possible to imagine extending to stereo, or even multi-microphone situations where a three-dimensional sound "field" needs to be approximated.

A previous version of this article can be found here.

All code for this experiment can be found here.

Cite this Work

If you'd like to cite this article, you can use the following BibTeX block.

Example 1

Original Audio

Reconstruction

Perturbed Audio

Timeline

Example 2

Original Audio

Reconstruction

Perturbed Audio

Timeline

import argparse from itertools import count from typing import Tuple, Dict, List import numpy as np import torch from torch import nn from torch.optim import Adam from conjure import LmdbCollection, loggers, serve_conjure, SupportedContentType, NumpySerializer, NumpyDeserializer, \ conjure_article, CompositeComponent, S3Collection, AudioComponent, AudioTimelineComponent, CitationComponent from conjure.logger import encode_audio, Logger from data import get_one_audio_segment from modules import unit_norm, flattened_multiband_spectrogram, sparse_softmax, max_norm, amplitude_envelope, \ iterative_loss from modules.eventgenerators.splat import SplattingEventGenerator from modules.infoloss import CorrelationLoss from modules.multiheadtransform import MultiHeadTransform from util import device, make_initializer from sklearn.manifold import TSNE initializer = make_initializer(0.05) article_title = 'Audio Splatting With Physics-Inspired Event Generators' class OverfitHierarchicalEvents(nn.Module): def __init__( self, n_samples: int, samplerate: int, n_events: int, context_dim: int): super().__init__() self.n_samples = n_samples self.samplerate = samplerate self.n_events = n_events self.context_dim = context_dim self.n_frames = n_samples // 256 event_levels = int(np.log2(n_events)) total_levels = int(np.log2(n_samples)) self.event_levels = event_levels self.event_generator = SplattingEventGenerator( n_samples=n_samples, samplerate=samplerate, n_resonance_octaves=64, n_frames=n_samples // 256, hard_reverb_choice=False, hierarchical_scheduler=True, wavetable_resonance=True, ) self.transform = MultiHeadTransform( self.context_dim, hidden_channels=128, shapes=self.event_generator.shape_spec, n_layers=1) self.event_time_dim = int(np.log2(self.n_samples)) rng = 0.1 self.event_vectors = nn.Parameter(torch.zeros(1, 2, self.context_dim).uniform_(-rng, rng)) self.hierarchical_event_vectors = nn.ParameterDict( {str(i): torch.zeros(1, 2, self.context_dim).uniform_(-rng, rng) for i in range(event_levels - 1)}) self.times = nn.Parameter( torch.zeros(1, 2, total_levels, 2).uniform_(-rng, rng)) self.hierarchical_time_vectors = nn.ParameterDict( {str(i): torch.zeros(1, (2 ** (i + 2)), total_levels, 2).uniform_(-rng, rng) for i in range(event_levels - 1)}) self.apply(initializer) @property def normalized_atoms(self): return unit_norm(self.atoms, dim=-1) def _forward( self, events: torch.Tensor, times: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: for i in range(self.event_levels - 1): # scale = 1 / (i + 1) scale = 1 # TODO: consider bringing back scaling as we approach the leaves of the tree events = \ events.view(1, -1, 1, self.context_dim) \ + (self.hierarchical_event_vectors[str(i)].view(1, 1, 2, self.context_dim) * scale) events = events.view(1, -1, self.context_dim) # TODO: Consider masking lower bits as we approach the leaves of the tree, so that # new levels can only _refine_, and not completely move entire branches batch, n_events, n_bits, _ = times.shape times = times.view(batch, n_events, 1, n_bits, 2).repeat(1, 1, 2, 1, 1).view(batch, n_events * 2, n_bits, 2) times = times + (self.hierarchical_time_vectors[str(i)] * scale) event_vectors = events params = self.transform.forward(events) print('TIMES', times.shape) events = self.event_generator.forward(**params, times=times) return events, event_vectors, times def perturbed(self): events = self.event_vectors.clone() times = self.times.clone() perturbation = torch.zeros_like(events).uniform_(-0.5, 0.5) return self._forward(events + perturbation, times) def forward(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: events = self.event_vectors.clone() times = self.times.clone() return self._forward(events, times) def loss_transform(x: torch.Tensor) -> torch.Tensor: return flattened_multiband_spectrogram( x, stft_spec={ 'long': (128, 64), 'short': (64, 32), 'xs': (16, 8), }, smallest_band_size=512) def reconstruction_loss(target: torch.Tensor, recon: torch.Tensor) -> torch.Tensor: target_spec = loss_transform(target) recon_spec = loss_transform(recon) loss = torch.abs(target_spec - recon_spec).sum() return loss def to_numpy(x: torch.Tensor): return x.data.cpu().numpy() def overfit(): n_samples = 2 ** 16 samplerate = 22050 n_events = 64 event_dim = 16 # Begin: this would be a nice little helper to wrap up collection = LmdbCollection(path='hierarchical') collection.destroy() collection = LmdbCollection(path='hierarchical') recon_audio, orig_audio, perturbed_audio = loggers( ['recon', 'orig', 'perturbed'], 'audio/wav', encode_audio, collection) eventvectors, eventtimes = loggers( ['eventvectors', 'eventtimes'], SupportedContentType.Spectrogram.value, to_numpy, collection, serializer=NumpySerializer(), deserializer=NumpyDeserializer() ) audio = get_one_audio_segment(n_samples, samplerate, device='cpu') target = audio.view(1, 1, n_samples).to(device) orig_audio(target) serve_conjure([ orig_audio, recon_audio, perturbed_audio, eventvectors, eventtimes, ], port=9999, n_workers=1) # end proposed helper function model = OverfitHierarchicalEvents( n_samples, samplerate, n_events, context_dim=event_dim).to(device) optim = Adam(model.parameters(), lr=1e-3) loss_model = CorrelationLoss(n_elements=512).to(device) for i in count(): optim.zero_grad() recon, vectors, times = model.forward() times = sparse_softmax(times, dim=-1) weights = torch.from_numpy(np.array([0, 1])).to(device).float() eventvectors(max_norm(vectors[0])) t = times[0] @ weights eventtimes((t > 0).float()) recon_summed = torch.sum(recon, dim=1, keepdim=True) recon_audio(max_norm(recon_summed)) perturbed, _, _ = model.perturbed() perturbed_summed = torch.sum(perturbed, dim=1, keepdim=True) perturbed_summed = max_norm(perturbed_summed) perturbed_audio(perturbed_summed) loss = loss_model.multiband_noise_loss(target, recon_summed, 128, 32) loss.backward() optim.step() print(i, loss.item()) def process_events2( logger: Logger, events: torch.Tensor, vectors: torch.Tensor, times: torch.Tensor, total_seconds: float, n_events:int, context_dim: int) -> Tuple[List[Dict], Dict]: # compute amplitude envelopes envelopes = amplitude_envelope(events, 128).data.cpu().numpy().reshape((n_events, -1)) norms = torch.norm(events, dim=-1).reshape((-1)) max_norm = torch.max(norms) opacities = norms / (max_norm + 1e-12) # compute event positions/times, in seconds times = [float(x) for x in times.reshape((-1,))] # normalize event vectors and map onto the y dimension normalized = vectors.data.cpu().numpy().reshape((-1, context_dim)) normalized = normalized - normalized.min(axis=0, keepdims=True) normalized = normalized / (normalized.max(axis=0, keepdims=True) + 1e-8) tsne = TSNE(n_components=1) points = tsne.fit_transform(normalized) points = points - points.min() points = points / (points.max() + 1e-8) # create a random projection to map colors proj = np.random.uniform(0, 1, (context_dim, 3)) colors = normalized @ proj colors -= colors.min() colors /= (colors.max() + 1e-8) colors *= 255 colors = colors.astype(np.uint8) colors = [f'rgba({c[0]}, {c[1]}, {c[2]}, {opacities[i]})' for i, c in enumerate(colors)] evts = {f'event{i}': events[:, i: i + 1, :] for i in range(events.shape[1])} # # event_components = {} # for k, v in events.items(): # _, e = logger.log_sound(k, v) # scatterplot_srcs.append(e.public_uri) # event_components[k] = AudioComponent(e.public_uri, height=15, controls=False) scatterplot_srcs = [] for k, v in evts.items(): _, e = logger.log_sound(k, v) scatterplot_srcs.append(e.public_uri) event_components[k] = AudioComponent(e.public_uri, height=15, controls=False) return [{ 'eventTime': times[i], 'offset': times[i], 'y': float(points[i]), 'color': colors[i], 'audioUrl': scatterplot_srcs[i].geturl(), 'eventEnvelope': envelopes[i].tolist(), 'eventDuration': total_seconds, } for i in range(n_events)], event_components # t = np.array(times) / total_seconds # points = np.concatenate([points.reshape((-1, 1)), t.reshape((-1, 1))], axis=-1) # # return points, times, colors def reconstruction_section(logger: Logger, samplerate: int, context_dim: int, n_iterations: int = 1000) -> CompositeComponent: n_samples = 2 ** 16 samplerate = 22050 n_events = 64 event_dim = 16 total_seconds = n_samples / samplerate audio = get_one_audio_segment(n_samples, samplerate, device='cpu') target = audio.view(1, 1, n_samples).to(device) model = OverfitHierarchicalEvents( n_samples, samplerate, n_events, context_dim=event_dim).to(device) optim = Adam(model.parameters(), lr=1e-3) loss_model = CorrelationLoss(n_elements=512).to(device) for i in range(n_iterations): optim.zero_grad() recon, vectors, times = model.forward() recon_summed = torch.sum(recon, dim=1, keepdim=True) # loss = loss_model.multiband_noise_loss(target, recon_summed, 128, 32) # loss = iterative_loss(target, recon, loss_transform) t = loss_transform(target) r = loss_transform(recon_summed) loss = torch.abs(t - r).sum() loss.backward() optim.step() print(i, loss.item()) recon, vectors, times = model.forward() perturbed, _, _ = model.perturbed() perturbed = torch.sum(perturbed, dim=1, keepdim=True) recon_summed = torch.sum(recon, dim=1, keepdim=True) recon_summed = max_norm(recon_summed) # first, ensure only one element is activated times = sparse_softmax(times, dim=-1) # project to [0, 1] space weights = torch.from_numpy(np.array([0, 1])).to(device).float() t = times[0] @ weights t = t.data.cpu().numpy() time_levels = int(np.log2(n_samples)) # create coefficients for each entry in the binary vector time_coeffs = np.zeros((time_levels, 1)) time_coeffs[:] = 2 exponents = np.linspace(0, time_levels - 1, time_levels) print('EXPONENTS', exponents, exponents.shape) time_coeffs = time_coeffs ** exponents print(time_coeffs) sample_times = t @ time_coeffs[..., None] print('TIME IN SAMPLES', sample_times.min(), sample_times.max()) times_in_seconds = sample_times / samplerate print('TIMES IN SECONDS', times_in_seconds.min(), times_in_seconds.max()) _, orig_audio = logger.log_sound('orig', target) orig_component = AudioComponent(orig_audio.public_uri, height=100) _, recon_audio = logger.log_sound('recon', recon_summed) recon_component = AudioComponent(recon_audio.public_uri, height=100) _, p_audio = logger.log_sound('perturbed', perturbed) p_component = AudioComponent(p_audio.public_uri, height=100) events, event_components = process_events2(logger, recon, vectors, times_in_seconds, total_seconds, n_events, context_dim) timeline = AudioTimelineComponent(duration=total_seconds, width=1000, height=500, events=events) return CompositeComponent( orig='Original Audio', orig_audio=orig_component, recon='Reconstruction', recon_audio=recon_component, perturbed='Perturbed Audio', perturbed_audio=p_component, timeline='Timeline', timeline_component=timeline ) def demo_page_dict() -> Dict[str, any]: remote = S3Collection('audio-splatting', is_public=True, cors_enabled=True) logger = Logger(remote) n_iterations = 2000 samplerate = 22050 context_dim = 16 example_1 = reconstruction_section(logger, samplerate, context_dim, n_iterations) example_2 = reconstruction_section(logger, samplerate, context_dim, n_iterations) # example_3 = reconstruction_section(logger, samplerate, context_dim, n_iterations) # example_4 = reconstruction_section(logger, samplerate, context_dim, n_iterations) citation = CitationComponent( tag='johnvinyardaudiosplatting', author='Vinyard, John', url='https://blog.cochlea.xyz/audio-splatting.html', header=article_title, year='2025', ) return dict( example_1=example_1, example_2=example_2, # example_3=example_3, # example_4=example_4, citation=citation ) def generate_demo_page(): display = demo_page_dict() conjure_article( __file__, 'html', title=article_title, **display) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--mode', type=str, required=True, choices=['train', 'demo']) args = parser.parse_args() if args.mode == 'train': overfit() else: generate_demo_page()
Back to Top