use image::{imageops::ColorMap, Rgb, RgbImage}; use std::{array, iter}; const MAX_LEVEL: u8 = 8; #[inline(always)] fn get_color_index(color: &Rgb, level: u8) -> usize { let shift = MAX_LEVEL - level; let r = (color[0] >> shift) & 1; let g = (color[1] >> shift) & 1; let b = (color[2] >> shift) & 1; ((r << 2) | (g << 1) | b) as usize } struct Octree { levels: [Vec; MAX_LEVEL as usize], } impl Octree { fn new() -> Self { let mut octree = Self { levels: Default::default(), }; octree.levels[0].push(OctreeNode::default()); octree } fn insert(&mut self, color: &Rgb) { let mut current_node_id = OctreeNodeId { level: 1, index: 0 }; while current_node_id.level < MAX_LEVEL { let child_index = get_color_index(color, current_node_id.level); if let Some(child_id) = self.get_node(current_node_id).children[child_index] { current_node_id = child_id; } else { current_node_id = self.insert_child(current_node_id, child_index); } } assert_eq!(current_node_id.level, MAX_LEVEL); self.get_node_mut(current_node_id).add_color(color); } fn get_index(&self, color: &Rgb) -> usize { let mut current_node_id = OctreeNodeId { level: 1, index: 0 }; while !self.get_node(current_node_id).is_leaf() { let child_index = get_color_index(color, current_node_id.level); current_node_id = self.get_node(current_node_id).children[child_index] .expect("Proper tree can't have dangling nodes"); } self.get_node(current_node_id).index } fn remove_leaves(&mut self, node_id: OctreeNodeId) -> usize { let mut cnt = 0; for child_index in 0..8 { if let Some(child) = self.delete_child(node_id, child_index) { self.get_node_mut(node_id).add_node(&child); cnt += 1; } } cnt } fn reduce_to(&mut self, color_count: usize) { let mut color_count_current = self.get_level(MAX_LEVEL).len(); if color_count_current > color_count { for level in (1..MAX_LEVEL).rev() { for index in 0..self.get_level(level).len() { let node_id = OctreeNodeId { level, index }; color_count_current -= self.remove_leaves(node_id).saturating_sub(1); if color_count_current <= color_count { break; } } } } } fn finalize(&mut self) -> Vec> { let mut palette = Vec::new(); for level in 1..=MAX_LEVEL { for index in 0..self.get_level(level).len() { let node_id = OctreeNodeId { level, index }; self.get_node_mut(node_id).index = palette.len(); palette.push(Rgb::from(self.get_node(node_id))); } } palette } fn insert_child(&mut self, node_id: OctreeNodeId, child_index: usize) -> OctreeNodeId { assert!(child_index < 8); assert!(node_id.level < MAX_LEVEL); let level = node_id.level + 1; let child_id = OctreeNodeId { level, index: self.get_level(level).len(), }; self.get_level_mut(level).push(OctreeNode::default()); assert!(self.get_node(node_id).children[child_index].is_none()); self.get_node_mut(node_id).children[child_index] = Some(child_id); child_id } fn delete_child(&mut self, node_id: OctreeNodeId, child_index: usize) -> Option { assert!(child_index < 8); let child_id = self.get_node_mut(node_id).children[child_index].take()?; assert!(!self.get_node(child_id).is_deleted); self.get_node_mut(child_id).is_deleted = true; Some(self.get_node(child_id).clone()) } #[inline(always)] fn get_level(&self, level: u8) -> &Vec { &self.levels[(level - 1) as usize] } #[inline(always)] fn get_level_mut(&mut self, level: u8) -> &mut Vec { &mut self.levels[(level - 1) as usize] } #[inline(always)] fn get_node(&self, node_id: OctreeNodeId) -> &OctreeNode { &self.get_level(node_id.level)[node_id.index] } #[inline(always)] fn get_node_mut(&mut self, node_id: OctreeNodeId) -> &mut OctreeNode { &mut self.get_level_mut(node_id.level)[node_id.index] } } #[derive(Debug, Default, Clone, Copy)] struct OctreeNodeId { level: u8, index: usize, } #[derive(Debug, Default, Clone)] struct OctreeNode { color: [u64; 3], count: u64, children: [Option; 8], is_deleted: bool, index: usize, } impl From<&OctreeNode> for Rgb { #[inline(always)] fn from(node: &OctreeNode) -> Self { if node.count == 0 { Rgb::from([0, 0, 0]) } else { Rgb::from(array::from_fn(|i| (node.color[i] / node.count) as u8)) } } } impl OctreeNode { #[inline(always)] fn is_leaf(&self) -> bool { !self.is_deleted && self.children.iter().all(Option::is_none) } #[inline(always)] fn add_color(&mut self, color: &Rgb) { self.count += 1; iter::zip(&mut self.color, color.0).for_each(|(a, b)| *a += b as u64) } #[inline(always)] fn add_node(&mut self, node: &OctreeNode) { self.count += node.count; iter::zip(&mut self.color, node.color).for_each(|(a, b)| *a += b) } } pub struct ColorQuantizer { octree: Octree, colors: Vec>, } impl ColorQuantizer { pub fn from(img: &RgbImage, palette_size: usize) -> Self { let palette_size = palette_size.min(256); let mut octree = Octree::new(); for pixel in img.pixels() { octree.insert(pixel); } octree.reduce_to(palette_size); let colors = octree.finalize(); assert!( colors.len() <= palette_size, "Color palette size exceeded {palette_size}" ); println!("final color palette size: {}", colors.len()); Self { octree, colors } } #[inline(always)] pub fn get_palette(&self) -> &[Rgb] { &self.colors } #[inline(always)] pub fn get_index(&self, color: &Rgb) -> usize { self.octree.get_index(color) } } impl ColorMap for ColorQuantizer { type Color = Rgb; #[inline(always)] fn index_of(&self, color: &Self::Color) -> usize { self.get_index(color) } #[inline(always)] fn map_color(&self, color: &mut Self::Color) { *color = self.get_palette()[self.get_index(color)] } }