package math.fourierTransforms.r2;

import graphics.grapher.Graph;
import utils.StopWatch;
import math.Mat1;
import math.MathUtils;
import math.fourierTransforms.DFTTest;


/**
 * This is the FFTR2Double version 1.9
 * @author D. Lyon
 */
public class FFTR2Double {

    private double r_data[] = null;
    private double i_data[] = null;


    // if forward = true then perform a forward fft
    // else perform a backward fft;
    private boolean forward = true;
    private static final float twoPI = (float) (2 * Math.PI);

    public FFTR2Double(int N) {
        r_data = new double[N];
        i_data = new double[N];
    }

    public FFTR2Double() {
    }

// swap Zi with Zj

    private void swapInt(int i, int j) {
        double tempr;
        int ti;
        int tj;
        ti = i - 1;
        tj = j - 1;
        tempr = r_data[tj];
        r_data[tj] = r_data[ti];
        r_data[ti] = tempr;
        tempr = i_data[tj];
        i_data[tj] = i_data[ti];
        i_data[ti] = tempr;
    }

    public static double getMaxValue(double in[]) {
        double max;
        max = -0.99e30;
        for (int i = 0; i < in.length; i++)
            if (in[i] > max)
                max = in[i];
        return max;

    }

    public void normalizeAndTruncateInput(double in[]) {
        getMaxValue(in);
        /* copy over normalized input */
        System.arraycopy(in, 0, r_data, 0, r_data.length);
    }

    private void bitReverse() {
        /* bit reversal */
        int n = r_data.length;
        int j = 1;

        int k;

        for (int i = 1; i < n; i++) {

            if (i < j) swapInt(i, j);
            k = n / 2;
            while (k >= 1 && k < j) {

                j = j - k;
                k = k / 2;
            }
            j = j + k;
        } // for
    }


    public void reverseFFT(double in_r[], double in_i[]) {
        forward = false;
        forwardFFT(in_r, in_i);
        forward = true;
        //centering(in_r);
    }

    public static void centering(double r[]) {
        int s = 1;
        for (int i = 0; i < r.length; i++) {
            s = -s;
            r[i] *= s;
        }
    }

    public void forwardFFT(double in_r[], double in_i[]) {
        int id;

        int localN;
        double wtemp, Wjk_r, Wjk_i, Wj_r, Wj_i;
        double theta, tempr, tempi;


        int numBits = MathUtils.log2(in_r.length);
        if (forward) {
            //centering(in_r);
        }

        // Truncate input data to a power of two
        int length = 1 << numBits; // length = 2**nu
        int nby2;

        // Copy passed references to variables to be used within
        // fft routines & utilities
        r_data = in_r;
        i_data = in_i;

        bitReverse();
        for (int m = 1; m <= numBits; m++) {
            // localN = 2^m;
            localN = 1 << m;

            nby2 = localN / 2;
            Wjk_r = 1;
            Wjk_i = 0;

            theta = Math.PI / nby2;

            // for recursive comptutation of sine and cosine
            Wj_r = Math.cos(theta);
            Wj_i = -Math.sin(theta);
            if (!forward) {
                Wj_i = -Wj_i;
            }


            for (int j = 0; j < nby2; j++) {
                // This is the FFT innermost loop
                // Any optimizations that can be made here will yield
                // great rewards.
                for (int k = j; k < length; k += localN) {
                    id = k + nby2;
                    tempr = Wjk_r * r_data[id] - Wjk_i * i_data[id];
                    tempi = Wjk_r * i_data[id] + Wjk_i * r_data[id];

                    // Zid = Zi -C
                    r_data[id] = r_data[k] - tempr;
                    i_data[id] = i_data[k] - tempi;
                    r_data[k] += tempr;
                    i_data[k] += tempi;
                }

                // (eq 6.23) and (eq 6.24)

                wtemp = Wjk_r;

                Wjk_r = Wj_r * Wjk_r - Wj_i * Wjk_i;
                Wjk_i = Wj_r * Wjk_i + Wj_i * wtemp;
            }
        }
        // normalize output of fft.
        if (forward)
            for (int i = 0; i < r_data.length; i++) {
                r_data[i] = r_data[i] / (double) length;
                i_data[i] = i_data[i] / (double) length;
            }
    }


    public void normalizeData() {
        int N = r_data.length;
        for (int k = 0; k < N; k++) {
            r_data[k] /= N;
            i_data[k] /= N;
        }

    }

    public double[] getReal() {
        return r_data;
    }

    public double[] getImaginary() {
        return i_data;
    }


    public double[] getPowerSpectralDensity() {
        double[] psd = new double[r_data.length];
        for (int k = 0; k < r_data.length; k++) {
            psd[k] =
                    r_data[k] * r_data[k] +
                            i_data[k] * i_data[k];

        }
        return psd;
    }

    // assume that re and im are
    // set. Also assume that the real
    // value is to be returned
    public double[] ifft() {
        int i, m, j, id;
        int N; // the radix 2 number of samples
        double wtemp, wr, wpr, wpi, wi, theta, tempr, tempi;

        // length is the number of input samples
        int length = r_data.length;

        // how many bits do we need?
        int nu = (int) (Math.log(length) / Math.log(2.0));
        //Truncate input data to a power of two
        length = 1 << nu; // length = 2**nu


        int n = length;

        for (m = 1; m <= nu; m++) {

            // k = 2^m;
            N = 1 << m;
            theta = twoPI / N;
            // theta = - 2Pi/(2^m)

            wr = 1.0;
            wi = 0.0;
            wpr = Math.cos(theta);

            // ifft uses - sin(theta);
            wpi = -Math.sin(theta);


            for (j = 1; j <= N / 2; j++) {
                for (i = j; i <= n; i = i + N) {

                    id = i + N / 2;
                    tempr = wr * r_data[id - 1] - wi * i_data[id - 1];
                    tempi = wr * i_data[id - 1] + wi * r_data[id - 1];

                    // Zid-1 = Zi-1 - C(tempr,tempi)
                    r_data[id - 1] = r_data[i - 1] - tempr;
                    i_data[id - 1] = i_data[i - 1] - tempi;
                    r_data[i - 1] += tempr;
                    i_data[i - 1] += tempi;
                }
                wtemp = wr;
                // W = W * WP
                // W = (wr + i wi) (wpr + i wpi)
                // W = wr * wpr - wi * wpi + i (wi * wpr + wr * wpi)
                wr = wr * wpr - wi * wpi;
                wi = wi * wpr + wtemp * wpi;
            }
        }
        return (Mat1.arrayCopy(r_data));
    }




    public static void printArray(double[] v, String title) {
        System.out.println(title);
        for (int i = 0; i < v.length; i++) {
            System.out.println("v[" + i + "]=" + v[i]);
        }
    }


    public void printArrays(String title) {
        System.out.println(title);
        for (int i = 0; i < r_data.length; i++) {
            System.out.println("[" + i + "]=("
                    + r_data[i] + "," + i_data[i] + ")");
        }

    }

    public void printReal(String title) {
        System.out.println(title);
        for (int i = 0; i < r_data.length; i++) {
            System.out.println("re[" + i + "]="
                    + r_data[i]);
        }

    }


    public static void testPSD() {


        FFTR2Double f = new FFTR2Double();

        int N = 8;

        double x1[] = new double[N];
        for (int j = 0; j < N; j++)
            x1[j] = j;


        double[] in_r;
        double[] in_i = new double[N];

        // copy test signal.
        in_r = Mat1.arrayCopy(x1);

        f.forwardFFT(in_r, in_i);
        f.printArrays("After the FFT");
        double psd[] = f.getPowerSpectralDensity();
        FFTR2Double.printArray(psd, "The psd");
    }

    public static void timeFFT() {
        for (int i = 2; i < 2048; i = i * 2)
            DFTTest.timeFFT(i);
    }

    public static double getTimeFFT(int n) {
        FFTR2Double f = new FFTR2Double(n);


        double x1[] = new double[n];
        for (int j = 0; j < n; j++)
            x1[j] = j;


        double[] in_r;
        double[] in_i = new double[n];

        // copy test signal.
        in_r = Mat1.arrayCopy(x1);
        StopWatch t = new StopWatch();
        t.start();
        f.forwardFFT(in_r, in_i);
        f.reverseFFT(in_r, in_i);
        t.stop();
        return t.getTimeInMs();
    }

    public static void main(String[] args) {
        testFFT();
    }
    public static void testFFT() {
        System.out.println("Starting 1D FFT test...");
        FFTR2Double f = new FFTR2Double();

        int N = 8;

        double x1[] = new double[N];
        for (int j = 0; j < N; j++)
            x1[j] = j;


        double[] in_r;
        double[] in_i = new double[N];

        double[] fftResult_r;
        double[] fftResult_i;

        // copy test signal.
        in_r = Mat1.arrayCopy(x1);

        f.forwardFFT(in_r, in_i);

        // Copy to new array because IFFT will
        // destroy the FFT results.
        fftResult_r = Mat1.arrayCopy(in_r);
        fftResult_i = Mat1.arrayCopy(in_i);


        f.reverseFFT(in_r, in_i);

        System.out.println("j\tx1[j]\tre[j]\tim[j]\tv[j]");
        for (int i = 0; i < N; i++) {
            System.out.println(i + "\t" +
                    x1[i] + "\t" +
                    fftResult_r[i] + "\t" +
                    fftResult_i[i] + "\t" +
                    in_r[i]);
        }

    }



    public static void print(double x[], double y[]) {
        for (int i = 0; i < x.length; i++) {
            System.out.println(x[i] + ", " + y[i]);
        }
    }

    public static void testDftVsFft() {
        double na[] = {16, 32, 64, 128, 256, 512};
        double time[] = new double[na.length];
        for (int i = 0; i < na.length; i++) {
            time[i] = getTimeFFT((int) na[i]);
            time[i] = getTimeFFT((int) na[i]);
            time[i] = getTimeFFT((int) na[i]);
            time[i] = getTimeFFT((int) na[i]);
        }
        Graph.displayGraph(na, time, "n", "ms");
        // print(na,time);
        for (int i = 0; i < na.length; i++) {
            time[i] = DFTTest.getTimeDFT((int) na[i]);
            time[i] = DFTTest.getTimeDFT((int) na[i]);
            time[i] = DFTTest.getTimeDFT((int) na[i]);
            time[i] = DFTTest.getTimeDFT((int) na[i]);
        }
        //print(na, time);
        Graph.displayGraph(na, time, "n", "ms");
    }

}
