FFT

DFT (離散フーリエ変換) を高速に実行するアルゴリズムのことを FFT (高速フーリエ変換) と呼ぶ. 離散フーリエ変換とフーリエ変換は似て非なるもので, この言葉はいささか不用意であり, 本来 FDFT と名付けるべきだったが今更である. 離散フーリエ変換という言葉自体, 本当は「フーリエ級数展開」のバリエーションと見なせるものなので, これを「フーリエ変換」と言ってしまうのは乱暴である.

大雑把に言うと, 時系列データに対して DFT を施すと周波数成分が計算できる. よくオーディオ機器にスペクトル表示器が備わっているが, アレをイメージするとよい.

今日は, そのような本来の目的は一旦忘れて, 純粋にアルゴリズムの話を進めたい.

定義にはいくつか流儀があるが, おおむね次のようなものとなる.

X[n] = ΣW(k * n, N) x[k], n,k = 0...N-1.

x[k] は変換元となるデータ列, X[n] は変換結果のデータ列, W(m, N) は回転子と呼ばれるもので, 1 の N 乗根のうち 1 でないものの m 乗を表している. 通常,

W(m, N) = exp(-2πi * m / N)
= cos(-2π * m / N) + i sin(-2π * m / N)

指数 (角度) のマイナス符号は, あまり深い意味はないが, 変換結果 X[n] が exp(2πi * m / N) の係数となるようにするためには必要である. (もともと i と -i は同等の立場なので, こだわる意味はまったくない. 純粋に形式だけの問題.)

サイズ N = 2 の DFT

定義は

X[0] = Σ W(k * 0, 2) x[k], k = 0, 1,
X[1] = Σ W(k * 1, 2) x[k], k = 0, 1

となるので,

X[0] = Σ W(0, 2) x[k] = x[0] + x[1],
X[1] = Σ W(k, 2) x[k] = x[0] - x[1]

ここで, W(1, 2) = exp(-2πi / 2) = exp(-πi) = cos(-π) + i sin(-π) = -1 などを使った.

サイズ 2 の DFT の計算には加減算だけが登場することが分かる.

N が 2 のべきとなっているとき, 例えば N = 8 としよう. X[n], x[k] の添え字 n, k を 2 進数で n = pqr, k = hijと表そう. pqr = 4p + 2q + r, hij = 4h + 2i + j.

DFT の定義は,

X[pqr] = ΣW(hij * pqr, 8) x[hij]

となる.

バタフライとビットリバース

回転子の変形を試みる.

W(hij * pqr, 8)
=W((h00 + ij) * pqr, 8)
=W(h00 * pqr, 8) * W(ij * pqr, 8)
=W(h00 * (pq0 + r), 8) * W(ij * pqr, 8)
=W(hpq000, 8) * W(h00 * r, 8) * W(ij * pqr, 8)
=W(h * r, 2) * W(ij * pqr, 8)

元の DFT は,

X[pqr] = Σ[ij] W(ij * pqr, 8) Σ[h] W(h * r, 2) x[hij]

となるが, Σ[h] の部分は, サイズ 2 の DFT となっている. (Σ[x] は x についての和という意味.)

さらに変形する.

W(ij * pqr, 8)
= W(ij * (pq0 + r), 8)
= W(ij * pq0, 8) * W(ij * r, 8)
= W(ij * pq, 4) * W(ij * r, 8)

なので,

X[pqr]
= Σ[ij] W(ij * pqr, 8) Σ[h] W(h * r, 2) x[hij]
= Σ[ij] W(ij * pq, 4) * W(ij * r, 8) Σ[h] W(h * r, 2) x[hij]

このように変形すると, サイズ 8 の DFT は, サイズ 2 の DFT の結果 Σ[h] に W(ij * r, 8) を掛けたものに対するサイズ 4 の DFT となっていることが分かる.

そしてこのサイズ 4 の DFT もまた同様にサイズ 2 の DFT に分解できる. 最終的には

X[pqr]
= Σ[j] W(j * p, 2) * W(j * qr, 8) * Σ[i] W(i * q, 2) * W(i * r, 4) * Σ[h] W(h * r, 2) * x[hij]

となる. 各 DFT を分けて書くと, 次のようになる.

y[rij] = Σ[h] W(h * r, 2) * x[hij],
z[qrj] = Σ[i] W(i * q, 2) * W(i *  r, 4) * y[rij],
X[pqr] = Σ[j] W(j * p, 2) * W(j * qr, 8) * z[rqj]

計算する要素の順序がかなりややこしいこの演算は「バタフライ」演算などと呼ばれる. このように添え字の 2 進展開を利用して回転子をうまくまとめて処理することで, 演算回数を劇的に減らすことができる. DFT の定義に忠実に計算すると O(N^2) の計算量となるが, 上記の方法を採用すれば O(N*log(N)) の計算量で済む.

実際に処理するときには, いわゆる「インプレイス」型の処理を行うことが多いが, それには変換元データを変換後データで置き換えていかなければならないため,

y[rij] = Σ[h] W(h * r, 2) * x[hij]
z[rqj] = Σ[i] W(i * q, 2) * W(i *  r, 4) * y[rij]
X[rqp] = Σ[j] W(j * p, 2) * W(j * qr, 8) * z[rqj]

などとする必要がある. このとき X の添え字がビット逆転してしまっていることに注意. このため, X が得られた後, いわゆる「ビットリバース」処理を行うことになる.

応用

ここで説明した時系列データの添え字の上位桁から下位桁に向かって DFT を実行していくアルゴリズムは, 「周波数間引き型 FFT」と呼ばれている. 逆に下位桁から上位桁に向かって実行すると「時間間引き型 FFT」となるが, もはや説明するまでもない.

DFT とほぼ同様の処理で逆 DFT が考えられる. DFT が時系列データを周波数成分に変換する演算なら, 逆 DFT はもちろん周波数成分を時系列データに変換する演算となる.

FFT の目的は, 前述の通り一つは周波数成分を計算することであるが, もう一つ重要な目的として畳み込み, 英語では convolution と呼ばれる計算がある.

多項式 (a[0]x^0 + a[1]x^1 + a[2]x^2 + ... + a[N-1]x^(N-1)) と (b[0]x^0 + b[1]x^1 + b[2]x^2 + ... + b[N-1]x^(N-1)) の積を計算したいとき, 各係数同士の積和が必要となるので, O(N^2) の計算量となるが, それぞれの係数の DFT を求め, 同一周波数成分同士を乗算し, 結果を逆 DFT すると積和が計算できてしまっていると言う一見不思議な性質がある. この計算量は O(N*log(N)) で済む. もちろん FFT 自体がそれほど軽い演算ではないため, 項数が多くないときにはあまり有利にはならないが, 巨大な N についてはほとんど必須と言ってよい工夫となる.

畳み込みに FFT を利用する場合, 周波数領域での要素の並び順はまったく重要ではない (どうせ最後に逆 FFT する) ため, 前述のビットリバース処理は省略することができる.

FFT は複素数体上の演算であるが, 原理上, 任意の数体上の演算として拡張できる. 特に有限体上で処理する FFT (のような処理) のことを FMT, 高速剰余変換と呼び, 多桁の乗算にはこちらの方が重要となる.

実装例

Field.java: 「体」のインタフェース.

package org.creasys.numeric;

/**
 * 「体」のインタフェース.
 *
 * @param <X> 実装クラス
 */
public interface Field<X>
  {
    /**
     * 加算する.
     *
     * @param rhs 加数
     * @return 加算結果
     */
    X add(X rhs);
    /**
     * 減算する.
     *
     * @param rhs 減数
     * @return 減算結果
     */
    X sub(X rhs);
    /**
     * 乗算する.
     *
     * @param rhs 乗数
     * @return 乗算結果
     */
    X mul(X rhs);
    /**
     * 除算する.
     *
     * @param rhs 除数
     * @return 除算結果
     */
    X div(X rhs);
  }

Complex.java: 複素数クラス.

package org.creasys.numeric;

/**
 * 複素数クラス.
 */
public class Complex implements Field<Complex>
  {
    /** 実部. */
    private double re;
    /** 虚部. */
    private double im;

    public static Complex of(double re, double im)
      {
        return new Complex(re, im);
      }

    public Complex(double re, double im)
      {
        this.re = re;
        this.im = im;
      }

    @Override
    public Complex add(Complex rhs)
      {
        return of(re + rhs.re, im + rhs.im);
      }

    @Override
    public Complex sub(Complex rhs)
      {
        return of(re - rhs.re, im - rhs.im);
      }

    @Override
    public Complex mul(Complex rhs)
      {
        return of(re * rhs.re - im * rhs.im, re * rhs.im + im * rhs.re);
      }

    public Complex mul(double rhs)
      {
        return of(re * rhs, im * rhs);
      }

    @Override
    public Complex div(Complex rhs)
      {
        // (a, b) / (c, d)
        // = (a, b) * (c, -d) / (c*c + d*d)
        double denom = 1.0 / (rhs.re * rhs.re + rhs.im * rhs.im);
        return of((re * rhs.re + im * rhs.im) * denom,
                        (im * rhs.re - re * rhs.im) * denom);
      }

    public double real()
      {
        return re;
      }

    public double imag()
      {
        return im;
      }

    public Complex conjugate()
      {
        return of(re, -im);
      }

    /**
     * 絶対値を返す.
     * <p>
     * 数学的には <code>sqrt(re<sup>2</sup> + im<sup>2</sup>)</code> で
     * あるが, 数値計算的にはこの通りに計算すると {@code re} または 
     * {@code im} の絶対値が <code>sqrt({@link Double#MAX_VALUE})</code> 
     * や <code>({@link Double#MIN_VALUE})</code> 付近になると, 結果は十
     * 分範囲内である場合でも, 演算中にオーバフローやアンダフローを発生
     * させてしまうことがあるため, 好ましくない.
     * <p>
     * そこで, 実部と虚部の絶対値の最大値を基準としてスケーリングして計
     * 算する. たとえば実部の絶対値が虚部の絶対値より大きい場合は,
     * <pre>
     * sqrt(re<sup>2</sup> + im<sup>2</sup>)
     * = |re| sqrt(1 + (im / |re|)<sup>2</sup>)
     * </pre>
     * を計算する.
     *
     * @return この複素数の絶対値
     */
    public double abs()
      {
        double r = Math.abs(re);
        double i = Math.abs(im);
        double result;
        if (r == 0.0)
            result = i;
        else if (r >= i)
          {
            i /= r;
            result = r * Math.sqrt(1.0 + i * i);
          }
        else
          {
            r /= i;
            result = i * Math.sqrt(r * r + 1.0);
          }
        return result;
      }

    /**
     * この複素数の2乗ノルムを返す.
     *
     * @return この複素数の2乗ノルム
     */
    public double norm()
      {
        return re * re + im * im;
      }

    @Override
    public String toString()
      {
        return String.format("(%6.2f, %6.2f)", re, im);
      }
  }

DFT2.java: 高速離散フーリエ変換.

package org.creasys.numeric.dft;

import org.creasys.numeric.Complex;

/**
 * 離散フーリエ変換を行うためのクラス.
 * <p>
 * データ列「{@code Complex[] x;}」に対して変換を行うには, 以下のような
 * 方法がある.
 * <ol type="A"><li><pre>
 * DFT2.dft(x, false);</pre></li>
 * <li><pre> DFT2 dfter = new DFT2();
 * dfter.dif(x, false);
 * dfter.bitReverse(x);
 * </pre></li>
 * </ol>
 *
 * @author tt &lt;tanimoto@creasys.org&gt;
 */
public class DFT2
  {
    /** DFTのサイズ. */
    private int size;

    /** 回転子の配列. */
    private Complex[] cis;

    /** インスタンスを構築する. */
    public DFT2() { size = 0; }

    /**
     * DFTのサイズを設定する.
     * <p>
     * 2のべき乗でない場合は, 例外を発生する.
     * <p>
     * 設定済みのサイズと異なるサイズを設定する場合は, 回転子配列を初期
     * 化する.
     *
     * @param n 設定するサイズ
     * @throws IllegalArgumentException サイズが不正だった場合
     */
    private void initSize(int n)
      {
        if (!(n >= 2 && (n & n - 1) == 0))
            throw new IllegalArgumentException();
        if (n != size)
          {
            size = n;
            cis = new Complex[size];
            for (int k = 0; k < n; ++k)
                cis[k] = Complex.of(
                        Math.cos(-2 * Math.PI * k / n),
                        Math.sin(-2 * Math.PI * k / n));
          }
      }

    /**
     * 離散フーリエ変換を実行する.
     * <p>
     * このメソッドは, アルゴリズムを示すために一話完結ものとして実装し
     * ている. 実行効率のため, 回転子はテーブル化するのが一般的であるが, 
     * ここでは計算によって求めている. 三角関数を毎回計算するのはあまり
     * にも効率が悪いため, 加法定理 (複素数の乗算) によって更新する手法
     * を用いている. そのため累積する誤差により, 精度はよくない.
     *
     * @param x 変換するデータ列
     * @param inverse 逆変換を行う場合 {@code true}
     */
    public static void dft(Complex[] x, boolean inverse)
      {
        int n = x.length;
        if (!(n >= 2 && (n & n - 1) == 0))
            throw new IllegalArgumentException();

        // 回転子の種.
        Complex rtor0 = Complex.of(
                Math.cos(-2 * Math.PI / n),
                Math.sin(-2 * Math.PI / n));
        if (inverse)
            rtor0 = rtor0.conjugate();

        // バタフライ演算. 周波数間引き.
        for (int ns = n / 2; ns > 0; ns /= 2)
          {
            Complex rtor = Complex.of(1.0, 0.0);
            for (int k = 0; k < ns; ++k)
              {
                for (int i = k; i < n; i += 2 * ns)
                  {
                    int j = i + ns;
                    Complex t = x[i].sub(x[j]);
                    x[i] = x[i].add(x[j]);
                    x[j] = t.mul(rtor);
                  }
                rtor = rtor.mul(rtor0);
              }
            rtor0 = rtor0.mul(rtor0);
          }

        // ビットリバース.
        for (int i = 0, j = 0; i < n - 1; ++i)
          {
            if (i < j)
              {
                Complex t = x[i];
                x[i] = x[j];
                x[j] = t;
              }
            for (int k = n >>> 1; ; k >>>= 1)
              {
                if ((j & k) == 0)
                  {
                    j |= k;
                    break;
                  }
                j &= ~k;
              }
          }

        if (inverse)
          {
            // 係数の処理.
            double c = 1.0 / n;
            for (int i = 0; i < n; ++i)
                x[i] = x[i].mul(c);
          }
      }

    /**
     * 周波数間引き型のバタフライ演算を実行する.
     * <p>
     * 逆変換を指示すると, 周波数間引き型の逆ではなく, 時間間引き型の逆
     * 変換となることに注意.
     *
     * @param x データ列
     * @param inverse 逆変換
     */
    public void dif(Complex[] x, boolean inverse)
      {
        int n = x.length;
        initSize(n);

        // バタフライ演算.
        int nk = 1;
        for (int ns = n / 2; ns > 0; ns /= 2)
          {
            int wx = 0;
            for (int k = 0; k < ns; ++k)
              {
                Complex rtor = !inverse ? cis[wx] : cis[wx].conjugate();
                for (int i = k; i < n; i += 2 * ns)
                  {
                    int j = i + ns;
                    Complex t = x[i].sub(x[j]);
                    x[i] = x[i].add(x[j]);
                    x[j] = t.mul(rtor);
                  }
                wx += nk;
              }
            nk *= 2;
          }
      }

    /**
     * 時間間引き型のバタフライ演算を実行する.
     * <p>
     * 逆変換を指示すると, 時間間引き型の逆ではなく, 周波数間引き型の逆
     * 変換となることに注意.
     *
     * @param x データ列
     * @param inverse 逆変換
     */
    public void dit(Complex[] x, boolean inverse)
      {
        int n = x.length;
        initSize(n);

        // バタフライ演算.
        int nk = n / 2;
        for (int ns = 1; ns < n; ns *= 2)
          {
            int wx = 0;
            for (int k = 0; k < ns; ++k)
              {
                Complex rtor = !inverse ? cis[wx] : cis[wx].conjugate();
                for (int i = k; i < n; i += 2 * ns)
                  {
                    int j = i + ns;
                    Complex t = rtor.mul(x[j]);
                    x[j] = x[i].sub(t);
                    x[i] = x[i].add(t);
                  }
                wx += nk;
              }
            nk /= 2;
          }
      }

    public void bitReverse(Complex[] x)
      {
        int n = x.length;
        initSize(n);

        // ビットリバース.
        for (int i = 0, j = 0; i < n - 1; ++i)
          {
            if (i < j)
              {
                Complex t = x[i];
                x[i] = x[j];
                x[j] = t;
              }
            for (int k = n >>> 1; ; k >>>= 1)
              {
                if ((j & k) == 0)
                  {
                    j |= k;
                    break;
                  }
                j &= ~k;
              }
          }
      }

    public void scaleDown(Complex[] x)
      {
        int n = x.length;
        initSize(n);

        double scale = 1.0 / n;
        for (int k = 0; k < n; ++k)
            x[k] = x[k].mul(scale);
      }
  }

DFT2Test.java: テストプログラム.

package org.creasys.numeric.dft;

import org.creasys.numeric.Complex;
import org.junit.Assert;
import org.junit.FixMethodOrder;
import org.junit.Test;
import org.junit.runners.MethodSorters;

@FixMethodOrder(MethodSorters.NAME_ASCENDING)
public class DFT2Test
  {
    @Test
    public void testMul0()
      {
        // 51782163529 * 76537543 の計算.
        int[] a = {9, 2, 5, 3, 6, 1, 2, 8, 7, 1, 5};
        int[] b = {3, 4, 5, 7, 3, 5, 6, 7};

        final int M = a.length + b.length;
        final int N = 32;
        Complex[] ca = new Complex[N];
        for (int k = 0; k < N; ++k)
            ca[k] = Complex.of(k < a.length ? a[k] : 0, 0);
        Complex[] cb = new Complex[N];
        for (int k = 0; k < N; ++k)
            cb[k] = Complex.of(k < b.length ? b[k] : 0, 0);

        DFT2.dft(ca, false);
        DFT2.dft(cb, false);
        for (int k = 0; k < N; ++k)
            ca[k] = ca[k].mul(cb[k]);
        DFT2.dft(ca, true);

        int[] result = new int[M];
        double maxError = 0;
        long carry = 0;
        for (int k = 0; k < M; ++k) {
            double x = ca[k].abs();
            long ix = Math.round(x);
            maxError = Math.max(maxError, Math.abs(x - ix));
            carry += ix;
            result[k] = (int) (carry % 10);
            carry /= 10;
        }
        for (int k = M; --k >= 0; )
            System.out.printf("%d", result[k]);
        System.out.printf("%nmaxError: %g%n", maxError);
    }

    @Test
    public void testMul1()
      {
        // 51782163529 * 76537543 の計算.
        int[] a = {9, 2, 5, 3, 6, 1, 2, 8, 7, 1, 5};
        int[] b = {3, 4, 5, 7, 3, 5, 6, 7};

        final int M = a.length + b.length;
        final int N = 32;
        Complex[] ca = new Complex[N];
        for (int k = 0; k < N; ++k)
            ca[k] = Complex.of(k < a.length ? a[k] : 0, 0);
        Complex[] cb = new Complex[N];
        for (int k = 0; k < N; ++k)
            cb[k] = Complex.of(k < b.length ? b[k] : 0, 0);

        DFT2 dft = new DFT2();
        dft.dif(ca, false);
        dft.dif(cb, false);
        for (int k = 0; k < N; ++k)
            ca[k] = ca[k].mul(cb[k]);
        dft.dit(ca, true);
        dft.scaleDown(ca);

        int[] result = new int[M];
        double maxError = 0;
        long carry = 0;
        for (int k = 0; k < M; ++k)
          {
            double x = ca[k].abs();
            long ix = Math.round(x);
            maxError = Math.max(maxError, Math.abs(x - ix));
            carry += ix;
            result[k] = (int) (carry % 10);
            carry /= 10;
          }
        for (int k = M; --k >= 0; )
            System.out.printf("%d", result[k]);
        System.out.printf("%nmaxError: %g%n", maxError);
      }

    @Test
    public void test0()
      {
        Complex[] x = new Complex[16];
        for (int k = 0; k < x.length; ++k)
            x[k] = Complex.of(Math.sin(2 * Math.PI * k / x.length), 0);

        System.out.println("Before");
        for (int k = 0; k < x.length; ++k)
            System.out.printf(k < x.length - 1 && k % 4 < 3 ? " %s" : " %s%n", x[k]);

        DFT2.dft(x, false);

        System.out.println("After");
        for (int k = 0; k < x.length; ++k)
            System.out.printf(k < x.length - 1 && k % 4 < 3 ? " %s" : " %s%n", x[k]);

        DFT2.dft(x, true);

        System.out.println("Reverse");
        for (int k = 0; k < x.length; ++k)
            System.out.printf(k < x.length - 1 && k % 4 < 3 ? " %s" : " %s%n", x[k]);
      }

    @Test
    public void test1()
      {
        Complex[] x = new Complex[16];
        for (int k = 0; k < x.length; ++k)
            x[k] = Complex.of(k == 0 ? 1 : 0, 0);

        System.out.println("Before");
        for (int k = 0; k < x.length; ++k)
            System.out.printf(k < x.length - 1 && k % 4 < 3 ? " %s" : " %s%n", x[k]);

        DFT2.dft(x, false);

        System.out.println("After");
        for (int k = 0; k < x.length; ++k)
            System.out.printf(k < x.length - 1 && k % 4 < 3 ? " %s" : " %s%n", x[k]);

        DFT2.dft(x, true);

        System.out.println("Reverse");
        for (int k = 0; k < x.length; ++k)
            System.out.printf(k < x.length - 1 && k % 4 < 3 ? " %s" : " %s%n", x[k]);
      }
  }

testMul0 は高速離散フーリエ変換を多桁乗算に応用した例. testMul1 も同様だが, こちらはバタフライ演算とスケールダウンだけを行い, ビットリバース処理を省いている分, やや効率を重視している.

test0 は, ちょうどデータ列分が周期となる正弦波を変換するテスト. test1 は, いわゆるステップ波を離散フーリエ変換するテストとなっている.