Files
fil/crates/kreuzberg/src/doc_orientation/detector.rs
Henrik Jess Nielsen b4c07d3693
All checks were successful
Deploy fil (kreuzberg) / deploy (push) Successful in 49s
Nomad changes
2026-06-01 23:40:55 +02:00

311 lines
12 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! Document orientation detection implementation using PP-LCNet_x1_0_doc_ori.
//!
//! Detects page-level orientation (0°, 90°, 180°, 270°) for scanned documents
//! and images. Requires the `auto-rotate` feature (ONNX Runtime).
//!
//! Used by ALL OCR backends when `auto_rotate` is enabled in `OcrConfig`.
//! More reliable than Tesseract's `DetectOrientationScript` which crashes
//! on raw images without DPI metadata.
use std::fs;
use std::path::PathBuf;
use image::RgbImage;
use ort::session::Session;
use ort::session::builder::{GraphOptimizationLevel, SessionBuilder};
use ort::value::Tensor;
use crate::Result;
use crate::error::KreuzbergError;
use super::types::OrientationResult;
/// HuggingFace repository containing the model.
const HF_REPO_ID: &str = "Kreuzberg/paddleocr-onnx-models";
const REMOTE_FILENAME: &str = "v2/classifiers/PP-LCNet_x1_0_doc_ori.onnx";
const SHA256: &str = "6b742aebce6f0f7f71f747931ac7becfc7c96c51641e14943b291eeb334e7947";
// PP-LCNet preprocessing constants.
// Input: resize short side to 256, center crop 224×224, ImageNet normalize (BGR).
const INPUT_SIZE: u32 = 224;
const RESIZE_SHORT: u32 = 256;
/// Output labels: index -> degrees.
const ORIENTATION_LABELS: [u32; 4] = [0, 90, 180, 270];
/// PP-LCNet doc_ori outputs ~45% confidence for correct class in a 4-class problem.
/// Uniform baseline is 25%. A threshold of 0.35 provides good discrimination.
pub const MIN_CONFIDENCE: f32 = 0.35;
/// Detects document page orientation using the PP-LCNet model.
///
/// Thread-safe: uses unsafe pointer cast for ONNX session (same pattern as embedding engine).
/// The model is downloaded from HuggingFace on first use and cached locally.
#[cfg_attr(alef, alef(skip))]
pub struct DocOrientationDetector {
session: once_cell::sync::OnceCell<Session>,
cache_dir: PathBuf,
acceleration: Option<crate::core::config::acceleration::AccelerationConfig>,
}
impl DocOrientationDetector {
/// Creates a new detector with the given cache directory and acceleration config.
pub(crate) fn with_acceleration(
cache_dir: PathBuf,
accel: Option<crate::core::config::acceleration::AccelerationConfig>,
) -> Self {
Self {
session: once_cell::sync::OnceCell::new(),
cache_dir,
acceleration: accel,
}
}
/// Detect document page orientation.
///
/// Returns the detected orientation (0°, 90°, 180°, 270°) and confidence.
/// Thread-safe: can be called concurrently from multiple pages.
pub(crate) fn detect(&self, image: &RgbImage) -> Result<OrientationResult> {
let session = self.get_or_init_session()?;
// Preprocess: resize short side to 256, center crop 224×224
let preprocessed = preprocess(image);
// Build input tensor: [1, 3, 224, 224]
let input_tensor = normalize(&preprocessed);
let tensor = Tensor::from_array(input_tensor).map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to create doc_ori input tensor: {e}"),
source: None,
})?;
// SAFETY: ONNX Runtime C API is thread-safe for concurrent inference.
// The ort crate's &mut self on Session::run is overly conservative.
#[allow(unsafe_code)]
let outputs = unsafe {
let session_ptr = session as *const Session as *mut Session;
(*session_ptr).run(ort::inputs!["x" => tensor])
}
.map_err(|e| KreuzbergError::Ocr {
message: format!("Doc orientation inference failed: {e}"),
source: None,
})?;
// Parse output: argmax over 4 orientation classes
let (_, output_value) = outputs.iter().next().ok_or_else(|| KreuzbergError::Ocr {
message: "No output from doc orientation model".to_string(),
source: None,
})?;
let scores: Vec<f32> = output_value
.try_extract_tensor::<f32>()
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to extract doc_ori output: {e}"),
source: None,
})?
.1
.to_vec();
// Softmax + argmax
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
let probabilities: Vec<f32> = exp_scores.iter().map(|&e| e / sum_exp).collect();
let (best_idx, &best_prob) = probabilities
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or((0, &0.0));
let degrees = ORIENTATION_LABELS.get(best_idx).copied().unwrap_or(0);
Ok(OrientationResult {
degrees,
confidence: best_prob,
})
}
/// Ensure the model is downloaded and return the ONNX file path.
fn ensure_model(&self) -> Result<PathBuf> {
let model_dir = self.cache_dir.join("doc-orientation");
let model_file = model_dir.join("model.onnx");
if model_file.exists() {
return Ok(model_file);
}
tracing::info!("Downloading document orientation model...");
fs::create_dir_all(&model_dir)?;
let cached_path =
crate::model_download::hf_download(HF_REPO_ID, REMOTE_FILENAME).map_err(|e| KreuzbergError::Plugin {
message: e,
plugin_name: "auto-rotate".to_string(),
})?;
crate::model_download::verify_sha256(&cached_path, SHA256, "doc_ori").map_err(|e| {
KreuzbergError::Validation {
message: e,
source: None,
}
})?;
fs::copy(&cached_path, &model_file).map_err(|e| KreuzbergError::Plugin {
message: format!("Failed to copy doc_ori model: {e}"),
plugin_name: "auto-rotate".to_string(),
})?;
tracing::info!("Document orientation model saved");
Ok(model_file)
}
/// Get or initialize the ONNX session (lazy, thread-safe via OnceCell).
fn get_or_init_session(&self) -> Result<&Session> {
self.session.get_or_try_init(|| {
let model_path = self.ensure_model()?;
crate::ort_discovery::ensure_ort_available();
let num_threads = crate::core::config::concurrency::resolve_thread_budget(None);
let builder = SessionBuilder::new()
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to create doc_ori session builder: {e}"),
source: None,
})?
.with_optimization_level(GraphOptimizationLevel::All)
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to set doc_ori optimization level: {e}"),
source: None,
})?
.with_intra_threads(num_threads)
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to set doc_ori thread count: {e}"),
source: None,
})?
.with_inter_threads(1)
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to set doc_ori inter threads: {e}"),
source: None,
})?;
let mut builder = crate::ort_discovery::apply_execution_providers(builder, self.acceleration.as_ref())
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to set doc_ori execution providers: {e}"),
source: None,
})?;
let session = builder.commit_from_file(&model_path).map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to load doc_ori model: {e}"),
source: None,
})?;
tracing::info!("Doc orientation model loaded");
Ok(session)
})
}
}
/// Resolve the cache directory for the auto-rotate model.
pub(crate) fn resolve_cache_dir() -> PathBuf {
crate::cache_dir::resolve_cache_dir("auto-rotate")
}
/// Detect orientation and return a corrected image if rotation is needed.
///
/// Returns `Ok(Some(rotated_bytes))` if rotation was applied,
/// `Ok(None)` if no rotation needed (0° or low confidence).
#[cfg_attr(alef, alef(skip))]
pub(crate) fn detect_and_rotate(detector: &DocOrientationDetector, image_bytes: &[u8]) -> Result<Option<Vec<u8>>> {
let img = image::load_from_memory(image_bytes)
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to load image for orientation detection: {e}"),
source: None,
})?
.to_rgb8();
let result = detector.detect(&img)?;
tracing::debug!(
degrees = result.degrees,
confidence = result.confidence,
"Document orientation detected"
);
if result.degrees == 0 || result.confidence < MIN_CONFIDENCE {
return Ok(None);
}
// Rotate the image back to upright (opposite direction of detected orientation).
let rotated = match result.degrees {
90 => image::imageops::rotate270(&img),
180 => image::imageops::rotate180(&img),
270 => image::imageops::rotate90(&img),
_ => return Ok(None),
};
// Encode back to PNG bytes
let mut buf = std::io::Cursor::new(Vec::new());
rotated
.write_to(&mut buf, image::ImageFormat::Png)
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to encode rotated image: {e}"),
source: None,
})?;
tracing::info!(
degrees = result.degrees,
confidence = result.confidence,
"Auto-rotated document page"
);
Ok(Some(buf.into_inner()))
}
/// Resize short side to 256, then center crop to 224×224.
fn preprocess(image: &RgbImage) -> RgbImage {
let (w, h) = (image.width(), image.height());
// Resize: scale so short side = RESIZE_SHORT
let (new_w, new_h) = if w < h {
let scale = RESIZE_SHORT as f32 / w as f32;
(RESIZE_SHORT, (h as f32 * scale).round() as u32)
} else {
let scale = RESIZE_SHORT as f32 / h as f32;
((w as f32 * scale).round() as u32, RESIZE_SHORT)
};
let resized = image::imageops::resize(image, new_w, new_h, image::imageops::FilterType::Triangle);
// Center crop to INPUT_SIZE × INPUT_SIZE
let x_offset = (new_w.saturating_sub(INPUT_SIZE)) / 2;
let y_offset = (new_h.saturating_sub(INPUT_SIZE)) / 2;
let crop_w = INPUT_SIZE.min(new_w);
let crop_h = INPUT_SIZE.min(new_h);
image::imageops::crop_imm(&resized, x_offset, y_offset, crop_w, crop_h).to_image()
}
/// Normalize image to [1, 3, H, W] tensor with ImageNet mean/std in BGR order.
/// PP-LCNet expects BGR input: channel 0=Blue, 1=Green, 2=Red.
fn normalize(image: &RgbImage) -> ndarray::Array4<f32> {
let (w, h) = (image.width() as usize, image.height() as usize);
let mut tensor = ndarray::Array4::<f32>::zeros((1, 3, h, w));
// ImageNet mean/std for BGR order (swap R and B)
const BGR_MEAN: [f32; 3] = [0.406 * 255.0, 0.456 * 255.0, 0.485 * 255.0];
const BGR_NORM: [f32; 3] = [1.0 / (0.225 * 255.0), 1.0 / (0.224 * 255.0), 1.0 / (0.229 * 255.0)];
for y in 0..h {
for x in 0..w {
let pixel = image.get_pixel(x as u32, y as u32);
let r = pixel[0] as f32;
let g = pixel[1] as f32;
let b = pixel[2] as f32;
// BGR order: channel 0 = Blue, channel 1 = Green, channel 2 = Red
tensor[[0, 0, y, x]] = (b - BGR_MEAN[0]) * BGR_NORM[0];
tensor[[0, 1, y, x]] = (g - BGR_MEAN[1]) * BGR_NORM[1];
tensor[[0, 2, y, x]] = (r - BGR_MEAN[2]) * BGR_NORM[2];
}
}
tensor
}