Create directory which you will place program files. Place snac-encode.py from below.
user@DESKTOP-XXXXXXX:~$ mkdir snac && cd snac
Refreshes the local package index, upgrade packages
user@DESKTOP-XXXXXXX:~/snac$ sudo apt update && sudo apt upgrade
Prepare pip
user@DESKTOP-XXXXXXX:~/snac$ sudo apt install python3-pip
Install virtual environment for Python 3 libraries.
user@DESKTOP-XXXXXXX:~/snac$ sudo apt install python3.12-venv
Create your own virtual environment for this SNAC and required library to run, separated from the system's.
user@DESKTOP-XXXXXXX:~/snac$ python3 -m venv ~/venv
Enter the virtual environment. Your libraries will be placed in this separate environment.
user@DESKTOP-XXXXXXX:~/snac$ source ~/venv/bin/activate
Install torch and torchaudio Python libraries, this will be used in audio processing.
(venv) user@DESKTOP-XXXXXXX:~/snac$ pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
Install required Python libraries and SNAC
(venv) user@DESKTOP-XXXXXXX:~/snac$ pip install pandas numpy snac
Confirm Torch is installed
(venv) user@DESKTOP-XXXXXXX:~/snac$ python -c "import torch; print(torch.__version__)"
Ensure CUDA is available (Print True)
(venv) user@DESKTOP-XXXXXXX:~/snac$ python -c "import torch; print(torch.cuda.is_available())"
Install sox and FFmpeg, this is used to read and write .wav files.
(venv) user@DESKTOP-XXXXXXX:~/snac$ sudo apt install sox ffmpeg
Usage:
(venv) user@DESKTOP-XXXXXXX:~/snac$ time python snac-encode.py -i input.wav -o encoded.snac -d decoded.wav
Source code (snac-encode.py)
#!/usr/bin/env python3
import torch
import torchaudio
import argparse
import os
import numpy as np
from snac import SNAC
def parse_arguments():
parser = argparse.ArgumentParser(description='SNAC Audio Encoder/Decoder')
parser.add_argument('-i', '--input', required=True, help='Input WAV file')
parser.add_argument('-o', '--output', required=True, help='Output SNAC file')
parser.add_argument('-d', '--decoded', required=True, help='Decoded (reconstructed) WAV file')
parser.add_argument('-s', '--sample_rate', type=int, default=44100,
help='Target sample rate (default: 44100)')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
help='Device to use (cuda/cpu)')
parser.add_argument('--model', type=str, default='hubertsiuzdak/snac_44khz',
help='SNAC model to use')
return parser.parse_args()
def load_audio(file_path, target_sr=44100):
"""Load audio file and resample if needed."""
waveform, sample_rate = torchaudio.load(file_path)
# Resample if needed
if sample_rate != target_sr:
print(f"Resampling from {sample_rate}Hz to {target_sr}Hz")
resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
waveform = resampler(waveform)
return waveform, target_sr
def process_audio(model, audio, device):
"""Process audio through SNAC model."""
# Move audio to the appropriate device
audio = audio.to(device)
with torch.inference_mode():
# Process through SNAC model
audio_hat, codes = model(audio)
return audio_hat, codes
def save_codes(codes, file_path):
"""Save SNAC codes to a file."""
# Convert codes to numpy arrays and save
codes_data = {}
for i, code_sequence in enumerate(codes):
codes_data[f'layer_{i}'] = code_sequence.cpu().numpy()
# Save using numpy's compressed format
np.savez_compressed(file_path, **codes_data)
print(f"Saved compressed codes to {file_path}")
def load_codes(file_path):
"""Load SNAC codes from a file."""
data = np.load(file_path)
codes = []
# Convert back to torch tensors
for i in range(len(data.files)):
key = f'layer_{i}'
if key in data:
codes.append(torch.from_numpy(data[key]))
return codes
def save_audio(audio, file_path, sample_rate):
"""Save audio tensor to WAV file."""
torchaudio.save(file_path, audio.cpu(), sample_rate)
print(f"Saved reconstructed audio to {file_path}")
def main():
args = parse_arguments()
device = args.device
# Load SNAC model
print(f"Loading SNAC model from {args.model}")
model = SNAC.from_pretrained(args.model).eval().to(device)
# Load and preprocess audio
print(f"Loading audio from {args.input}")
waveform, sample_rate = load_audio(args.input, args.sample_rate)
num_channels = waveform.shape[0]
if num_channels > 2:
print(f"Warning: Audio has {num_channels} channels. Only the first two will be processed.")
waveform = waveform[:2]
num_channels = 2
# Process each channel separately
reconstructed_channels = []
all_codes = []
for ch in range(num_channels):
print(f"Processing channel {ch+1}/{num_channels}")
# Extract single channel and add batch dimension (B, 1, T)
channel_audio = waveform[ch:ch+1].unsqueeze(0)
# Process through SNAC
audio_hat, codes = process_audio(model, channel_audio, device)
# Store results
reconstructed_channels.append(audio_hat.squeeze(0))
all_codes.append(codes)
# Combine reconstructed channels
if num_channels == 2:
reconstructed_audio = torch.cat(reconstructed_channels, dim=0)
else:
reconstructed_audio = reconstructed_channels[0]
# Save reconstructed audio
save_audio(reconstructed_audio, args.decoded, sample_rate)
# Save compressed codes
output_base, _ = os.path.splitext(args.output)
if num_channels == 1:
# Single channel, just save directly
save_codes(all_codes[0], args.output)
else:
# For stereo, save each channel with suffix
for ch in range(num_channels):
channel_output = f"{output_base}_ch{ch+1}.snac"
save_codes(all_codes[ch], channel_output)
# Create a manifest file with channel info
with open(args.output, 'w') as f:
f.write(f"channels: {num_channels}\n")
for ch in range(num_channels):
f.write(f"channel_{ch+1}: {os.path.basename(output_base)}_ch{ch+1}.snac\n")
print("Processing complete!")
if __name__ == "__main__":
main()