package ip.gui.frames;

import collections.sortable.Cshort;
import ip.gui.dialog.DoubleArrayLog;
import ip.transforms.Gauss;
import ip.transforms.Kernels;
import j2d.ShortImageBean;
import math.Mat2;
import utils.StopWatch;

import java.awt.*;
import java.awt.event.ActionEvent;
import java.util.Vector;

public class SpatialFilterFrame extends ConvolutionFrame {
    public SpatialFilterFrame child = null;
    Menu SpatialFilterMenu = getMenu("SpatialFilter");

    private Menu lowPassMenu = getMenu("LowPass");
    private Menu gaussianMenu = getMenu("gaussian");
    private Menu medianMenu = getMenu("Median");
    private Menu highPassMenu = getMenu("Hi-pass");
    private Menu unsharpMenu = getMenu("unsharp");

    private Menu kernalMenu = getMenu("Convolution Kernal");
    private MenuItem showConvolutionKernal_mi
            = addMenuItem(kernalMenu, "show");
    private MenuItem enterConvolutionKernal33_mi
            = addMenuItem(kernalMenu, "3x3..");
    private MenuItem enterConvolutionKernal55_mi
            = addMenuItem(kernalMenu, "5x5..");
    private MenuItem enterConvolutionKernal77_mi
            = addMenuItem(kernalMenu, "7x7..");

    private MenuItem average_mi = addMenuItem(lowPassMenu, "[a]verage");
    private MenuItem lp1_mi = addMenuItem(lowPassMenu, "[E-1]lp1");
    private MenuItem lp2_mi = addMenuItem(lowPassMenu, "[E-2]lp2");
    private MenuItem lp3_mi = addMenuItem(lowPassMenu, "[E-3]lp3");
    private MenuItem mean3_mi = addMenuItem(lowPassMenu, "[E-4]mean3");
    private MenuItem mean9_mi = addMenuItem(lowPassMenu, "[E-5]mean9");

    private MenuItem gauss3_mi = addMenuItem(gaussianMenu, "[E-6]gauss 3x3");
    private MenuItem gauss7_mi = addMenuItem(gaussianMenu, "[E-7]gauss 7x7");
    private MenuItem gauss15_mi = addMenuItem(gaussianMenu, "[E-9]gauss 15x15");
    private MenuItem gauss31_mi = addMenuItem(gaussianMenu, "[E-T-G]auss 31x31");
    private MenuItem gauss31_fast_mi = addMenuItem(gaussianMenu, "Gauss 31x31 fast");
    private MenuItem gabor_mi = addMenuItem(lowPassMenu, "[T-G]abor");

    private MenuItem medianCross3x3_mi = addMenuItem(medianMenu, "[E-T-+]cross3x3");
    private MenuItem medianSquare3x3_mi = addMenuItem(medianMenu, "[E-T-s]quare 3x3");
    private MenuItem medianOctagon5x5_mi = addMenuItem(medianMenu, "[E-T-o]catgon 5x5");
    private MenuItem medianSquare5x5_mi = addMenuItem(medianMenu, "[E-T-S]quare 5x5");
    private MenuItem medianDiamond7x7_mi = addMenuItem(medianMenu, "[E-T-D]iamond 7x7");
    private MenuItem medianCross7x7_mi = addMenuItem(medianMenu, "[E-T-C]ross 7x7");
    private MenuItem outlierEstimate_mi = addMenuItem(medianMenu, "[E-T-O]utlier estimate");
    private MenuItem saltAndPepper100_mi = addMenuItem(medianMenu, "[E-T-1]saltAndPepper100");
    private MenuItem saltAndPepper1000_mi = addMenuItem(medianMenu, "[E-T-2]saltAndPepper1000");
    private MenuItem saltAndPepper2000_mi = addMenuItem(medianMenu, "[E-T-3]saltAndPepper2000");
    private MenuItem saltAndPepper4000_mi = addMenuItem(medianMenu, "[E-T-4]saltAndPepper4000");

    private MenuItem hp1_mi = addMenuItem(highPassMenu, "[T-1]hp1");
    private MenuItem hp2_mi = addMenuItem(highPassMenu, "[T-2]hp2");
    private MenuItem hp3_mi = addMenuItem(highPassMenu, "[T-3]hp3");
    private MenuItem hp4_mi = addMenuItem(highPassMenu, "[T-4]hp4");
    private MenuItem hp5_mi = addMenuItem(highPassMenu, "[T-5]hp5");

    MenuItem shadowMask_mi = addMenuItem(highPassMenu, "[T-6]shadowMask");

    private MenuItem usp1_mi = addMenuItem(unsharpMenu, "[T-7]usp1");

    private MenuItem subtractChild_mi = addMenuItem(getFileMenu(), "[T-8]subtract child");

    private MenuItem short2Image_mi = addMenuItem(getFileMenu(), "[T-9]short2Image");
    private MenuItem clip_mi = addMenuItem(getFileMenu(), "[T-0]clip");


    private boolean computeOutlier = true;
    private int numberOfOutliers = 0;


    public void makeChild() {
        child = new SpatialFilterFrame("child");
        ImageFrameInterface frame = child;
        int width1 = getImageWidth();
        frame.setImageWidth(width1);
        ImageFrameInterface frame1 = child;
        int height1 = getImageHeight();
        frame1.setImageHeight(height1);
        short[][] r = Mat2.copyArray(shortImageBean.getR());
        shortImageBean.setR(r);
        child.setG(Mat2.copyArray(shortImageBean.getG()));
        child.setB(Mat2.copyArray(shortImageBean.getB()));
    }

    public void subtractChild() {
        subtract(shortImageBean);
        short2Image();
    }

    public void subtract(ShortImageBean sibB) {

        ShortImageBean.subtract(shortImageBean, sibB);

    }

    public void outlierEstimate() {
        setComputeOutlier(!isComputeOutlier());

        System.out.println(
                "computeOutlier = " + isComputeOutlier());
        System.out.println(
                "numberOfOutliers = " + getNumberOfOutliers());
        setNumberOfOutliers(0);
    }

    public SpatialFilterFrame(String title) {
        super(title);
        MenuBar mb = getMenuBar();
        SpatialFilterMenu.add(kernalMenu);
        lowPassMenu.add(gaussianMenu);
        SpatialFilterMenu.add(lowPassMenu);
        SpatialFilterMenu.add(medianMenu);
        highPassMenu.add(unsharpMenu);
        SpatialFilterMenu.add(highPassMenu);
        mb.add(SpatialFilterMenu);
        setMenuBar(mb);
    }


    public void clip() {
        ShortImageBean.clip(shortImageBean);
    }

    public void enterConvolutionKernal() {
        new DoubleArrayLog(
                this, "Convolution kernal", null, 0, 0, 6);
    }

    public void enterConvolutionKernal(int r, int c) {
        new DoubleArrayLog(
                this, "Convolution kernal", null, r, c, 6);
    }

    public void actionPerformed(ActionEvent e) {

        if (match(e, gabor_mi)) {
            gabor7();
            return;
        }
        if (match(e, enterConvolutionKernal33_mi)) {
            enterConvolutionKernal(3, 3);
            return;
        }

        if (match(e, enterConvolutionKernal55_mi)) {
            enterConvolutionKernal(5, 5);
            return;
        }

        if (match(e, enterConvolutionKernal77_mi)) {
            enterConvolutionKernal(7, 7);
            return;
        }
        if (match(e, showConvolutionKernal_mi)) {
            showConvolutionKernal();
            return;
        }
        if (match(e, clip_mi)) {
            clip();
            return;
        }
        if (match(e, short2Image_mi)) {
            short2Image();
            return;
        }
        if (match(e, subtractChild_mi)) {
            subtractChild();
            return;
        }
        if (match(e, usp1_mi)) {
            usp1();
            return;
        }
        if (match(e, outlierEstimate_mi)) {
            outlierEstimate();
            return;
        }
        if (match(e, medianCross7x7_mi)) {
            medianCross7x7();
            return;
        }
        if (match(e, medianCross3x3_mi)) {
            medianCross3x3();
            return;
        }
        if (match(e, medianSquare3x3_mi)) {
            medianSquare3x3();
            return;
        }
        if (match(e, medianOctagon5x5_mi)) {
            medianOctagon5x5();
            return;
        }
        if (match(e, medianSquare5x5_mi)) {
            medianSquare5x5();
            return;
        }
        if (match(e, medianDiamond7x7_mi)) {
            medianDiamond7x7();
            return;
        }
        if (match(e, mean9_mi)) {
            mean9();
            return;
        }
        if (match(e, mean3_mi)) {
            mean3();
            return;
        }
        if (match(e, saltAndPepper100_mi)) {
            saltAndPepper(100);
            return;
        }
        if (match(e, saltAndPepper1000_mi)) {
            saltAndPepper(1000);
            return;
        }
        if (match(e, saltAndPepper2000_mi)) {
            saltAndPepper(2000);
            return;
        }
        if (match(e, saltAndPepper4000_mi)) {
            saltAndPepper(4000);
            return;
        }
        if (match(e, gauss3_mi)) {
            gauss3();
            return;
        }
        if (match(e, gauss7_mi)) {
            gauss7();
            return;
        }
        if (match(e, gauss15_mi)) {
            gauss15();
            return;
        }
        if (match(e, gauss31_mi)) {
            gauss31();
            return;
        }
        if (match(e, gauss31_fast_mi)) {
            gauss31Fast();
            return;
        }

        if (match(e, lp1_mi)) {
            lp1();
            return;
        }
        if (match(e, lp2_mi)) {
            lp2();
            return;
        }
        if (match(e, lp3_mi)) {
            lp3();
            return;
        }
        if (match(e, hp1_mi)) {
            hp1();
            return;
        }
        if (match(e, hp2_mi)) {
            hp2();
            return;
        }
        if (match(e, hp3_mi)) {
            hp3();
            return;
        }
        if (match(e, hp4_mi)) {
            hp4();
            return;
        }
        if (match(e, hp5_mi)) {
            hp5();
            return;
        }
        if (match(e, average_mi)) {
            average();
            return;
        }

        super.actionPerformed(e);

    }

    public void saltAndPepper(int n) {
        shortImageBean.saltAndPepper(n);
        short2Image();
    }


    public void average() {
        float[][] k = Kernels.getAverage3x3();
        convolve(k);
    }


    public void hp1() {
        float[][] k = Kernels.getHp1();
        convolve(k);
    }

    public void hp2() {
        float[][] k = Kernels.getHp2();
        convolve(k);
    }

    public void hp3() {
        float[][] k = Kernels.getHp3();
        convolve(k);
    }

    public void hp4() {
        float[][] k = Kernels.getHp4();
        convolve(k);
    }

    public void hp5() {
        float[][] k = Kernels.getHp5();
        convolve(k);
    }


    public void usp1() {
        makeChild();
        child.gauss3();
        subtract(shortImageBean);
        short2Image();
    }

    public void lp1() {
        float[][] k = Kernels.getLp1();
        convolve(k);
    }

    public void lp2() {
        float[][] k = Kernels.getLp2();
        convolve(k);
    }

    public void lp3() {
        float[][] k = Kernels.getLp3();
        convolve(k);
    }

    public void gabor7() {
        float[][] k = Kernels.getGabor7();
        convolve(k);

    }


    public void mean9() {
        float[][] k = Kernels.getMean9();

        convolve(k);
    }

    public void mean3() {
        float[][] k = Kernels.getMean3();
//sum=1.0000000074505806
        convolve(k);
    }

    public void gauss3() {
        float[][] k = Gauss.getGauss3();
        convolve(k);
    }

// computes an MxN kernel using (9.1)


    public static double oneOnF(
            double x, double y,
            double xc, double yc) {
        double dx = x - xc;
        double dy = y - yc;
        double dx2 = dx * dx;
        double dy2 = dy * dy;
        double eps = 1;
        return
                1 / Math.sqrt(dx2 + dy2 + eps);
    }

    public static float[][] getOneOnFKernel(
            int M, int N) {
        float k[][] = new float[M][N];
        int xc = M / 2;
        int yc = N / 2;
        for (int x = 0; x < k.length; x++)
            for (int y = 0; y < k[0].length; y++)
                k[x][y] = (float) oneOnF(x, y, xc, yc);
        return k;
    }

    public void multOneOnF() {
        int xc = getImageWidth() / 2;
        int yc = getImageHeight() / 2;
        for (int x = 0; x < getImageWidth(); x++)
            for (int y = 0; y < getImageHeight(); y++) {
                double f = oneOnF(x, y, xc, yc);
                shortImageBean.getR()[x][y] = (short) (shortImageBean.getR()[x][y] * f);
                shortImageBean.getG()[x][y] = (short) (shortImageBean.getG()[x][y] * f);
                shortImageBean.getB()[x][y] = (short) (shortImageBean.getB()[x][y] * f);
            }
        short2Image();
    }

    public void printOneOnFKernel() {
        float k[][] = getOneOnFKernel(9, 9);
        Mat2.printKernel(k, "OneOnFKernel" + k.length);
    }

    public void gauss7() {
        float[][] k = Gauss.getGauss7();
        convolve(k);
    }

    public void gauss15() {
        float[][] k = Gauss.getGauss15();
        convolve(k);
    }

    public void gauss31Fast() {
        double sigma = 2.0;
        float k[][] = Gauss.getGaussKernel(31, 31, sigma);
        convolveFast(k);
    }

    public void gauss31() {
        double sigma = 2.0;
        float k[][] = Gauss.getGaussKernel(31, 31, sigma);
        convolve(k);
    }


    public void medianSquare3x3() {
        short[][] k = Kernels.getMedianCross3x3();
        median(k);
    }

    public void medianSquare5x5() {
        short[][] k = Kernels.getMedianCross5x5();
        median(k);
    }

    public void medianOctagon5x5() {
        short[][] k = Kernels.getMedianOctagon5x5();
        median(k);
    }

    public void medianDiamond7x7() {
        short[][] k = Kernels.getMedianDiamond();
        median(k);
    }

    public void medianCross7x7() {
        short[][] k = Kernels.getMedianCross7x7();
        median(k);
    }


    public void medianSquare7x7() {
        short[][] k = Kernels.getMedianSquare7x7();
        median(k);
    }

    public void medianCross3x3() {
        short k[][] = {
            {0, 1, 0},
            {1, 1, 1},
            {0, 1, 0}
        };
        median(k);
    }

    public void median(short k[] []) {
        Mat2.printMedian(k, "color median");
        StopWatch t = new StopWatch();
        t.start();
        short[][] r = median(shortImageBean.getR(), k);
        shortImageBean.setR(r);
        setG(median(shortImageBean.getG(), k));
        setB(median(shortImageBean.getB(), k));
        t.print("Median filter time");
        short2Image();
    }

    public void medianBottom(
            short f[][], short k[][], short h[][]) {
        int windowLength = 0;
        int window[];
        int uc = 0;
        int vc = 0;
        System.out.println("k=" + k.length);
        uc = k.length / 2;
        vc = k[0].length / 2;
        windowLength = Mat2.numberOfNonZeros(k);
        window = new int[windowLength];

        //median bottom
        for (int x = 0; x < getImageWidth() - 1; x++)
            for (int y = 0; y < vc; y++) {
                int loc = 0;
                for (int v = -vc; v <= vc; v++)
                    for (int u = -uc; u <= uc; u++)
                        if (k[u + uc][v + vc] != 0)
                            window[loc++] = f[cx(x - u)][cy(y - v)];
                h[x][y] = (short) median(window);
            }
    }

    public void medianLeft(
            short f[][], short k[][], short h[][]) {
        int uc = k.length / 2;
        int vc = k[0].length / 2;

        int windowLength = Mat2.numberOfNonZeros(k);
        int window[] = new int[windowLength];
        //median bottom
        //median left
        for (int x = 0; x < uc; x++)
            for (int y = vc; y < getImageHeight() - vc; y++) {
                int loc = 0;
                for (int v = -vc; v <= vc; v++)
                    for (int u = -uc; u <= uc; u++)
                        if (k[u + uc][v + vc] != 0)
                            window[loc++] = f[cx(x - u)][cy(y - v)];
                h[x][y] = (short) median(window);
            }
    }

    public void medianRightAndTop(
            short f[][], short k[][], short h[][]) {
        int uc = k.length / 2;
        int vc = k[0].length / 2;

        int windowLength = Mat2.numberOfNonZeros(k);
        int window[] = new int[windowLength];
        //median right
        for (
                int x = getImageWidth() - uc;
                x < getImageWidth() - 1; x++)
            for (int y = vc; y < getImageHeight() - vc; y++) {
                int loc = 0;
                for (int v = -vc; v <= vc; v++)
                    for (int u = -uc; u <= uc; u++)
                        if (k[u + uc][v + vc] != 0)
                            window[loc++] = f[cx(x - u)][cy(y - v)];
                h[x][y] = (short) median(window);
            }

        //median top
        for (int x = 0; x < getImageWidth() - 1; x++) {
            for (
                    int y = getImageHeight() - vc;
                    y < getImageHeight() - 1; y++) {
                int loc = 0;
                for (int v = -vc; v <= vc; v++)
                    for (int u = -uc; u <= uc; u++)
                        if (k[u + uc][v + vc] != 0)
                            window[loc++] = f[cx(x - u)][cy(y - v)];
                h[x][y] = (short) median(window);
            }
        }
    }

    // median, optimze the edges
    public short[][] median(short f[][], short k[][]) {
        short h[][] = medianNoEdge(f, k);
        medianBottom(f, k, h);
        medianLeft(f, k, h);
        medianRightAndTop(f, k, h);
        return h;
    }

    public short[][] medianNoEdge(short f[][], short k[][]) {
        int uc = k.length / 2;
        int vc = k[0].length / 2;
        short h[][] = new short[getImageWidth()][getImageHeight()];

        int windowLength = Mat2.numberOfNonZeros(k);
        int window[] = new int[windowLength];

        for (int x = uc; x < getImageWidth() - uc; x++) {
            for (int y = vc; y < getImageHeight() - vc; y++) {
                int loc = 0;
                for (int v = -vc; v <= vc; v++)
                    for (int u = -uc; u <= uc; u++)
                        if (k[u + uc][v + vc] != 0)
                            window[loc++] = f[x - u][y - v];
                h[x][y] = (short) median(window);
                f[x][y] = (short) median(window);
            }

        }
        return h;
    }

    public void testMedian() {
        int a[] = {1, 2, 3, 5, 4, 3, 2, 5, 6, 7};
        System.out.println("The median =" + median(a));
    }

    public static void main(String args[]) {
        Mat2.testOutlier();
    }

    public int median(int a[]) {
        int mid = a.length / 2 - 1;
        if (isComputeOutlier()) {
            if (!Mat2.outlierHere(a)) {
                return a[mid];
            }
        }
        setNumberOfOutliers(getNumberOfOutliers() + 1);
        Mat2.quickSort(a);

        if ((a.length & 1) == 1)
            return a[mid];
        return (int) ((a[mid] + a[mid + 1] + 0.5) / 2);
    }
// The sloooww way to do a median filter.
// This one uses fancy design patterns,
// including the template method.
// What are these CS people thinking?
// 18 seconds to do a 128x128 grayscale median
// filter with quicksort using MRJ 2.0
// on a powermac 8100/100.
// Gosh!
// Simple quicksort based median filter
// works in 1/2 second for the same image
// yes...36 times faster!
    public short[][] medianSlow(short f[][], short k[][]) {
        int uc = k.length / 2;
        int vc = k[0].length / 2;
        short h[][] = new short[getImageWidth()][getImageHeight()];
        Cshort cs = new Cshort(0);
        Vector window = new Vector();
        int windowLength = Mat2.numberOfNonZeros(k);
        for (int i = 0; i < windowLength; i++)
            window.addElement(new Cshort(0));

        for (int x = uc; x < getImageWidth() - uc; x++) {
            for (int y = vc; y < getImageHeight() - vc; y++) {
                int loc = 0;
                for (int v = -vc; v <= vc; v++)
                    for (int u = -uc; u <= uc; u++)
                        if (k[u + uc][v + vc] != 0.0f) {
                            cs = (Cshort) window.elementAt(loc++);
                            cs.setValue((f[x - u][y - v]));
                        }
                h[x][y] = Mat2.median(window);
            }

        }
        return h;
    }

    public boolean isComputeOutlier() {
        return computeOutlier;
    }

    public void setComputeOutlier(boolean computeOutlier) {
        this.computeOutlier = computeOutlier;
    }

    public int getNumberOfOutliers() {
        return numberOfOutliers;
    }

    public void setNumberOfOutliers(int numberOfOutliers) {
        this.numberOfOutliers = numberOfOutliers;
    }
}