First of all, let's take a look at the implementation:
class FF(nn.Module):
def __init__(self, band_limit: int, p: float):
self.b_s = torch.arange(1, band_limit + 1)
self.a_s = 1 / torch.pow(self.b_s, p)
super().__init__()
def forward(self, x):
return torch.cat([
self.a_s * torch.sin(2. * np.pi * x * self.b_s),
self.a_s * torch.cos(2. * np.pi * x * self.b_s)
], dim=-1) / torch.norm(self.a_s)
Now there are a few crucial points. First of all, there are
these two parameters a
and b
. In fourier features, the
parameter b
corresponds to the frequencies. Fourier features
use equally spaced octaves.
net = FF(8, 0)
doc.print("frequencies (in 2π):", net.b_s)
frequencies (in 2π): tensor([1, 2, 3, 4, 5, 6, 7, 8])
An important property in the examples used in tanick is that if your input is between , you need to make sure that the frequencies captures the lowest mode (the first octave) of your entire range. Otherwise you will get aliasing.
Important Considerations
So for the range , this means the lowest mode spans the full over that range.
Now suppose we only take 1 fourier component (both and ), this range is mapped to a circle in the phase space. If the a data point falls outside of this range, it would circle back in the circle and become aliased. Therefore we need to make sure the lowest frequency components has longer wavelength than the range.
Controlling the Spectral Bias
Now a second parameter is . This is the set of weights for each of the fourier components in
We can use
to specify these weights. The weight decays for different as below:
plt.title("Spectral Weights")
for p in [0, 0.5, 1, 1.5, 2, float('inf')]:
a_s = 1 / np.arange(1, 8).__pow__(p)
plt.plot(a_s, label=f"p={p}")
The Spectrum of Random Fourier Kernel
We can directly visualize the spectrum of this kernel. Now before we start let's first remind ourselves of the basic concepts.
definition 1: a similarity function that is continuous and symmetric, is called a kernel.
Mercer's Theorem is a kernel iff. for all data , the Gramian matrix given by
is positive semi-definite (PSD). This is called the kernel trick. For more precise formalism, refer to wikipedia.
The neural tangent kernel (NTK) is similar to the conjugate kernel (CK). The congjugate kernel looks at the neural network at initialization, whereas the neural tanget kernel looks at the gradient of a network initialized around zero. This is related to the first order Taylor expansion of the neural network function, and can be done by computing the Gramian matrix of the gradient vector around the intialization point of the network :
There are two important details:
The gradient vector needs to be normalized, so that for the same datapoint , the NTK kernel produces the identity.
This is the reason behind the line below
grad.append(grad_vec / np.linalg.norm(grad_vec))
The Gram matrix for one network instantiation can be quite noisy. We need to average over multiple instantiations.
def get_ntk(net, xs): grad = [] out = net(torch.FloatTensor(xs)[:, None]) for o in tqdm(out, desc="NTK", leave=False): net.zero_grad() o.backward(retain_graph=True) grad_vec = torch.cat([p.grad.view(-1) for p in net.parameters() if p.grad is not None]).numpy() grad.append(grad_vec / np.linalg.norm(grad_vec)) net.zero_grad() grad = np.stack(grad) gram_matrix = grad @ grad.T return gram_matrix
xs = np.linspace(-0.5, 0.5, 32) for p in tqdm([0, 0.5, 1, 1.5, 2, float('inf')], desc="kernels"): ntk_kernel = 0 # average from ten networks for i in trange(10, desc="averaging networks", leave=False): net = nn.Sequential(FF(16, p), MLP(32, 1024, 4, 1)) ntk_kernel += 0.1 * get_ntk(net, xs) plt.figure('cross section') plt.plot(ntk_kernel[16], label=f'p={p}') plt.figure('spectrum') fft = np.fft.fft(ntk_kernel[16]) plt.plot(np.fft.fftshift(fft).__abs__(), label=f'p={p}')
plt.imshow(ntk_kernel, cmap='inferno')
Kernel Cross Section Spectrum
The Band-limited Fourier Features
An important feature of our implementation above, is that the
fourier features are band-limited on the top. For band_limit=16
,
if we visualize beyond the 8th octave, we would see the spectrum
encounter a sharp drop.
This happens because we are effectively clipping the higher order components
plt.figure("Full A's")
for p in [0, 0.5, 1, 1.5, 2, float('inf')]:
ff = FF(16, p)
full_a = np.zeros(32)
full_a[:16] = ff.a_s
plt.plot(full_a, label=f"p={p}")
xs = np.linspace(-0.5, 0.5, 64)
for p in tqdm([0, 0.5, 1, 1.5, 2, float('inf')], desc="kernels"):
ntk_kernel = 0
for i in trange(10, desc="averaging networks", leave=False):
net = nn.Sequential(FF(16, p), MLP(32, 1024, 4, 1))
ntk_kernel += 0.1 * get_ntk(net, xs)
plt.figure('Full Spectrum')
fft = np.fft.fft(ntk_kernel[32])
plt.plot(np.fft.fftshift(fft).__abs__(), label=f'p={p}')
Full A's | Spectrum |
---|---|