A Quick Introduction to Random Fourier Kernels

A Quick Introduction to Random Fourier Kernels

written by Ge Yang

First of all, let's take a look at the implementation:

class FF(nn.Module):
    def __init__(self, band_limit: int, p: float):
        self.b_s = torch.arange(1, band_limit + 1)
        self.a_s = 1 / torch.pow(self.b_s, p)
        super().__init__()

    def forward(self, x):
        return torch.cat([
            self.a_s * torch.sin(2. * np.pi * x * self.b_s),
            self.a_s * torch.cos(2. * np.pi * x * self.b_s)
        ], dim=-1) / torch.norm(self.a_s)

Now there are a few crucial points. First of all, there are these two parameters a and b. In fourier features, the parameter b corresponds to the frequencies. Fourier features use equally spaced octaves.

net = FF(8, 0)
doc.print("frequencies (in 2π):", net.b_s)
frequencies (in 2π): tensor([1, 2, 3, 4, 5, 6, 7, 8])

An important property in the examples used in tanick is that if your input is between [0,1)[0, 1), you need to make sure that the frequencies b\mathbb b captures the lowest mode (the first octave) of your entire range. Otherwise you will get aliasing.

Important Considerations

So for the range [0,1)[0, 1), this means the lowest mode spans the full 2π2 \pi over that range.

Now suppose we only take 1 fourier component (both sin\sin and cos\cos), this range [0,1)[0, 1) is mapped to a circle in the phase space. If the a data point falls outside of this range, it would circle back in the circle and become aliased. Therefore we need to make sure the lowest frequency components has longer wavelength than the range.

Controlling the Spectral Bias

Now a second parameter is aa. This is the set of weights for each of the fourier components in

out=aisin(2πbix)+aicos(2πbix)\text{out} = \sum {a_i \sin(2\pi b_i x) + a_i \cos(2\pi b_i x)}

We can use

ai=1/bipa_i = 1 / b_i^p

to specify these weights. The weight decays for different pp as below:

plt.title("Spectral Weights")

for p in [0, 0.5, 1, 1.5, 2, float('inf')]:
    a_s = 1 / np.arange(1, 8).__pow__(p)
    plt.plot(a_s, label=f"p={p}")

The Spectrum of Random Fourier Kernel

We can directly visualize the spectrum of this kernel. Now before we start let's first remind ourselves of the basic concepts.

definition 1: a similarity function K:XRK: X \mapsto \mathbb R that is continuous and symmetric, is called a kernel.

Mercer's Theorem KK is a kernel iff. for all data xXx\in X, the Gramian matrix MM given by

G[i,j]=Kxi,xjG_{[i, j]} = K\langle x^i, x^j\rangle

is positive semi-definite (PSD). This is called the kernel trick. For more precise formalism, refer to wikipedia.

The neural tangent kernel (NTK) is similar to the conjugate kernel (CK). The congjugate kernel looks at the neural network at initialization, whereas the neural tanget kernel looks at the gradient of a network initialized around zero. This is related to the first order Taylor expansion of the neural network function, and can be done by computing the Gramian matrix of the gradient vector around the intialization point θ0\theta_0 of the network fθ0(x)f_{\theta_0}(x):

NTKxi,xj=θf(xi)θTf(xj)NTK\langle x^i, x^j \rangle = \nabla_\theta f(x^i) \nabla^T_\theta f(x^j)

There are two important details:

  1. The gradient vector θf\nabla_\theta f needs to be normalized, so that for the same datapoint xi=xjx_i = x_j, the NTK kernel produces the identity.

    This is the reason behind the line below

    grad.append(grad_vec / np.linalg.norm(grad_vec))
  2. The Gram matrix for one network instantiation can be quite noisy. We need to average over multiple instantiations.

    def get_ntk(net, xs):
        grad = []
        out = net(torch.FloatTensor(xs)[:, None])
        for o in tqdm(out, desc="NTK", leave=False):
            net.zero_grad()
            o.backward(retain_graph=True)
            grad_vec = torch.cat([p.grad.view(-1) for p in net.parameters() if p.grad is not None]).numpy()
            grad.append(grad_vec / np.linalg.norm(grad_vec))
            net.zero_grad()
    
        grad = np.stack(grad)
        gram_matrix = grad @ grad.T
        return gram_matrix
    xs = np.linspace(-0.5, 0.5, 32)
    
    for p in tqdm([0, 0.5, 1, 1.5, 2, float('inf')], desc="kernels"):
        ntk_kernel = 0  # average from ten networks
        for i in trange(10, desc="averaging networks", leave=False):
            net = nn.Sequential(FF(16, p), MLP(32, 1024, 4, 1))
            ntk_kernel += 0.1 * get_ntk(net, xs)
    
        plt.figure('cross section')
        plt.plot(ntk_kernel[16], label=f'p={p}')
    
        plt.figure('spectrum')
        fft = np.fft.fft(ntk_kernel[16])
        plt.plot(np.fft.fftshift(fft).__abs__(), label=f'p={p}')
    plt.imshow(ntk_kernel, cmap='inferno')
    KernelCross SectionSpectrum

The Band-limited Fourier Features

An important feature of our implementation above, is that the fourier features are band-limited on the top. For band_limit=16, if we visualize beyond the 8th octave, we would see the spectrum encounter a sharp drop.

This happens because we are effectively clipping the higher order components

plt.figure("Full A's")
for p in [0, 0.5, 1, 1.5, 2, float('inf')]:
    ff = FF(16, p)
    full_a = np.zeros(32)
    full_a[:16] = ff.a_s
    plt.plot(full_a, label=f"p={p}")
xs = np.linspace(-0.5, 0.5, 64)

for p in tqdm([0, 0.5, 1, 1.5, 2, float('inf')], desc="kernels"):
    ntk_kernel = 0
    for i in trange(10, desc="averaging networks", leave=False):
        net = nn.Sequential(FF(16, p), MLP(32, 1024, 4, 1))
        ntk_kernel += 0.1 * get_ntk(net, xs)

    plt.figure('Full Spectrum')
    fft = np.fft.fft(ntk_kernel[32])
    plt.plot(np.fft.fftshift(fft).__abs__(), label=f'p={p}')
Full A'sSpectrum