use image::{imageops::ColorMap, Rgb, RgbImage}; use std::{ array, cmp, collections::{BinaryHeap, VecDeque}, iter, }; use crate::queue::Queue; const MAX_LEVEL: u8 = 8; const MAX_NODES: usize = 1024; const MAX_COLORS: usize = 256; #[derive(Debug, Default)] struct Node { rgb: [u32; 3], count: u32, index: u8, level: u8, is_leaf: bool, parent: Option, children: [Option; 8], } impl Node { fn merge_color(&mut self, color: Rgb) { self.count += 1; iter::zip(&mut self.rgb, color.0).for_each(|(a, b)| *a += b as u32) } fn merge_node(&mut self, node: Node) { self.count += node.count; iter::zip(&mut self.rgb, node.rgb).for_each(|(a, b)| *a += b) } } #[derive(Debug)] struct Pool { nodes: [Node; MAX_NODES], ids: Queue, } impl Pool { fn new() -> Self { let mut ids = Queue::new(); let nodes = array::from_fn(|i| { ids.push(i as u32); Node::default() }); Self { nodes, ids } } fn create(&mut self) -> u32 { self.ids.pop().unwrap() } fn get(&self, id: u32) -> &Node { &self.nodes[id as usize] } fn get_mut(&mut self, id: u32) -> &mut Node { &mut self.nodes[id as usize] } fn delete(&mut self, id: u32) -> Node { self.ids.push(id); std::mem::take(&mut self.nodes[id as usize]) } } fn get_color_index(color: Rgb, level: u8) -> usize { let shift = MAX_LEVEL - level; color .0 .into_iter() .rev() .enumerate() .map(|(i, c)| (((c >> shift) & 1) << i) as usize) .fold(0, |s, c| s | c) } #[derive(Debug, Ord)] struct Reducible { node_id: u32, level: u8, count: u32, } impl Reducible { fn new(node_id: u32, node: &Node) -> Self { Self { node_id, level: node.level, count: node.count, } } } impl Eq for Reducible {} impl cmp::PartialEq for Reducible { fn eq(&self, other: &Self) -> bool { self.level.eq(&other.level) && self.count.eq(&other.count) } } impl cmp::PartialOrd for Reducible { fn partial_cmp(&self, other: &Self) -> Option { Some(match self.level.cmp(&other.level) { cmp::Ordering::Greater => cmp::Ordering::Greater, cmp::Ordering::Less => cmp::Ordering::Less, cmp::Ordering::Equal => self.count.cmp(&other.count), }) } } struct Octree { pool: Pool, root: u32, reducible: BinaryHeap, color_count: usize, leaf_count: usize, } impl Octree { fn new(color_count: usize) -> Self { let mut pool = Pool::new(); let root = pool.create(); pool.get_mut(root).is_leaf = true; Self { pool, root, reducible: BinaryHeap::new(), color_count, leaf_count: 1, } } pub fn traverse(&self, f: F) where F: FnMut(u32, &Node), { let mut f = f; let mut queue = VecDeque::new(); queue.push_back(self.root); while let Some(node_id) = queue.pop_front() { let node = self.pool.get(node_id); f(node_id, node); for child_id in node.children.iter().flatten() { queue.push_back(*child_id); } } } pub fn traverse_mut(&mut self, f: F) where F: FnMut(u32, &mut Node), { let mut f = f; let mut queue = VecDeque::new(); queue.push_back(self.root); while let Some(node_id) = queue.pop_front() { let node = self.pool.get_mut(node_id); f(node_id, node); for child_id in node.children.iter().flatten() { queue.push_back(*child_id); } } } fn insert(&mut self, color: Rgb) { let mut node_id = self.root; for level in 1..=MAX_LEVEL { let child_index = get_color_index(color, level); node_id = match self.pool.get(node_id).children[child_index] { Some(child_id) => child_id, None => { let child_id = self.pool.create(); { let child = self.pool.get_mut(child_id); child.level = level; child.parent = Some(node_id); child.is_leaf = true; self.leaf_count += 1; } { let parent = self.pool.get_mut(node_id); parent.children[child_index] = Some(child_id); if parent.is_leaf { parent.is_leaf = false; self.leaf_count -= 1; self.reducible.push(Reducible::new(node_id, &parent)); } } child_id } } } self.pool.get_mut(node_id).merge_color(color); self.reduce(); } fn get_index(&self, color: Rgb) -> usize { let mut node_id = self.root; for level in 1..=MAX_LEVEL { let node = self.pool.get(node_id); if node.is_leaf { break; } let child_index = get_color_index(color, level); node_id = match node.children[child_index] { Some(child_id) => child_id, None => { match node .children .iter() .enumerate() .filter_map(|(i, c)| c.zip(Some(i))) .map(|(c, i)| { let d = child_index ^ i; ([d >> 2, (d >> 1) & 1, d & 1].into_iter().sum::(), c) }) .min_by(|a, b| a.0.cmp(&b.0)) .map(|(_, c)| c) { Some(child_id) => child_id, None => break, } } } } self.pool.get(node_id).index as usize } fn prune_node(&mut self, node_id: u32) { self.leaf_count -= { let node = self.pool.get_mut(node_id); node.is_leaf = true; self.leaf_count += 1; std::mem::take(&mut node.children) } .into_iter() .flatten() .map(|child_id| { let child = self.pool.delete(child_id); self.pool.get_mut(node_id).merge_node(child); }) .count() } fn reduce(&mut self) { if self.leaf_count > self.color_count { while let Some(reducible) = self.reducible.pop() { self.prune_node(reducible.node_id); if let Some(parent_id) = self.pool.get(reducible.node_id).parent { self.reducible .push(Reducible::new(parent_id, self.pool.get(parent_id))); } if self.leaf_count <= self.color_count { break; } } } } fn finalize(&mut self) -> Vec> { let mut palette = Vec::new(); self.traverse_mut(|_, node| { if node.is_leaf { node.index = palette.len() as u8; palette.push(Rgb::from(array::from_fn(|i| { (node.rgb[i] / node.count) as u8 }))); } }); palette } } pub struct ColorQuantizer { octree: Octree, colors: Vec>, } impl ColorQuantizer { pub fn from(img: &RgbImage, palette_size: usize) -> Self { let palette_size = palette_size.min(MAX_COLORS); let mut octree = Octree::new(palette_size); for pixel in img.pixels() { octree.insert(*pixel); } let colors = octree.finalize(); println!("final color palette size: {}", colors.len()); Self { octree, colors } } #[inline(always)] pub fn get_palette(&self) -> &[Rgb] { &self.colors } #[inline(always)] pub fn get_index(&self, color: Rgb) -> usize { self.octree.get_index(color) } } impl ColorMap for ColorQuantizer { type Color = Rgb; #[inline(always)] fn index_of(&self, color: &Self::Color) -> usize { self.get_index(*color) } #[inline(always)] fn map_color(&self, color: &mut Self::Color) { *color = self.get_palette()[self.get_index(*color)] } }