/*
 * Copyright (c) 2005 Your Corporation. All Rights Reserved.
 */
package j2d.transforms;

import futils.Futil;
import gui.ClosableJFrame;
import gui.dialogs.ScrollingAwtImagePanel;
import j2d.ImageUtils;

import javax.media.jai.*;
import javax.media.jai.operator.DFTDescriptor;
import javax.media.jai.widget.ScrollingImagePanel;
import java.awt.*;
import java.awt.image.DataBuffer;
import java.awt.image.RenderedImage;
import java.awt.image.SampleModel;
import java.awt.image.renderable.ParameterBlock;

public class DftExampleDep extends ClosableJFrame {
    public static void main(String[] args) {
        testDft();
    }

    private static void testDft2() {
        new DftExampleDep(
                Futil.getReadFile("select an image").getAbsolutePath(),
                Futil.getReadFile("select an image").getAbsolutePath());
    }

    DftExampleDep(String fileName1, String fileName2) {
        // Load the source image.
        PlanarImage src = JAI.create("fileload", fileName1);
        PlanarImage orig = JAI.create("fileload", fileName2);
        RenderedOp origDft = DFTImage(orig);
        RenderedOp gaussDft = DFTImage(src);
        RenderedOp conv = multComplex(origDft, gaussDft);

        PlanarImage idftimg = inverseDFTImage(conv);
        display(idftimg);

    }

    private PlanarImage getIdft(PlanarImage src) {
        PlanarImage idft = JAI.create("idft", src);

        // Create a RenderingHints object with desirable layout.
        ImageLayout il = new ImageLayout();
        il.setSampleModel(src.getSampleModel());
        il.setColorModel(src.getColorModel());
        RenderingHints rh = new RenderingHints(JAI.KEY_IMAGE_LAYOUT, il);
        // Convert the data to byte.
        ParameterBlock pb = new ParameterBlock();
        pb.addSource(idft);
        pb.add(DataBuffer.TYPE_BYTE);
        PlanarImage dst = JAI.create("format", pb, rh);
        return dst;
    }

    private void display(PlanarImage dst) {
        // Display.
        addComponent(new ScrollingAwtImagePanel(dst, dst.getWidth(), dst.getHeight()));
        pack();
        setVisible(true);
    }
/*
--------------------------------------------------------------------------------
------*/

    /**
     * viewDFT reformats a magnitude-squared dft-image of the JAI-library, so that
     * it is in a
     * displayable format (tone rescaled over the full grayscale, and reformatted so
     * that the DC-value is located in the center of the image).
     * Tone-scaling is done by first overwrite the DC-value with the second maximum
     * (correct DC-value
     * is lost) and then map the range minimum/maximum tonevalue to 0/255.
     * Image reformatting is done by duplicate the middle row (row -fc = row +fc)
     * and shifting the four
     * quadrants (switch quadrant 1 with 3 and switch quadrant 2 with 4).
     *
     * @param inputImg : magnitude-squared dft image to be processed.
     * @return : a tiledImage with tone values between 0-255, and the DC-value in
     *         the center.
     * @see "Numerical Recipes in C chapter 12 and figure 12.5.1 for the image
     *      layout.
     */
    public TiledImage viewDFT(PlanarImage inputImg) {
        /** ToneScale the image */
        DataBufferFloat db = (DataBufferFloat) inputImg.getData().getDataBuffer();
        db.setElemFloat(0, (float) 0);
        RenderedImage ext = JAI.create("extrema", inputImg);
        double[][] extrema = (double[][]) ext.getProperty("extrema");
        double max = extrema[1][0];
        double min = extrema[0][0];
        db.setElemFloat(0, (float) max);
        RenderedImage rescaleddisplay = JAI.create("rescale", inputImg,
                new double[]{(0xFF / (max - min))},
                new double[]{((0xFF * min) / (min - max))});
        RenderedImage display = JAI.create("format", rescaleddisplay,
                DataBuffer.TYPE_BYTE);
        /** Format the image so that the DC-value is located in the center */
        int width = display.getWidth();
        int height = display.getHeight();
        int rangeX = width / 2;      //X-indexrange
        int rangeY = height / 2;     //Y-indexrange
        int DCoriginX = width / 2;     //X-coordinate of DC-component
        int DCoriginY = height / 2;    //Y-coordinate of DC-component
        float displacementXneg = (float) -rangeX;
        // X-displacement to move negative freq.block
        float displacementYneg = (float) -rangeY;
        // Y-displacement to move negative freq.block
        float displacementXpos = (float) rangeX;
        // X-displacement to move positivefreq.block
        float displacementYpos = (float) rangeY;
        // Y-displacement to move positive freq.block
// create empty image where we will build the manipulated image
        TiledImage outImg = new
                TiledImage(0, 0, width + 1, height + 1, 0, 0, display.getSampleModel(), display.getColorModel());
// map quadrant 2 to 4 : x,y (0..256) to (256..512)
        ROIShape roi4 = new ROIShape(new
                Rectangle(DCoriginX, DCoriginY, rangeX + 1, rangeY + 1));//quadrant 4
        ParameterBlock pb = new ParameterBlock();
        pb.addSource(display);
        pb.add(displacementXpos);
        pb.add(displacementYpos);
        outImg.setData(JAI.create("translate", pb).getTile(0, 0), roi4);
// map quadrant 4 to 2 : x,y (256..511) to (0..255)
        ROIShape roi2 = new ROIShape(new Rectangle(0, 0, rangeX, rangeY));//quadrant 2
        pb.set(displacementXneg, 0);
        pb.set(displacementYneg, 1);
        outImg.setData(JAI.create("translate", pb).getTile(0, 0), roi2);
// map quadrant 1 to 3 : x(256..511) to (0..255) and y(0..256) to (256..512)
        ROIShape roi3 = new ROIShape(new
                Rectangle(0, DCoriginY, rangeX, rangeY + 1));//quadrant 3
        pb.set(displacementXneg, 0);
        pb.set(displacementYpos, 1);
        outImg.setData(JAI.create("translate", pb).getTile(0, 0), roi3);
// map quadrant 3 to 1 : x(0..256) to (256..512) and y(256..511) to (0..255)
        ROIShape roi1 = new ROIShape(new
                Rectangle(DCoriginX, 0, rangeX + 1, rangeY));//quadrant 1
        pb.set(displacementXpos, 0);
        pb.set(displacementYneg, 1);
        outImg.setData(JAI.create("translate", pb).getTile(0, 0), roi1);
        return (outImg);
    } //viewDFT

    public RenderedOp DFTImage(PlanarImage image) {
        ParameterBlockJAI pb = new ParameterBlockJAI("dft");
        pb.addSource(image);
        pb.setParameter("scalingType", DFTDescriptor.SCALING_NONE);
        pb.setParameter("dataNature", DFTDescriptor.REAL_TO_COMPLEX);
        return JAI.create("dft", pb);
    }

    public PlanarImage inverseDFTImage(RenderedOp image) {
        ParameterBlockJAI pb = new ParameterBlockJAI("idft");
        pb.addSource(image);
        pb.setParameter("scalingType", DFTDescriptor.SCALING_NONE);
        pb.setParameter("dataNature", DFTDescriptor.COMPLEX_TO_REAL);
        return JAI.create("idft", pb);
    }

    public RenderedOp multComplex(RenderedOp img1, RenderedOp img2) {
        ParameterBlockJAI pb = new ParameterBlockJAI("multiplycomplex");
        pb.addSource(img1);
        pb.addSource(img2);
        return JAI.create("multiplycomplex", pb);
    }

    public RenderedOp magnitudeImage(RenderedOp image) {
        ParameterBlockJAI pb = new ParameterBlockJAI("magnitude");
        pb.addSource(image);
        return JAI.create("magnitude", pb);
    }

    public RenderedOp phaseImage(RenderedOp image) {
        ParameterBlockJAI pb = new ParameterBlockJAI("phase");
        pb.addSource(image);
        return JAI.create("phase", pb);
    }

    public RenderedOp polarToComplexImage(RenderedOp mag,
                                          RenderedOp phase) {
        ParameterBlockJAI pb = new ParameterBlockJAI("polartocomplex");
        pb.addSource(mag);
        pb.addSource(phase);
        return JAI.create("polartocomplex", pb);
    }

    public static void testDft() {
        ClosableJFrame cf = new ClosableJFrame();
        Container c = cf.getContentPane();
        int SIZE = 18;
//        // Create a constant SIZE x SIZE byte image of unity pixel values.
        ParameterBlock pb = new ParameterBlock();
        pb.add((float) SIZE).add((float) SIZE);
        pb.add(new Byte[]{new Byte((byte) 1)});
//        RenderedImage constant = JAI.create("constant", pb);
        RenderedImage constant = ImageUtils.getPlanarImage(ImageUtils.getImage());

        // Create a TiledImage of the same size but set only its central
        // quarter to the constant image to yield a rectangle of unity pixel
        // values in a black background.
        TiledImage source = new TiledImage(constant.getMinX(),
                constant.getMinY(),
                constant.getWidth(),
                constant.getHeight(),
                constant.getTileGridXOffset(),
                constant.getTileGridYOffset(),
                constant.getSampleModel(),
                constant.getColorModel());
        ROI roi = new ROIShape(new Rectangle(SIZE / 3, SIZE / 3, SIZE / 3, SIZE / 3));
        source.set(constant, roi);

        // Create RenderingHints such that the "dft" output is double
        // precision (default is single precision float).
        SampleModel sm =
                RasterFactory.createComponentSampleModel(source.getSampleModel(),
                        DataBuffer.TYPE_DOUBLE,
                        SIZE, SIZE, 1);
        ImageLayout il = new ImageLayout();
        il.setSampleModel(sm);
        RenderingHints hints = new RenderingHints(JAI.KEY_IMAGE_LAYOUT, il);

        // Perform forward then inverse DFTs and extract the first band
        // from the inverse transform.
        pb = (new ParameterBlock()).addSource(source);
        RenderedImage dft = JAI.create("dft", pb, hints);
        RenderedImage idft = JAI.create("idft", dft);
        RenderedImage band1 = JAI.create("bandselect", idft, new int[]{0});

        // Subtract the first band of the inverse transform from the original
        // source TiledImage, take the absolute value of the difference and
        // calculate the absolute deviation.
        RenderedImage diff = JAI.create("subtract", band1, source);
        RenderedImage abs = JAI.create("absolute", diff);
        RenderedImage ext = JAI.create("extrema", abs);

        // Print the absolute deviation.
        double[] absdev = (double[]) ext.getProperty("maximum");
        for (int b = 0; b < absdev.length; b++) {
            System.out.println("Band " + b + " absolute deviation = " + absdev[b]);
        }

        // Rescale the first band of the inverse transform for display.
        RenderedImage rescale = JAI.create("rescale", band1,
                new double[]{255.0},
                new double[]{0.0});
        RenderedImage display = JAI.create("format", rescale,
                DataBuffer.TYPE_BYTE);

        // Display the original source and the first band of the inverse DFT.
        c.setLayout(new GridLayout(1, 2));
        c.add(new ScrollingImagePanel(JAI.create("rescale", source,
                new double[]{255.0},
                new double[]{0.0}),
                SIZE, SIZE));
        c.add(new ScrollingImagePanel(display, SIZE, SIZE));
        cf.pack();
        cf.show();
    }

}