package j2d.color;

/*
 * Copyright (c) 2005 DocJava, Inc. All Rights Reserved.
 */

import futils.Futil;
import gui.In;

import javax.media.jai.*;
import javax.media.jai.iterator.RandomIter;
import javax.media.jai.iterator.RandomIterFactory;
import java.awt.*;
import java.awt.image.*;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.PrintWriter;

/**
 * The application on this class performs a simplified version of the
 * K-Means algorithm on an image, classifying (clustering) its pixels. Its
 * also shows how one can get pixel values from any position on an image.
 */
public class Kmeans {

    private static int nClusters; // The number of clusters for the algorithm.
    private static double[][] clusterCenters; // An array to store the cluster
    // centers. It could be of type int,
    // but read on.
    private static int[] pixelCounterForCluster; // Array to store pixel counter
    //  for each cluster.
    private static final int maxIterations = 50; // Will iterate at most 50 times.
    private static final double minDiffRMS = 0.01; // Will iterate until the
    // difference between the RMS
    // between iterations is
    // smaller or equal this value.

    /**
     * The application entry point.
     *
     * @param args the command line arguments.
     */
    public static void main(String[] args) throws IOException {
        // We need one argument: the image filename, the desired number of clusters
        // and a cluster center initialization method.

        String options[] = {
            "[R]andom",
            " [S]paced ",
            "random[P]ixels)"
        };
        String initMethod = (String)
                In.multiPrompt(options,
                        "select the init message",
                        "init message dialog");
        // Open a text file for reports.
        BufferedWriter report = new BufferedWriter(
                new PrintWriter(System.out));
        // Open the image (using the name passed as a command line parameter)
        PlanarImage pi = JAI.create("fileload",
                Futil.getReadFile("select an image").toString());
        // Get the number of clusters from the command line
        nClusters = In.getInt("select the number of clusters", 1, 256);

        // If the source image is colormapped, convert it to 3-band RGB.
        if (pi.getColorModel() instanceof IndexColorModel) {
            // Retrieve the IndexColorModel
            IndexColorModel icm = (IndexColorModel) pi.getColorModel();
            // Cache the number of elements in each band of the colormap.
            int mapSize = icm.getMapSize();
            // Allocate an array for the lookup table data.
            byte[][] lutData = new byte[3][mapSize];
            // Load the lookup table data from the IndexColorModel.
            icm.getReds(lutData[0]);
            icm.getGreens(lutData[1]);
            icm.getBlues(lutData[2]);
            // Create the lookup table object.
            LookupTableJAI lut = new LookupTableJAI(lutData);
            // Replace the original image with the 3-band RGB image.
            pi = JAI.create("lookup", pi, lut);
        }
        // Get the image dimensions.
        int width = pi.getWidth();
        int height = pi.getHeight();
        // Get the number of bands on the image.
        SampleModel sm = pi.getSampleModel();
        int nbands = sm.getNumBands();
        // We assume that we can get the pixels values in a integer array.
        int[] pixel = new int[nbands];
        // Get an iterator for the image.
        RandomIter iterator = RandomIterFactory.create(pi, null);
        // Create an output array for the image (type short can be used since there
        // won't be many clusters.
        short[][] clusteredImage = new short[width][height];
        // Create the array that will hold the centers of the clusters.
        clusterCenters = new double[nClusters][3];
        // Create the array that will hold the number of pixels assigned to each
        // cluster.
        pixelCounterForCluster = new int[nClusters];
        String initChar = initMethod.toUpperCase();
        if (initChar.charAt(0) == 'R') // Init the cluster centers with
            for (int cluster = 0; cluster < nClusters; cluster++) // random values.
            {
                clusterCenters[cluster][0] = (int) (Math.random() * 255.);
                clusterCenters[cluster][1] = (int) (Math.random() * 255.);
                clusterCenters[cluster][2] = (int) (Math.random() * 255.);
            }
        else if (initChar.charAt(0) == 'S') // Init the cluster centers
            for (int cluster = 0; cluster < nClusters; cluster++) // with evenly spaced
            {                                            // values.
                clusterCenters[cluster][0] =
                        (int) (255.0 * cluster / nClusters);
                clusterCenters[cluster][1] =
                        (int) (255.0 * cluster / nClusters);
                clusterCenters[cluster][2] =
                        (int) (255.0 * cluster / nClusters);
            }
        else // assume 'P', init the cluster centers with randomly sampled pixels.
            for (int cluster = 0; cluster < nClusters; cluster++) {
                int randomW = (int) (Math.random() * width);
                int randomH = (int) (Math.random() * height);
                iterator.getPixel(randomW, randomH, pixel);
                clusterCenters[cluster][0] = pixel[0]; // Can't do direct assignment
                clusterCenters[cluster][1] = pixel[1]; // since pixel is int[] and
                clusterCenters[cluster][2] = pixel[2]; // clusterCenters[cluster] is
            }                                      // double[]
        // Write the cluster centers to the report file.
        report.write("Initial cluster information:\n");
        for (int cluster = 0; cluster < nClusters; cluster++) {
            report.write("Cluster " + (cluster + 1));
            report.write(" center: (" +
                    clusterCenters[cluster][0] +
                    "," +
                    clusterCenters[cluster][1] +
                    "," +
                    clusterCenters[cluster][2] + ")\n");
        }
        report.write(
                "-------------------------------------------------\n");
        // We're ready to start ! Do a main loop for the algorithm.
        boolean continueClustering = true;
        // One way to stop this iterative algorithm is determine that it will stop
        // after a certain number of iterations, so we need to count the iterations.
        int iterations = 0;
        // Another way to stop the iterations is to check whether there is a
        // significant difference in the clustering between iterations - if there
        // isn't, we can consider that the iterations will not enhance the results
        // and stop. To measure the difference we calculate a "Root Mean Squared
        // (RMS)"-like value: the square root of the sum of the squared distances
        // from all pixels to the center of their clusters.
        double thisRMS = 0, lastRMS = 0, diffRMS = 0;
        while (continueClustering) {
            // Write some stuff to the report.
            report.write("Start iteration " + (iterations + 1) + "\n");
            // First part - we scan the image and "classify" the pixels accordingly
            // to their minimum distance to the cluster centers.
            for (int h = 0; h < height; h++)
                for (int w = 0; w < width; w++) {
                    // Get the array of values for the pixel on the w,h coordinate.
                    iterator.getPixel(w, h, pixel);
                    // Gets its class
                    short pixelsClass = getClass(pixel);
                    clusteredImage[w][h] = pixelsClass;
                }
            // Second part - we have the "classified" image, now let's update the
            // cluster centers. To avoid creating unnecessary arrays, we will use the
            // array that holds the cluster centers as accumulators (that's why they
            // were declared as doubles). We also reset the pixel count for each
            // cluster.
            for (int cluster = 0; cluster < nClusters; cluster++) {
                clusterCenters[cluster][0] = 0;
                clusterCenters[cluster][1] = 0;
                clusterCenters[cluster][2] = 0;
                pixelCounterForCluster[cluster] = 0;
            }
            for (int h = 0; h < height; h++)
                for (int w = 0; w < width; w++) {
                    // To which class does this pixel belong ?
                    short pixelsClass = clusteredImage[w][h];
                    pixelCounterForCluster[pixelsClass]++;
                    // Let's get the pixels values to calculate the cluster center
                    iterator.getPixel(w, h, pixel);
                    clusterCenters[pixelsClass][0] += pixel[0];
                    clusterCenters[pixelsClass][1] += pixel[1];
                    clusterCenters[pixelsClass][2] += pixel[2];
                }
            // Recalculate the center of each cluster
            for (int cluster = 0; cluster < nClusters; cluster++) {
                clusterCenters[cluster][0] /=
                        (double) pixelCounterForCluster[cluster];
                clusterCenters[cluster][1] /=
                        (double) pixelCounterForCluster[cluster];
                clusterCenters[cluster][2] /=
                        (double) pixelCounterForCluster[cluster];
            }
            // Write the clusters information to the report file.
            report.write("Cluster information:\n");
            for (int cluster = 0; cluster < nClusters; cluster++) {
                report.write("Cluster " + (cluster + 1));
                report.write(" count:" + pixelCounterForCluster[cluster]);
                report.write(" center: (" +
                        clusterCenters[cluster][0] +
                        "," +
                        clusterCenters[cluster][1] +
                        "," +
                        clusterCenters[cluster][2] + ")");
                report.write("\n");
            }
            // Calculate the RMS error of this iteration.
            thisRMS = 0;
            for (int h = 0; h < height; h++)
                for (int w = 0; w < width; w++) {
                    // To which class does this pixel belong ?
                    short pixelsClass = clusteredImage[w][h];
                    // Let's get the pixels values
                    iterator.getPixel(w, h, pixel);
                    double thisDistance =
                            calcSquaredDistance(pixel,
                                    clusterCenters[pixelsClass]);
                    thisRMS = +thisDistance;
                }
            thisRMS = Math.sqrt(thisRMS);
            // Calculate the difference between iterations' RMSs
            if (iterations == 0)
                diffRMS = Double.MAX_VALUE;
            else {
                diffRMS = thisRMS - lastRMS;
            }
            // Write the RMS information to the report file.
            report.write("RMS:" +
                    thisRMS +
                    " lRMS:" +
                    lastRMS +
                    " dRMS:" +
                    diffRMS +
                    " stop@:" + minDiffRMS);
            report.write(
                    "\n-------------------------------------------------\n");
            lastRMS = thisRMS;
            // OK, do we stop ? If iterations reached its maximum or the difference
            // between consecutive RMSs is smaller then a minimum value, we stop.
            iterations++;
            if (iterations > maxIterations) continueClustering = false;
            if (diffRMS <= minDiffRMS) continueClustering = false;
        }
        // Now we have a clustered image and a list of cluster centers.
        // First we reprocess the clustered image (i.e. the 2D array with the
        // indexes of the cluster centers) - we must have a 1D array to store on a
        // DataBuffer.
        short[] clusteredImage1D = new short[width * height];
        int cont = 0;
        for (int h = 0; h < height; h++)
            for (int w = 0; w < width; w++) {
                clusteredImage1D[cont++] = clusteredImage[w][h];
            }
        // We also need a colormap for the results. One of the problems I faced
        // here is that I couldn't create a colormap of any size, it had to have
        // 256 colors.
        byte[] reds, greens, blues; // the arrays of entries on the colormap
        int nColors = 256; // it should be nColors = nClusters, but read above.
        reds = new byte[nColors];
        greens = new byte[nColors];
        blues = new byte[nColors];
        // Create the color map using the cluster centers
        for (int cluster = 0; cluster < nClusters; cluster++) {
            reds[cluster] = (byte) clusterCenters[cluster][0];
            greens[cluster] = (byte) clusterCenters[cluster][1];
            blues[cluster] = (byte) clusterCenters[cluster][2];
        }
        // Let's create a SampleModel for an image of bytes
        SampleModel sampleModel =
                RasterFactory.createBandedSampleModel(
                        DataBuffer.TYPE_BYTE,
                        width,
                        height,
                        1); // only one band
        // Let's create an instance of IndexColorModel using the colormap entries
        ColorModel colorModel = new IndexColorModel(8,
                nColors,
                reds,
                greens,
                blues);
        // Let's create a TiledImage using the SampleModel and ColorModel
        TiledImage tiledImage = new TiledImage(0, 0, width, height, 0, 0,
                sampleModel,
                colorModel);
        // Construct a DataBuffer from the data array.
        DataBufferShort dbuffer =
                new DataBufferShort(clusteredImage1D, width * height);
        // Create a WritableRaster using the SampleModel and the DataBuffer
        Raster raster = RasterFactory.createWritableRaster(sampleModel,
                dbuffer,
                new Point(0, 0));
        // Set the data on the TiledImage to be the Raster.
        tiledImage.setData(raster);
        // Save the tiled image.
        String outputFile = Futil.getWriteFile("select an output file")+"";
        JAI.create("filestore", tiledImage, outputFile, "TIFF");
        // Close the report file.
        report.flush();
        report.close();
        System.exit(0);
    }

    /**
     * This method compares a pixel array with all known classes pixels'
     * arrays to see which is the closest, i.e. which class should be
     * chosen accordingly to the minimum Euclidean distance algorithm.
     * There is a small trick on Java double values comparison: if a NaN
     * occurs, it should be ignored by the comparison.
     *
     * @param pixel the array of values
     * @return int
     */
    private static short getClass(int[] pixel) {
        // Which distance is the smaller ? Let's assume it is the maximum double
        // value.
        double smallestSoFar = Double.MAX_VALUE;
        short whichClass = -1; // any value should do
        for (short cluster = 0; cluster < nClusters; cluster++) {
            double thisDistance = calcSimpleDistance(pixel,
                    clusterCenters[cluster]);
            // It is possible to have a "singularity" - a cluster which does not
            // contain any pixel assigned to it. After the second iteration, these
            // cluster centers will be NaN, and we can safely ignore those clusters.
            if (!Double.isNaN(thisDistance))
                if (thisDistance < smallestSoFar) {
                    smallestSoFar = thisDistance;
                    whichClass = cluster;
                }
        }
        return whichClass;
    }

    /**
     * This method calculates the Euclidean distance between two arrays
     * (which should contain the image and sample pixels). The method
     * assumes that both arrays have the same dimensions, which are
     * hard-coded to speed up the algorithm. The method is optimized, it
     * does not use the square root in the distance calculation (since it
     * will be used for comparison only).
     *
     * @param pixel1 the array with one pixel's values.
     * @param pixel2 the array with other pixel's values.
     * @return the simplified Euclidean distance.
     */
    private static double calcSimpleDistance(int[] pixel1,
                                             double[] pixel2) {
        return (pixel1[0] - pixel2[0]) * (pixel1[0] - pixel2[0]) +
                (pixel1[1] - pixel2[1]) * (pixel1[1] - pixel2[1]) +
                (pixel1[2] - pixel2[2]) * (pixel1[2] - pixel2[2]);
    }

    /**
     * This method calculates the Euclidean distance between two arrays
     * (which should contain the image and sample pixels). The method
     * assumes that both arrays have the same dimensions, which are
     * hard-coded to speed up the algorithm. The method returns the true
     * Euclidean distance, i.e. the square root of the sum of the distances
     * for each dimension, and reuse the simple distance method.
     *
     * @param pixel1 the array with one pixel's values.
     * @param pixel2 the array with other pixel's values.
     * @return the strict Euclidean distance.
     */
    private static double calcSquaredDistance(int[] pixel1,
                                              double[] pixel2) {
        return Math.sqrt(calcSimpleDistance(pixel1, pixel2));
    }

}
