package ip.gui.frames;

// to run use:
// ip.gui.WaveletFrame

import ip.transforms.Lifting;
import utils.Print;

import java.awt.*;
import java.awt.event.ActionEvent;

public class WaveletFrame extends FFTFrame {

    private Menu waveletMenu = getMenu("Wavelet");
    private MenuItem forwardHaar_mi = addMenuItem(waveletMenu, "Forward Haar");
    private MenuItem backwardHaar_mi = addMenuItem(waveletMenu, "Backward Haar");
    private MenuItem liftingForwardHaar_mi = addMenuItem(waveletMenu, "LiftingForward Haar");
    private MenuItem liftingBackwardHaar_mi = addMenuItem(waveletMenu, "LiftingBackward Haar");
    private MenuItem demo1d_mi = addMenuItem(waveletMenu, "demo 1d");
    private MenuItem demo2d_mi = addMenuItem(waveletMenu, "demo 2d");

    private MenuItem haarCompress_mi = addMenuItem(waveletMenu, "[E-c]haarCompress");
    private MenuItem stripimage_mi = addMenuItem(waveletMenu, "stripimage");

    private MenuItem ulawEncode_mi = addMenuItem(waveletMenu, "ulaw encode");
    private MenuItem ulawDecode_mi = addMenuItem(waveletMenu, "ulaw decode");
    private MenuItem clip_mi = addMenuItem(waveletMenu, "clip");
    private MenuItem clearQuad1_mi = addMenuItem(waveletMenu, "clear quad1");
    private MenuItem clearQuad2_mi = addMenuItem(waveletMenu, "clear quad2");
    private MenuItem clearQuad3_mi = addMenuItem(waveletMenu, "clear quad3");
    private MenuItem clearLowerHalf_mi = addMenuItem(waveletMenu, "clear lower half");
    private MenuItem clearLower34_mi = addMenuItem(waveletMenu, "clear lower 3/4");

    private MenuItem stats_mi = addMenuItem(getFileMenu(), "compute stats");


    public void actionPerformed(ActionEvent e) {

        if (match(e, clearLowerHalf_mi)) {
            clearLowerHalf();
            return;
        }
        if (match(e, clearLower34_mi)) {
            clearLower34();
            return;
        }
        if (match(e, clearQuad3_mi)) {
            clearQuad3();
            return;
        }
        if (match(e, clearQuad1_mi)) {
            clearQuad1();
            return;
        }
        if (match(e, clearQuad2_mi)) {
            clearQuad2();
            return;
        }
        if (match(e, stripimage_mi)) {
            stripimage();
            return;
        }
        if (match(e, haarCompress_mi)) {
            haarCompress();
            return;
        }
        if (match(e, clip_mi)) {
            clip();
            return;
        }
        if (match(e, ulawDecode_mi)) {
            ulawDecode();
            return;
        }
        if (match(e, ulawEncode_mi)) {
            ulawEncode();
            return;
        }
        if (match(e, stats_mi)) {
            stats();
            return;
        }

        if (match(e, liftingForwardHaar_mi)) {
            liftingForwardHaar();
            return;
        }
        if (match(e, liftingBackwardHaar_mi)) {
            liftingBackwardHaar();
            return;
        }
        if (match(e, backwardHaar_mi)) {
            backwardHaar();
            return;
        }
        if (match(e, forwardHaar_mi)) {
            forwardHaar();
            return;
        }

        if (match(e, demo2d_mi)) {
            demo2d();
            return;
        }
        if (match(e, demo1d_mi)) {
            demo1d();
            return;
        }
        super.actionPerformed(e);
    }

    public void demo2d() {
        int[][] i =
                {
                    {9, 7, 5, 3},
                    {3, 5, 7, 9},
                    {2, 4, 6, 8},
                    {4, 6, 8, 10}
                };
        short x[][] = getShort(i);
        print(x);
        forwardHaar(x);
        print(x);
    }
    public final static short[][] getShort(int a[][]){
         short s[][] = new short[a.length][a[0].length];
        for (int i=0; i < a.length; i++)
            for (int j=0; j < a[0].length; j++)
                s[i][j] = (short)a[i][j];
        return s;
    }

    public void demo1d() {
        int s[][] =
                {
                    {9, 7, 5, 3},
                    {3, 5, 7, 9},
                    {2, 4, 6, 8},
                    {4, 6, 8, 10}
                };
        short x[][] = getShort(s);
        print(x);
        for (int i = 0; i < x.length; i++)
            forwardHaar2(x[i]);
        print(x);
    }

    public void print(short in[][]) {
        for (int i = 0; i < in.length; i++) {
            for (int j = 0; j < in[0].length; j++)
                Print.print(in[i][j] + "\t");
            Print.println("");
        }
        Print.println("-------------------");
    }

    public void forwardHaar() {
        fh(shortImageBean.getR());
        fh(shortImageBean.getG());
        fh(shortImageBean.getB());
        short2Image();
    }

    public void liftingForwardHaar() {
        Lifting.forwardHaar(shortImageBean.getR());
        Lifting.forwardHaar(shortImageBean.getG());
        Lifting.forwardHaar(shortImageBean.getB());
        short2Image();
    }

    public void liftingBackwardHaar() {
        Lifting.backwardHaar(shortImageBean.getR());
        Lifting.backwardHaar(shortImageBean.getG());
        Lifting.backwardHaar(shortImageBean.getB());
        short2Image();
    }

    public void fh(short in[][]) {
        forwardHaar(in);
    }

    public void backwardHaar() {
        backwardHaar(shortImageBean.getR());
        backwardHaar(shortImageBean.getG());
        backwardHaar(shortImageBean.getB());
        clip();
    }

    private static void forwardHaar(short in[][]) {
        int width = in.length;
        int height = in[0].length;
        short temp[] = new short[width];
        for (int i = 0; i < width; i++)
            forwardHaar2(in[i]);
        for (int j = 0; j < height; j++) {
            for (int i = 0; i < width; i++)
                temp[i] = in[i][j];
            forwardHaar2(temp);
            for (int i = 0; i < width; i++)
                in[i][j] = temp[i];
        }
    }


    private void backwardHaar(short in[][]) {
        int width = in.length;
        int height = in[0].length;
        short out[] = new short[width];
        for (int i = 0; i < width; i++)
            backwardHaar2(in[i]);
        for (int j = 0; j < height; j++) {
            for (int i = 0; i < width; i++)
                out[i] = in[i][j];
            backwardHaar2(out);
            for (int i = 0; i < width; i++)
                in[i][j] = out[i];
        }
    }

    private static void forwardHaar2(short in[]) {
        int n = in.length;
        int nOn2 = n / 2;
        if (n < 2) return;
        for (int i = 0; i < n; i += 2) {
            in[i + 1] -= in[i];
            in[i] += in[i + 1] / 2;
        }
        short averages[] = new short[n / 2];
        for (int i = nOn2 - 1; i >= 0; i--) {
            averages[i] = in[2 * i];
            in[i + nOn2] = in[2 * i + 1];
        }
        forwardHaar2(averages);
        for (int i = 0; i < nOn2; i++)
            in[i] = averages[i];
    }

    private void backwardHaar2(short in[]) {
        int n = in.length;
        if (n < 2) return;
        int nOn2 = n / 2;

        short averages[] = new short[nOn2];
        for (int i = 0; i < nOn2; i++)
            averages[i] = in[i];
        backwardHaar2(averages);
        for (int i = 0; i < nOn2; i++) {
            in[2 * i] = averages[i];
            in[2 * i + 1] = in[i + nOn2];
        }
        for (int i = 0; i < n; i += 2) {
            in[i] -= in[i + 1] / 2;
            in[i + 1] += in[i];
        }
    }

    public int[][] short2Int(short s[][]) {
        int a[][] = new int[getImageWidth()][getImageHeight()];
        for (int x = 0; x < getImageWidth(); x++)
            for (int y = 0; y < getImageHeight(); y++)
                a[x][y] = s[x][y];
        return a;
    }

    public short[][] int2Short(int s[][]) {
        short a[][] = new short[getImageWidth()][getImageHeight()];
        for (int x = 0; x < getImageWidth(); x++)
            for (int y = 0; y < getImageHeight(); y++)
                a[x][y] = (short) s[x][y];
        return a;
    }

    public WaveletFrame(String title) {
        super(title);
        getXformMenu().add(waveletMenu);
    }


    /**
     print statistics on the image
     */
    public void stats() {
        int max[] = {-10000, -1000, -1000};
        int min[] = {10000, 1000, 1000};
        double average[] = new double[3];
        for (int x = 0; x < getImageWidth(); x++)
            for (int y = 0; y < getImageHeight(); y++) {
                if (shortImageBean.getR()[x][y] > max[0]) max[0] = shortImageBean.getR()[x][y];
                if (shortImageBean.getG()[x][y] > max[1]) max[1] = shortImageBean.getR()[x][y];
                if (shortImageBean.getB()[x][y] > max[2]) max[2] = shortImageBean.getR()[x][y];
                if (shortImageBean.getR()[x][y] < min[0]) min[0] = shortImageBean.getR()[x][y];
                if (shortImageBean.getG()[x][y] < min[1]) min[1] = shortImageBean.getR()[x][y];
                if (shortImageBean.getB()[x][y] < min[2]) min[2] = shortImageBean.getR()[x][y];
                average[0] += shortImageBean.getR()[x][y];
                average[1] += shortImageBean.getG()[x][y];
                average[2] += shortImageBean.getB()[x][y];
            }
        int n = getImageWidth() * getImageHeight();
        average[0] = average[0] / n;
        average[1] = average[1] / n;
        average[2] = average[2] / n;
        Print.println("------ Statistics -----");
        Print.println("\tR\tG\tB\t");
        Print.println("min:" + min[0] + "\t" + min[1] + "\t" + min[2]);
        Print.println("max:" + max[0] + "\t" + max[1] + "\t" + max[2]);
        Print.println("avg:" + average[0] + "\t" + average[1] + "\t" + average[2]);
    }

    public static void main(String args[]) {
        WaveletFrame wf = new WaveletFrame("wavelet frame");
        wf.show();
    }

    public void ulawEncode() {
        for (int x = 0; x < getImageWidth(); x++) {
            shortImageBean.getR()[x] = UlawCodec.encode(shortImageBean.getR()[x]);
            shortImageBean.getG()[x] = UlawCodec.encode(shortImageBean.getG()[x]);
            shortImageBean.getB()[x] = UlawCodec.encode(shortImageBean.getB()[x]);
        }
        //add(128);
        short2Image();
    }

    public void ulawDecode() {
        for (int x = 0; x < getImageWidth(); x++) {
            shortImageBean.getR()[x] = UlawCodec.decode(shortImageBean.getR()[x]);
            shortImageBean.getG()[x] = UlawCodec.decode(shortImageBean.getG()[x]);
            shortImageBean.getB()[x] = UlawCodec.decode(shortImageBean.getB()[x]);
        }
        //add(-128);
        short2Image();
    }


    short eps = 0;

    public void haarCompress() {
        forwardHaar();
        stripimage();
        backwardHaar();
    }

    public void stripimage() {
        eps += 5;
        Print.println("Haar compress factor=" + eps);
        for (int x = 0; x < getImageWidth(); x++)
            for (int y = 0; y < getImageHeight(); y++) {
                shortImageBean.getR()[x][y] = strip(shortImageBean.getR()[x][y], eps);
                shortImageBean.getG()[x][y] = strip(shortImageBean.getG()[x][y], eps);
                shortImageBean.getB()[x][y] = strip(shortImageBean.getB()[x][y], eps);
            }
    }

    public void clearQuad1() {
        clearQuad(getImageWidth() / 2, getImageHeight() / 2, getImageWidth(), getImageHeight());
        short2Image();
    }

    public void clearQuad2() {
        clearQuad(getImageWidth() / 2, 0, getImageWidth(), getImageHeight());
        clearQuad(0, getImageHeight() / 2, getImageWidth(), getImageHeight());
        short2Image();
    }

    public void clearQuad3() {
        clearQuad(getImageWidth() / 4, 0, getImageWidth(), getImageHeight());
        clearQuad(0, getImageHeight() / 4, getImageWidth(), getImageHeight());
        short2Image();
    }

    public void clearLowerHalf() {
        clearQuad(0, getImageHeight() / 2, getImageWidth(), getImageHeight());
        short2Image();
    }

    public void clearLower34() {
        clearQuad(0, getImageHeight() / 4, getImageWidth(), getImageHeight());
        short2Image();
    }

    public void clearQuad(int x1, int y1, int x2, int y2) {
        for (int x = x1; x < x2; x++)
            for (int y = y1; y < y2; y++) {
                shortImageBean.getR()[x][y] = 0;
                shortImageBean.getG()[x][y] = 0;
                shortImageBean.getB()[x][y] = 0;
            }
    }

    public short strip(short i, short eps) {
        if (Math.abs(i) < eps) return 0;
        return i;
    }

    public void clip() {
        for (int x = 0; x < getImageWidth(); x++)
            for (int y = 0; y < getImageHeight(); y++) {
                shortImageBean.getR()[x][y] = clip(shortImageBean.getR()[x][y]);
                shortImageBean.getG()[x][y] = clip(shortImageBean.getG()[x][y]);
                shortImageBean.getB()[x][y] = clip(shortImageBean.getB()[x][y]);
            }
        short2Image();
    }

    private short clip(short i) {
        if (i < 0) return 0;
        if (i > 255) return 255;
        return i;
    }

}

class UlawCodec {
    public static double mu = 255.0;
    public static double vmax = 255;
    public static double offset = vmax / 2 + 2;
    private static double factor = 22;
    private static double muOnVmax = mu / vmax;

    public static short[] encode(short a[]) {
        for (int i = 0; i < a.length; i++)
            a[i] = encode(a[i]);
        return a;
    }

    public static short decode(short x) {
        double a = (x - offset) / factor;
        a = Math.exp(a) - 1;
        a = a / muOnVmax;
        return (short) a;
    }

    public static short encode(short x) {
        return
                (short) (offset + sign(x) * factor * Math.log(1 + Math.abs(x) * muOnVmax));
    }

    public static short sign(short x) {
        if (x < 0) return -1;
        return 1;
    }

    public static short[] decode(short a[]) {
        for (int i = 0; i < a.length; i++)
            a[i] = decode(a[i]);
        return a;
    }


}