package ip.color;

public class Octree {
    private static final int MAXDEPTH = 7;
    private int numNodes = 0;
    private int maxNodes = 0;
    private int size, level, leafLevel;
    private RGB colorLut[];
    private Node tree = null;
    private Node reduceList[] = new Node[MAXDEPTH + 1];
    private int k;
    private short r[][], g[][], b[][];

    public Octree() {

    }

    public void octreeQuantization(short ra[][],
                                   short ga[][],
                                   short ba[][],
                                   int ki) {
        r = ra;
        g = ga;
        b = ba;
        k = ki;

        setColor();
        reMap(r, g, b);
    }

    public void reMap(short r[][], short g[][], short b[][]) {
        for (int x = 0; x < r.length; x++)
            for (int y = 0; y < r[0].length; y++) {
                RGB c = new RGB();
                c.r = r[x][y];
                c.g = g[x][y];
                c.b = b[x][y];
                int id = findColor(tree, c);
                c = colorLut[id];
                r[x][y] = c.r;
                g[x][y] = c.b;
                b[x][y] = c.g;
            }
    }

    /**
     * Use the octree color reduction algorithm on a image sequence, so
     * that you can have a consistent color map for each image in the
     * sequence. After you have added all the images, go back an remap each
     * image. - DL
     *
     * @param r
     * @param g
     * @param b
     */
    public void addImagesSeen(short r[][], short g[][], short b[][]) {
        RGB color = new RGB();
        leafLevel = level + 1;
        for (int y = 0; y < r[0].length; y++) {
            for (int x = 0; x < r.length; x++) {
                color.r = r[x][y];
                color.g = g[x][y];
                color.b = b[x][y];
                tree = insertNode(tree, color, 0);
                if (size > k)
                    reduceTree();
            }
        }
        int index[] = new int[1];
        index[0] = 0;
        initVGAPalette(tree, index);
    }

    public void setColor() {
        RGB color = new RGB();

        colorLut = new RGB[k];
        tree = null;
        size = 0;
        level = MAXDEPTH;
        leafLevel = level + 1;
        for (int y = 0; y < r[0].length; y++) {
            for (int x = 0; x < r.length; x++) {
                color.r = r[x][y];
                color.g = g[x][y];
                color.b = b[x][y];
                tree = insertNode(tree, color, 0);
                if (size > k)
                    reduceTree();
            }
        }
        int index[] = new int[1];
        index[0] = 0;
        initVGAPalette(tree, index);

    }

    public int findColor(Node tree, RGB color) {
        if (tree.leaf)
            return tree.colorIndex;
        else {
            final int i = ((color.r >> (MAXDEPTH - tree.level)) & 1) <<
                    2 |
                    ((color.g >> (MAXDEPTH - tree.level)) & 1) << 1 |
                    (color.b >> (MAXDEPTH - tree.level)) & 1;
            final Node treeNode = tree.link[i];
            if (treeNode != null)
              return findColor(treeNode, color);
            return findNearestEntry(color);
        }
    }
    /**
     * search the color lookup table for
     * @param color
     * @return  an index of the closest entry.
     */
    public int findNearestEntry(RGB color) {
       int n = colorLut.length;
        int error = Integer.MAX_VALUE;
        int bestIndex = 0;
        for (int i=0; i < n ; i++) {
            RGB c = colorLut[i];
            int e = c.getError(color);
            if (e < error) {
                error = e;
                bestIndex = i;
            }
        }
        return bestIndex;
    }

    public Node insertNode(Node node, RGB color, int depth) {
        int branch;

        if (node == null) // create new node
        {
            node = new Node();
            numNodes++;
            if (numNodes > maxNodes)
                maxNodes = numNodes;
            node.level = depth;
            node.leaf = (depth >= leafLevel) ? true : false;
            if (node.leaf)
                size++;
        }
        node.colorCount++;
        node.RGBSum.r += color.r;
        node.RGBSum.g += color.g;
        node.RGBSum.b += color.b;
        if (!(node.leaf) && (depth < leafLevel)) {
            branch = ((color.r >> (MAXDEPTH - depth)) & 1) << 2 |
                    ((color.g >> (MAXDEPTH - depth)) & 1) << 1 |
                    (color.b >> (MAXDEPTH - depth)) & 1;
            if (node.link[branch] == null) {
                node.child++;
                if (node.child == 2) {
                    node.nextReduceable = reduceList[depth];
                    reduceList[depth] = node;
                }
            }
            node.link[branch] =
                    insertNode(node.link[branch], color, depth + 1);
        }

        return node;
    }

    public Node killTree(Node tree) {
        if (tree == null)
            return null;
        for (int i = 0; i < 8; i++)
            tree.link[i] = killTree(tree.link[i]);

        numNodes--;

        return null;
    }

    public void reduceTree() {
        Node node;
        int new_Level;
        int depth;

        new_Level = level;
        while (reduceList[new_Level] == null)
            new_Level--;
        node = reduceList[new_Level];
        reduceList[new_Level] = reduceList[new_Level].nextReduceable;
        node.leaf = true;
        size = size - node.child + 1;
        depth = node.level;
        for (int i = 0; i < 8; i++)
            node.link[i] = killTree(node.link[i]);
        if (depth < level) {
            level = depth;
            leafLevel = level + 1;
        }
    }

    public void initVGAPalette(Node tree, int index[]) {
        if (tree != null) {
            if (tree.leaf || tree.level == leafLevel) {
                if (colorLut[index[0]] == null)
                    colorLut[index[0]] = new RGB();
                colorLut[index[0]].r =
                        (short) (tree.RGBSum.r / tree.colorCount);
                colorLut[index[0]].g =
                        (short) (tree.RGBSum.g / tree.colorCount);
                colorLut[index[0]].b =
                        (short) (tree.RGBSum.b / tree.colorCount);
                tree.colorIndex = index[0]++;
                tree.leaf = true;
            } else
                for (int octant = 0; octant < 8; octant++)
                    initVGAPalette(tree.link[octant], index);
        }
    }

    public static int getMAXDEPTH() {
        return MAXDEPTH;
    }

    public int getNumNodes() {
        return numNodes;
    }

    public void setNumNodes(int numNodes) {
        this.numNodes = numNodes;
    }

    public int getMaxNodes() {
        return maxNodes;
    }

    public void setMaxNodes(int maxNodes) {
        this.maxNodes = maxNodes;
    }

    public int getSize() {
        return size;
    }

    public void setSize(int size) {
        this.size = size;
    }

    public int getLevel() {
        return level;
    }

    public void setLevel(int level) {
        this.level = level;
    }

    public int getLeafLevel() {
        return leafLevel;
    }

    public void setLeafLevel(int leafLevel) {
        this.leafLevel = leafLevel;
    }

    public Node getTree() {
        return tree;
    }

    public int getK() {
        return k;
    }

    public short[][] getR() {
        return r;
    }

    public short[][] getG() {
        return g;
    }

    public short[][] getB() {
        return b;
    }

    public RGB[] getColorLut() {
        return colorLut;
    }

    public Node[] getReduceList() {
        return reduceList;
    }
}

class ColorSum {
    public long r, g, b;

    public ColorSum() {
        r = g = b = 0;
    }
}

class RGB {
    public short r, g, b;
    /**
     *
     * @param rgb
     * @return  the square of the error in RGB space.
     */
    public int getError(RGB rgb) {
        if (rgb == null) return 0;
        int dr = r - rgb.r;
        int dg = g - rgb.g;
        int db = b - rgb.b;
        return dr*dr+dg*dg+db*db;
    }
}

class Node {
    boolean leaf = false;
    int level = 0;
    int colorIndex = 0;
    int child = 0;
    long colorCount = 0;
    ColorSum RGBSum = new ColorSum();
    Node nextReduceable = null;
    Node link[] = new Node[8];

    public Node() {
        for (int i = 0; i < 8; i++)
            link[i] = null;
    }
}