Wed Mar 1 13:24:27 UTC 2023
Project to explore writing an FFT from scratch and an opportunity to flex the latest C++. The aspiration is to keep up with realtime audio without depending on third-party libraries or firmware. See the pipeline for this repo.
The output of gprof
is pretty unfriendly to parse for a
person, but piping it through gprof2dot
is much clearer.
However, in this case it’s highlighting routines in the guts of the
thread library.
The challenge is to keep the code readable/maintainable but run
quickly: i.e., reserve the weird stuff for the hot spots. Conventional
wisdom conditions you to always use const
references, but
here I’ve simply used const
for simplicity and only
introduced references when I could measure an improvement.
See the Cooley-Tukey FFT algorithm.
g++
over clang++
-Ofast
float
instead of double
(also lets you fit
a larger twiddle matrix in memory)The results summary below documents the performance of the single core pipe where this job runs.
Full gas runs are built thus:
++ test.cxx -std=c++23 -Ofast -DNDEBUG g
It’s quite interesting how the profile changes as you increase core count: single threaded processes suddenly become the bottleneck. For instance, generating a twiddle matrix took 3s on my laptop for 24K bins, but 44s on a higher-powered VM with the DFT increased to 64K bins, simply because that part of the processing is single-threaded.
Below the DFT has far greater bins than is required for DVD quality audio, and it is processing at 1.5 x realtime. It was generated on a 16-core Google Cloud VM with 62GB of RAM.
The results below are written by the GitLab runner that builds this web page. The VM can just about keep up with 24576 bins at realtime. The twiddle matrix is only 1GB.
#include "fmt/core.h" // for print
#include "gtest/gtest.h" // for Test, AssertionResult, Message, TestPartResult
#include <algorithm> // for __copy_fn, copy
#include <assert.h> // for assert
#include <bits/chrono.h> // for microseconds, steady_clock, duration_cast
#include <chrono> // for
#include <cmath> // for sin
#include <complex> // for operator*, complex, operator""i, abs, exp
#include <execution> // for for_each, seq
#include <filesystem> // for directory_iterator, path, operator!=
#include <fstream> // for ifstream
#include <functional> // for identity, less
#include <iterator> // for ostream_iterator
#include <map> // for map
#include <numbers> // for pi
#include <numeric> // for iota
#include <ranges> // for views
#include <span> // for span
#include <stddef.h> // for size_t
#include <stdint.h> // for uint32_t
#include <string> // for string, allocator, char_traits, operator+
#include <variant> // for variant
#include <vector> // for vector, vector<>::value_type
/**
The number of bins of the DFT is fixed at compile time. However, it's used in
two contexts -- to define the size of containers, and in complex number
calculations -- so multiple versions can be instantiated at compile time to
avoid conversion at run time.
*/
template <class T> constexpr T bins = T{24'576};
/**
The most computationally expensive part of this process is the exponent
calculation below, which must be repeated across all samples for every DFT bin.
However, there's nothing stopping us doing this up front as it doesn't change.
"(n^2) is the sweet spot of badly scaling algorithms: fast enough to make it
into production, but slow enough to make things fall down once it gets there."
-- Bruce Dawson
*/
/// Warn the compiler that this routine is going to make your computer hot
template <typename T> [[gnu::hot]] auto generate_twiddle_matrix() {
static_assert(std::is_floating_point_v<T>);
// Initialise matrix
using row_t = std::vector<std::complex<T>>;
const auto row = row_t(bins<size_t>);
auto matrix = std::vector<row_t>(bins<size_t> / 2, row);
// Populate and return matrix
for (auto k = 0uz; k < bins<size_t> / 2; ++k)
for (auto n = 0uz; n < bins<size_t>; ++n) {
// Casts should be ugly and easy to find
const auto _k = static_cast<T>(k);
const auto _n = static_cast<T>(n);
[k][n] = exp(_k * _n * std::complex<T>{0.0f, 2.0f} *
matrixstd::numbers::pi_v<T> / bins<T>);
}
return matrix;
}
/// Get the twiddle for a given pair of indices
template <typename T> struct twiddle_t {
constexpr auto operator[](size_t x, size_t y) const {
assert(x < bins<size_t> / 2);
assert(y < bins<size_t>);
static const auto matrix = generate_twiddle_matrix<T>();
return matrix[x][y];
}
};
// Calculate the sum of all responses for this frequency bin
auto response(std::span<const std::floating_point auto> samples,
const size_t x) {
assert(std::size(samples) == bins<size_t>);
assert(x < bins<size_t> / 2);
using T = std::decay_t<decltype(samples)>::value_type;
// Initialise twiddle matrix
static const twiddle_t<T> twiddle;
// I wouldn't normally advocate this old-school for-loop style but it's the
// clearest way to get the sample index into the twiddle calculation
auto response = std::complex<T>{};
// Operator overload allows multidimensional array access
for (auto y = 0uz; y < bins<size_t>; ++y)
+= samples[y] * twiddle[x, y];
response
return response;
}
/**
I originally wrote this whole process as one large routine but breaking it up
does make the profile results clearer. It's also easier to unit test, benchmark
and refactor.
*/
/// Calculate the response across the samples in each bin of the DFT
auto analyse(std::span<const std::floating_point auto> samples) {
assert(std::ssize(samples) == bins<size_t>);
// The incoming floating point type is unwieldy so let's define a type alias
using T = std::decay_t<decltype(samples)>::value_type;
// Initialise the results container with the index, this allows us to
// parallelise the calculation without keeping track of the current element's
// index
std::vector<T> dft(bins<size_t> / 2);
std::iota(begin(dft), end(dft), T{});
// Remember you must link against tbb for any of this execution policy
// stuff or it will quietly execute serially
std::for_each(std::execution::par, begin(dft), end(dft), [&](T &bin) {
// Convert to a real by scaling the absolute value by the number of bins
= std::abs(response(samples, static_cast<size_t>(bin))) / bins<T>;
bin });
return dft;
}
/// Dump the DFT results as a CSV for plotting
void write_csv(const auto dft, const std::string stem) {
assert(not stem.empty());
assert(std::size(dft) == bins<size_t> / 2);
using T = decltype(dft)::value_type;
if (std::ofstream csv_file{"csv/" + stem + ".csv"}; csv_file.good())
std::ranges::copy(dft, std::ostream_iterator<T>(csv_file, "
"));
}
/// Write the summary string to file
void write_summary(const auto stats) {
assert(not stats.empty());
std::ofstream out{"summary.txt"};
for (const auto &[key, value] : stats) {
<< "- ";
out
if (std::holds_alternative<size_t>(value))
<< std::get<size_t>(value);
out if (std::holds_alternative<float>(value))
<< std::get<float>(value);
out if (std::holds_alternative<double>(value))
<< std::get<double>(value);
out
<< " " << key << "
out ";
}
}
/// Read a WAV file and return the samples as a container of floating points
template <typename T> auto read_wav(const std::string file_name) {
assert(file_name.ends_with(".wav"));
// Structure of a WAV header
struct {
uint32_t riff_id_;
uint32_t riff_size_;
uint32_t wave_tag_;
uint32_t format_id_;
uint32_t format_size_;
uint32_t format_tag_ : 16;
uint32_t channels_ : 16;
uint32_t sample_rate_;
uint32_t bytes_per_second_;
uint32_t block_align_ : 16;
uint32_t bit_depth_ : 16;
uint32_t data_id_;
uint32_t data_size_;
} header;
assert(sizeof header == 44uz);
// Read WAV header
std::ifstream audio{file_name};
.read(reinterpret_cast<char *>(&header), sizeof header);
audio
// Read a block of raw data to analyse
std::vector<short> raw(bins<size_t>);
const size_t data_size = std::size(raw) * sizeof(decltype(raw)::value_type);
.read(reinterpret_cast<char *>(raw.data()), data_size);
audio
// Convert to target type
return std::vector<T>{raw.cbegin(), raw.cend()};
}
/// Conversion from degrees to radians
constexpr auto deg2rad(const std::floating_point auto degrees) {
return std::numbers::pi_v<decltype(degrees)> * degrees /
decltype(degrees){180.0f};
}
/// Generate a clean sine wave to play with
auto generate_sine_wave(const size_t count,
const std::floating_point auto frequency) {
using T = std::decay_t<decltype(frequency)>;
assert(count > 0uz);
assert(frequency > T{});
// Initialise a container with the element index
std::vector<T> samples(count);
// Populate samples
for (auto i = 0uz; i < std::size(samples); ++i) {
const auto x = static_cast<T>(i);
[i] = std::sin(frequency * deg2rad(x));
samples}
return samples;
}
(dft, initialise_twiddle_matrix) {
TEST// Generate some test samples
const auto samples = generate_sine_wave<float>(bins<size_t>, 2000.0f);
// Twiddle matrix is initialise on first call
const auto dft = analyse<float>(samples);
(dft, "sine_wave");
write_csv
// DFT should be half as long as the sample data
(std::size(dft), std::size(samples) / 2);
EXPECT_EQ}
(dft, profile_all_wavs_as_floats) {
TEST// Get list of files to process -- note we haven't specified the vector type
const std::filesystem::path p{"wav/"};
const std::vector files(std::filesystem::directory_iterator{p}, {});
(files.empty());
ASSERT_FALSE
// Start timer for main process
using namespace std::chrono;
const auto start_timing = high_resolution_clock::now();
// Calculate DFT for each file
::print("Processing {} files
fmt", std::size(files));
for (const auto &file : files) {
// Skip any unsupported file types
if (not(file.path().extension() == ".wav"))
continue;
const std::string file_name = file.path();
::print(" {}
fmt", file_name);
// Get samples for a WAV file and analyse
const auto samples = read_wav<float>(file_name);
const auto dft = analyse<float>(samples);
(std::size(dft), std::size(samples) / 2);
EXPECT_EQ
// Get the base file name and write DFT to disk
(dft, file.path().stem());
write_csv}
// Create summary
const auto end_timing = high_resolution_clock::now();
const auto diff = duration_cast<microseconds>(end_timing - start_timing);
const auto samples_per_second =
1e6f * static_cast<float>(std::size(files) * bins<size_t>) /
static_cast<float>(diff.count());
std::map<std::string, std::variant<size_t, float, double>> stats;
["cores"] = std::thread::hardware_concurrency();
stats["files"] = std::size(files);
stats["GiB twiddle matrix"] =
stats<size_t> * (bins<size_t> / 2) * sizeof(float) / std::pow(2.0f, 30);
bins["s analysis duration"] = static_cast<float>(diff.count()) / 1e6f;
stats["samples per second"] =
stats1e6f * static_cast<float>(std::size(files) * bins<size_t>) /
static_cast<float>(diff.count());
["DFT bins"] = bins<size_t>;
stats["x speed up"] = samples_per_second / bins<float>;
stats
(stats);
write_summary}
(unit_test, basic_conversion) {
TEST// Floats
(deg2rad(0.0f), 0.0f, 0.1f);
EXPECT_NEAR(deg2rad(90.0f), std::numbers::pi_v<float> / 2.0f, 0.1f);
EXPECT_NEAR(deg2rad(180.0f), std::numbers::pi_v<float>, 0.1f);
EXPECT_NEAR
// Doubles
(deg2rad(0.0), 0.0, 0.1);
EXPECT_NEAR(deg2rad(90.0), std::numbers::pi_v<double> / 2.0, 0.1);
EXPECT_NEAR(deg2rad(180.0), std::numbers::pi_v<double>, 0.1);
EXPECT_NEAR
// Compile time long double
constexpr auto x = deg2rad(0.0l);
static_assert(std::is_same_v<decltype(x), const long double>);
(x, 0.0l);
EXPECT_EQ}
(unit_test, sine_wave_generation) {
TEST// Floats
const auto floats = generate_sine_wave(bins<size_t>, 100.0f);
static_assert(std::is_same_v<decltype(floats)::value_type, float>);
(std::size(floats), bins<size_t>);
EXPECT_EQ(*std::ranges::min_element(floats), -1.0f, 0.1f);
EXPECT_NEAR(*std::ranges::max_element(floats), 1.0f, 0.1f);
EXPECT_NEAR::print("float typeid: {}
fmt", typeid(decltype(floats)).name());
// Doubles
const auto doubles = generate_sine_wave(bins<size_t>, 100.0);
static_assert(std::is_same_v<decltype(doubles)::value_type, double>);
(std::size(doubles), bins<size_t>);
EXPECT_EQ(*std::ranges::min_element(doubles), -1.0, 0.1);
EXPECT_NEAR(*std::ranges::max_element(doubles), 1.0, 0.1);
EXPECT_NEAR::print("double typeid: {}
fmt", typeid(decltype(doubles)).name());
// Long doubles
const auto long_doubles = generate_sine_wave(bins<size_t>, 100.0l);
static_assert(
std::is_same_v<decltype(long_doubles)::value_type, long double>);
(std::size(long_doubles), bins<size_t>);
EXPECT_EQ::print("long double typeid: {}
fmt", typeid(decltype(long_doubles)).name());
}