/*
 * Copyright (c) 2005 DocJava, Inc. All Rights Reserved.
 */
package math.fourierTransforms.pfa;

import j2d.ImageUtils;
import j2d.ShortImageBean;

import java.awt.*;
import java.awt.image.MemoryImageSource;

import math.complex.ComplexFloat1d;

public class PFAImage {

    private int intImage[];
    private int imageWidth, imageHeight;
    // image.length, or imageWidth * imageHeight
    private int N;
    // scale is used to scale the FFT input to prevent overflow,
    // N = imageWidth * imageHeight is often used.
    private int scale;

    private boolean fftShift;

    private short alpha[];

    ComplexFloat1d red = new ComplexFloat1d();
    ComplexFloat1d blue = new ComplexFloat1d();
    ComplexFloat1d green = new ComplexFloat1d();
    //private final MyImageObserverUtils imageObserverUtils = new MyImageObserverUtils();


    public PFAImage(Image img) {
        initImage(img);
    }

    /**
     * uses one image to filter another.
     * only works if the two images are the same size, and if
     * the kernal is a mask defined in frequency space.
     */
    public static Image filter(Image img1, Image kernal) {
        PFAImage fimg1 = new PFAImage(img1);
        fimg1.fft();
        fimg1.scaleAndMask(kernal);
        fimg1.ifft();
        return fimg1.getImage();
    }

    public Image getImage() {
        Toolkit tk = Toolkit.getDefaultToolkit();
        return tk.createImage(
                new MemoryImageSource(
                        imageWidth,
                        imageHeight,
                        tk.getColorModel(),
                        intImage, 0,
                        imageWidth));
    }

    private void initImage(Image img) {
        ShortImageBean sib = new ShortImageBean(img);
        imageWidth = sib.getWidth();
        imageHeight = sib.getHeight();
        initVariables(sib.getPels(), imageWidth, false);
    }

    //todo test the below...it is untested!
    /**
     * The kernal is a filter mask!
     * Just screen out the pixels in frequency space. The input
     * kernal pixels are 0 or 255. So normalize to range from 0 to 1.
     * Do this by dividing by 255!
     *
     */
    public void scaleAndMask(Image kernal) {
        ShortImageBean sib = new ShortImageBean(kernal);
        red.scaleAndMask(sib.getR());
        green.scaleAndMask(sib.getG());
        blue.scaleAndMask(sib.getB());
    }


    //(&&&)rjd
    public void scale(float scaleFactor){
        red.scale(scaleFactor);
        green.scale(scaleFactor);
        blue.scale(scaleFactor);
    }
    //(&&&)rjd
    public void multiply(PFAImage kernal){
        red.mult(kernal.getRed());
        green.mult(kernal.getGreen());
        blue.mult(kernal.getBlue());
    }

    public ComplexFloat1d getRed(){
        return red;
    }
    public ComplexFloat1d getGreen(){
        return green;
    }
    public ComplexFloat1d getBlue(){
        return blue;
    }

    public PFAImage(int intImage[], int imageWidth,
                    boolean fftShift) {
        initVariables(intImage, imageWidth, fftShift);
    }

    private void initVariables(int[] intImage, int imageWidth, boolean fftShift) {
        this.intImage = intImage;
        this.imageWidth = imageWidth;
        N = intImage.length;
        imageHeight = N / imageWidth;
        this.fftShift = fftShift;
        scale = N;

        alpha = ImageUtils.getAlpha(intImage);
        short r[] = ImageUtils.getRed(intImage);
        final int n = r.length;
        // If fftShift is true, shift the zero frequency to the center.
        red.setRe(fftReorder(r));
        green.setRe(fftReorder(ImageUtils.getGreen(intImage)));
        blue.setRe(fftReorder(ImageUtils.getBlue(intImage)));

        red.setIm(new float[n]);
        green.setIm(new float[n]);
        blue.setIm(new float[n]);
    }

    public void fft() {
        intImage = getFftIntArray();
    }

    public int[] getFftIntArray() {
        new PFA2d(red, imageWidth);
        new PFA2d(green, imageWidth);
        new PFA2d(blue, imageWidth);

        float magScale = 100;
        float resultRed[] = red.logScaleMagnitude(magScale);
        float resultGreen[] = green.logScaleMagnitude(magScale);
        float resultBlue[] = blue.logScaleMagnitude(magScale);

        return ImageUtils.getArgbToInt(alpha, resultRed,
                resultGreen, resultBlue);
    }

    public void ifft() {
        intImage = getIfftIntArray();
    }

    public int[] getIfftIntArray() {
        new IFFT2d(red, imageWidth,true);
        new IFFT2d(green, imageWidth,true);
        new IFFT2d(blue, imageWidth,true);

        return ImageUtils.getArgbToInt(alpha,
                ifftReorder(red.getRe()),
                ifftReorder(green.getRe()),
                ifftReorder(blue.getRe()));
    }

    // reorder color data of transforms.fft input.
    // 1. Convert color data from short to float.
    // 2. Scale the color data by scale.
    // 3. If fftShift is true, shift the zero frequency in the center of matrix.
    private float[] fftReorder(short color[]) {
        float result[] = new float[N];

        if (!fftShift) {   // Without zero frequency shift.
            for (int i = 0; i < N; i++)
                result[i] = (float) color[i] / scale;
        } else {            // With zero frequency shift.
            int k = 0;
            float alternateSign;
            for (int i = 0; i < imageHeight; i++)
                for (int j = 0; j < imageWidth; j++) {
                    alternateSign = ((i + j) % 2 == 0) ? -1 : 1;
                    result[i * imageWidth + j] = (color[k++] * alternateSign / scale);
                }
        }
        return result;
    } // End of function fftReorder().

    private short[] ifftReorder(float re[]) {
        short result[] = new short[N];

        if (!fftShift) {   // Without zero frequency shift.
            for (int i = 0; i < N; i++)
                result[i] = (short) (re[i] * scale);
        } else {            // With zero frequency shift.
            int k = 0;
            float alternateSign;
            for (int i = 0; i < imageHeight; i++)
                for (int j = 0; j < imageWidth; j++) {
                    alternateSign = ((i + j) % 2 == 0) ? -1 : 1;
                    result[i * imageWidth + j] = (short) (re[k++] * alternateSign * scale);
                }
        }
        return result;
    } // End of function fftReorder().


    //
    private void quadShiftSingleColor(float re []){
        int k = 0;
        float alternateSign;
        for (int i = 0; i < imageHeight; i++){
            for (int j = 0; j < imageWidth; j++) {
                alternateSign = ((i + j) % 2 == 0) ? 1 : -1;
                re[k++] *=alternateSign;
            }
        }
    }

    //Performing this in frequency domain will cause quadrant shift when we come back
    public void quadrantShift() {
        quadShiftSingleColor(red.getRe());
        quadShiftSingleColor(green.getRe());
        quadShiftSingleColor(blue.getRe());
        quadShiftSingleColor(red.getIm());
        quadShiftSingleColor(green.getIm());
        quadShiftSingleColor(blue.getIm());
    }


    public static void main(String[] args) {
        Image img = ImageUtils.getImage();
        ImageUtils.displayImage(img, "original Image");
        ShortImageBean sib = new ShortImageBean(img);
        final int width = sib.getWidth();
        PFAImage fftimage = new PFAImage(sib.getPels(),
                width, true);
        final int height = sib.getHeight();
        Image psdImage = ImageUtils.getImage(
                fftimage.getFftIntArray(), width, height);
        ImageUtils.displayImage(psdImage, "psd image");
        Image filteredImage = ImageUtils.getImage(
                fftimage.getIfftIntArray(), width, height);
        ImageUtils.displayImage(filteredImage, "filtered image");
    }

    public float getPeakMagnitude() {
        return red.getMaxReal();
    }

    public void normalize() {
        red.normalize();
        green.normalize();
        blue.normalize();
    }
} // End of class j2d.filters.fftConvolution.FFTImageAM.
