Nomad changes
All checks were successful
Deploy fil (kreuzberg) / deploy (push) Successful in 49s

This commit is contained in:
Henrik Jess Nielsen
2026-06-01 23:40:55 +02:00
parent 72b1a0a6ed
commit b4c07d3693
5723 changed files with 1130655 additions and 0 deletions

View File

@@ -0,0 +1,139 @@
use crate::{
base_net::BaseNet,
constants::{IMAGENET_MEAN_VALUES, IMAGENET_NORM_VALUES},
ocr_error::OcrError,
ocr_result::Angle,
ocr_utils::OcrUtils,
};
use ort::{
inputs,
session::{Session, SessionOutputs},
value::Tensor,
};
// PP-LCNet_x1_0_textline_ori preprocessing (ImageNet normalization).
// Input: resize to 160×80 (W×H), normalize with ImageNet mean/std.
// Formula in substract_mean_normalize: (pixel - MEAN) * NORM
// For ImageNet: (pixel/255 - mean) / std = (pixel - mean*255) * (1/(std*255))
// V2 PP-LCNet angle classifier expects [3, 80, 160] input (NCHW).
const ANGLE_DST_WIDTH: u32 = 160;
const ANGLE_DST_HEIGHT: u32 = 80;
const ANGLE_COLS: usize = 2;
#[derive(Debug)]
pub struct AngleNet {
session: Option<Session>,
input_names: Vec<String>,
}
impl BaseNet for AngleNet {
fn new() -> Self {
Self {
session: None,
input_names: Vec::new(),
}
}
fn set_input_names(&mut self, input_names: Vec<String>) {
self.input_names = input_names;
}
fn set_session(&mut self, session: Option<Session>) {
self.session = session;
}
}
impl AngleNet {
pub fn get_angles(
&self,
part_imgs: &[image::RgbImage],
do_angle: bool,
most_angle: bool,
cls_thresh: f32,
) -> Result<Vec<Angle>, OcrError> {
// Pre-allocate — we know exact count upfront.
let mut angles = Vec::with_capacity(part_imgs.len());
if do_angle {
for img in part_imgs {
let angle = self.get_angle(img, cls_thresh)?;
angles.push(angle);
}
} else {
angles.extend(part_imgs.iter().map(|_| Angle::default()));
}
if do_angle && most_angle {
let sum: i32 = angles.iter().map(|x| x.index).sum();
let half_percent = angles.len() as f32 / 2.0;
let most_angle_index = if (sum as f32) < half_percent { 0 } else { 1 };
for angle in angles.iter_mut() {
angle.index = most_angle_index;
}
}
Ok(angles)
}
fn get_angle(&self, img_src: &image::RgbImage, cls_thresh: f32) -> Result<Angle, OcrError> {
let Some(session) = &self.session else {
return Err(OcrError::SessionNotInitialized);
};
let angle_img = image::imageops::resize(
img_src,
ANGLE_DST_WIDTH,
ANGLE_DST_HEIGHT,
image::imageops::FilterType::Triangle,
);
let input_tensors =
OcrUtils::substract_mean_normalize(&angle_img, &IMAGENET_MEAN_VALUES, &IMAGENET_NORM_VALUES);
let input_tensors = Tensor::from_array(input_tensors)?;
// SAFETY: ONNX Runtime C API is thread-safe for concurrent inference.
#[allow(unsafe_code)]
let outputs = unsafe {
let session_ptr = session as *const Session as *mut Session;
(*session_ptr).run(inputs![self.input_names[0].as_str() => input_tensors])?
};
let mut angle = Self::score_to_angle(&outputs, ANGLE_COLS)?;
// Only apply rotation if confidence exceeds threshold (matches PaddleOCR's cls_thresh=0.9)
if angle.score < cls_thresh {
angle.index = 0; // Keep original orientation when confidence is low
}
Ok(angle)
}
fn score_to_angle(output_tensor: &SessionOutputs, angle_cols: usize) -> Result<Angle, OcrError> {
let (_, red_data) = output_tensor.iter().next().ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"No output tensors found in angle classification session output",
))
})?;
let src_data: Vec<f32> = red_data.try_extract_tensor::<f32>()?.1.to_vec();
let mut angle = Angle::default();
let mut max_value = f32::MIN;
let mut angle_index = 0;
for (i, value) in src_data.iter().take(angle_cols).enumerate() {
if *value > max_value {
max_value = *value;
angle_index = i as i32;
}
}
angle.index = angle_index;
angle.score = max_value;
Ok(angle)
}
}

View File

@@ -0,0 +1,78 @@
use ort::session::{
Session,
builder::{GraphOptimizationLevel, SessionBuilder},
};
use crate::ocr_error::OcrError;
pub trait BaseNet {
fn new() -> Self;
fn get_session_builder(
&self,
num_thread: usize,
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
) -> Result<SessionBuilder, OcrError> {
let builder = Session::builder()?;
let builder = match builder_fn {
Some(custom) => custom(builder)?,
None => builder
.with_optimization_level(GraphOptimizationLevel::All)
.map_err(|e| OcrError::Ort(ort::Error::new(e.message())))?
.with_intra_threads(num_thread)
.map_err(|e| OcrError::Ort(ort::Error::new(e.message())))?
.with_inter_threads(1)
.map_err(|e| OcrError::Ort(ort::Error::new(e.message())))?,
};
Ok(builder)
}
fn set_input_names(&mut self, input_names: Vec<String>);
fn set_session(&mut self, session: Option<Session>);
fn init(&mut self, session: Session) {
let input_names: Vec<String> = session.inputs().iter().map(|input| input.name().to_string()).collect();
self.set_input_names(input_names);
self.set_session(Some(session));
}
fn init_model(
&mut self,
path: &str,
num_thread: usize,
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
) -> Result<(), OcrError> {
// Wrap ORT session creation in catch_unwind to prevent mutex poisoning
// on platforms where ORT initialization can panic (notably Windows).
let path_owned = path.to_string();
let session = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut builder = self.get_session_builder(num_thread, builder_fn)?;
builder.commit_from_file(&path_owned).map_err(OcrError::from)
}))
.map_err(|_| OcrError::Ort(ort::Error::new("ORT session initialization panicked")))??;
self.init(session);
Ok(())
}
fn init_model_from_memory(
&mut self,
model_bytes: &[u8],
num_thread: usize,
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
) -> Result<(), OcrError> {
// Wrap ORT session creation in catch_unwind to prevent mutex poisoning
// on platforms where ORT initialization can panic (notably Windows).
let session = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut builder = self.get_session_builder(num_thread, builder_fn)?;
builder.commit_from_memory(model_bytes).map_err(OcrError::from)
}))
.map_err(|_| OcrError::Ort(ort::Error::new("ORT session initialization panicked")))??;
self.init(session);
Ok(())
}
}

View File

@@ -0,0 +1,33 @@
//! Shared normalization constants for PaddleOCR preprocessing.
//!
//! Two normalization schemes are used:
//!
//! - **ImageNet** (`IMAGENET_MEAN_VALUES` / `IMAGENET_NORM_VALUES`): used by the text
//! detection network (`DbNet`) and the angle classifier (`AngleNet`).
//! Formula: `(pixel - mean * 255) * (1 / (std * 255))`.
//!
//! - **CRNN** (`CRNN_MEAN_VALUES` / `CRNN_NORM_VALUES`): used by the text recognition
//! network (`CrnnNet`).
//! Formula: `(pixel - 127.5) * (1 / 127.5)`.
/// ImageNet channel means (R, G, B), pre-multiplied by 255.
///
/// Derived from `[0.485, 0.456, 0.406]` (per-channel ImageNet means).
/// Used by `DbNet` (text detection) and `AngleNet` (angle classification).
pub(crate) const IMAGENET_MEAN_VALUES: [f32; 3] = [0.485 * 255.0, 0.456 * 255.0, 0.406 * 255.0];
/// ImageNet channel normalization factors (R, G, B), equal to `1 / (std * 255)`.
///
/// Derived from `[0.229, 0.224, 0.225]` (per-channel ImageNet standard deviations).
/// Used by `DbNet` (text detection) and `AngleNet` (angle classification).
pub(crate) const IMAGENET_NORM_VALUES: [f32; 3] = [1.0 / (0.229 * 255.0), 1.0 / (0.224 * 255.0), 1.0 / (0.225 * 255.0)];
/// CRNN channel means (R, G, B): `127.5` for all channels.
///
/// Used by `CrnnNet` (text recognition).
pub(crate) const CRNN_MEAN_VALUES: [f32; 3] = [127.5, 127.5, 127.5];
/// CRNN channel normalization factors (R, G, B): `1 / 127.5` for all channels.
///
/// Used by `CrnnNet` (text recognition).
pub(crate) const CRNN_NORM_VALUES: [f32; 3] = [1.0 / 127.5, 1.0 / 127.5, 1.0 / 127.5];

View File

@@ -0,0 +1,393 @@
use ndarray::Array4;
use ort::session::Session;
use ort::value::Tensor;
use ort::{inputs, session::builder::SessionBuilder};
use std::collections::HashMap;
use crate::{
base_net::BaseNet,
constants::{CRNN_MEAN_VALUES, CRNN_NORM_VALUES},
ocr_error::OcrError,
ocr_result::TextLine,
ocr_utils::OcrUtils,
};
const CRNN_DST_HEIGHT: u32 = 48;
#[derive(Debug)]
pub struct CrnnNet {
session: Option<Session>,
keys: Vec<String>,
input_names: Vec<String>,
}
impl BaseNet for CrnnNet {
fn new() -> Self {
Self {
session: None,
keys: Vec::new(),
input_names: Vec::new(),
}
}
fn set_input_names(&mut self, input_names: Vec<String>) {
self.input_names = input_names;
}
fn set_session(&mut self, session: Option<Session>) {
self.session = session;
}
}
impl CrnnNet {
pub fn init_model(
&mut self,
path: &str,
num_thread: usize,
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
) -> Result<(), OcrError> {
BaseNet::init_model(self, path, num_thread, builder_fn)?;
self.keys = self.get_keys()?;
Ok(())
}
pub fn init_model_dict_file(
&mut self,
path: &str,
num_thread: usize,
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
dict_file_path: &str,
) -> Result<(), OcrError> {
BaseNet::init_model(self, path, num_thread, builder_fn)?;
self.read_keys_from_file(dict_file_path)?;
Ok(())
}
pub fn init_model_from_memory(
&mut self,
model_bytes: &[u8],
num_thread: usize,
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
) -> Result<(), OcrError> {
BaseNet::init_model_from_memory(self, model_bytes, num_thread, builder_fn)?;
self.keys = self.get_keys()?;
Ok(())
}
fn get_keys(&mut self) -> Result<Vec<String>, OcrError> {
let session = self.session.as_ref().ok_or(OcrError::SessionNotInitialized)?;
let metadata = session.metadata()?;
let model_charater_list = metadata.custom("character").ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
"crnn_net character not found in metadata",
))
})?;
// PP-OCRv5 model metadata already includes the CTC blank token ("#") at
// index 0 and the space token (" ") at the end. Do NOT prepend/append
// extra tokens — doing so shifts every character index by one and
// produces garbled output.
let keys: Vec<String> = model_charater_list.split('\n').map(|s: &str| s.to_string()).collect();
Ok(keys)
}
fn read_keys_from_file(&mut self, path: &str) -> Result<(), OcrError> {
let content = std::fs::read_to_string(path)?;
// PP-OCRv5 dict files already include the CTC blank token ("#") at
// index 0 and the space token (" ") at the end. Do NOT prepend/append
// extra tokens — doing so shifts every character index by one and
// produces garbled output.
let keys: Vec<String> = content.split('\n').map(|s| s.to_string()).collect();
self.keys = keys;
Ok(())
}
pub fn get_text_lines(
&self,
part_imgs: &[image::RgbImage],
angle_rollback_records: &HashMap<usize, image::RgbImage>,
angle_rollback_threshold: f32,
batch_size: u32,
) -> Result<Vec<TextLine>, OcrError> {
if part_imgs.is_empty() {
return Ok(Vec::new());
}
// Batch recognition: sort by aspect ratio, batch, pad to max width
let mut text_lines = self.get_text_lines_batched(part_imgs, batch_size)?;
// Angle rollback: re-recognize individual images that scored poorly
for (index, text_line) in text_lines.iter_mut().enumerate() {
if (text_line.text_score.is_nan() || text_line.text_score < angle_rollback_threshold)
&& let Some(angle_rollback_record) = angle_rollback_records.get(&index)
{
*text_line = self.get_text_line(angle_rollback_record)?;
}
}
Ok(text_lines)
}
/// Batch recognition: sort crops by width, group into batches, pad to max width,
/// run single ONNX inference per batch. Matches PaddleOCR/RapidOCR batching strategy.
fn get_text_lines_batched(
&self,
part_imgs: &[image::RgbImage],
batch_size: u32,
) -> Result<Vec<TextLine>, OcrError> {
let session = self.session.as_ref().ok_or(OcrError::SessionNotInitialized)?;
let batch_size = (batch_size as usize).max(1);
// Compute target widths and sort indices by aspect ratio (width/height)
let mut indexed_widths: Vec<(usize, u32)> = part_imgs
.iter()
.enumerate()
.map(|(i, img)| {
let scale = CRNN_DST_HEIGHT as f32 / img.height().max(1) as f32;
let dst_width = (img.width() as f32 * scale).ceil() as u32;
(i, dst_width.max(1))
})
.collect();
indexed_widths.sort_by_key(|&(_, w)| w);
let mut results: Vec<(usize, TextLine)> = Vec::with_capacity(part_imgs.len());
// Process in batches
for chunk in indexed_widths.chunks(batch_size) {
if chunk.len() == 1 {
// Single image — use existing path (no padding overhead)
let (orig_idx, _) = chunk[0];
let text_line = self.get_text_line(&part_imgs[orig_idx])?;
results.push((orig_idx, text_line));
continue;
}
// Find max width in this batch
let max_width = chunk.iter().map(|&(_, w)| w).max().unwrap_or(1);
// Build batch tensor [N, 3, 48, max_width] with zero-padding
let n = chunk.len();
let mut batch_data = Array4::<f32>::zeros((n, 3, CRNN_DST_HEIGHT as usize, max_width as usize));
for (batch_idx, &(orig_idx, dst_width)) in chunk.iter().enumerate() {
let img = &part_imgs[orig_idx];
let resized =
image::imageops::resize(img, dst_width, CRNN_DST_HEIGHT, image::imageops::FilterType::Triangle);
// Normalize and fill into batch tensor (zero-padded on right).
// Use raw slice access instead of per-pixel get_pixel() to
// eliminate millions of bounds checks in the hot loop.
let cols = resized.width() as usize;
let rows = resized.height() as usize;
let raw = resized.as_raw();
assert_eq!(raw.len(), rows * cols * 3, "unexpected image buffer size");
let adjusted = [
CRNN_MEAN_VALUES[0] * CRNN_NORM_VALUES[0],
CRNN_MEAN_VALUES[1] * CRNN_NORM_VALUES[1],
CRNN_MEAN_VALUES[2] * CRNN_NORM_VALUES[2],
];
for r in 0..rows {
for c in 0..cols {
let base = r * cols * 3 + c * 3;
for ch in 0..3 {
batch_data[[batch_idx, ch, r, c]] =
raw[base + ch] as f32 * CRNN_NORM_VALUES[ch] - adjusted[ch];
}
}
}
// Remaining columns stay zero (padding)
}
let input_tensor = Tensor::from_array(batch_data)?;
// SAFETY: ONNX Runtime C API is thread-safe for concurrent inference.
#[allow(unsafe_code)]
let outputs = unsafe {
let session_ptr = session as *const Session as *mut Session;
(*session_ptr).run(inputs![self.input_names[0].as_str() => input_tensor])?
};
let (_, output_value) = outputs.iter().next().ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"No output tensors found in batched CRNN session output",
))
})?;
let (shape, flat_data) = output_value.try_extract_tensor::<f32>()?;
// Shape: [batch, timesteps, num_classes]
let batch_dim = *shape.first().unwrap_or(&1) as usize;
let timesteps = *shape.get(1).unwrap_or(&0) as usize;
let num_classes = *shape.get(2).unwrap_or(&0) as usize;
for (batch_idx, item) in chunk.iter().enumerate().take(batch_dim.min(n)) {
let offset = batch_idx * timesteps * num_classes;
let slice = &flat_data[offset..offset + timesteps * num_classes];
let text_line = Self::score_to_text_line(slice, timesteps, num_classes, &self.keys)?;
results.push((item.0, text_line));
}
}
// Reorder results back to original index order
results.sort_by_key(|&(idx, _)| idx);
Ok(results.into_iter().map(|(_, tl)| tl).collect())
}
fn get_text_line(&self, img_src: &image::RgbImage) -> Result<TextLine, OcrError> {
let Some(session) = &self.session else {
return Err(OcrError::SessionNotInitialized);
};
let scale = CRNN_DST_HEIGHT as f32 / img_src.height() as f32;
let dst_width = (img_src.width() as f32 * scale).ceil() as u32;
let src_resize = image::imageops::resize(
img_src,
dst_width,
CRNN_DST_HEIGHT,
image::imageops::FilterType::Triangle,
);
let input_tensors = OcrUtils::substract_mean_normalize(&src_resize, &CRNN_MEAN_VALUES, &CRNN_NORM_VALUES);
let input_tensors = Tensor::from_array(input_tensors)?;
// SAFETY: ONNX Runtime C API is thread-safe for concurrent inference.
#[allow(unsafe_code)]
let outputs = unsafe {
let session_ptr = session as *const Session as *mut Session;
(*session_ptr).run(inputs![self.input_names[0].as_str() => input_tensors])?
};
let (_, red_data) = outputs.iter().next().ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"No output tensors found in CRNN session output",
))
})?;
let (shape, src_data) = red_data.try_extract_tensor::<f32>()?;
let dimensions = shape;
let height = *dimensions.get(1).ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"CRNN output tensor missing height dimension (index 1)",
))
})? as usize;
let width = *dimensions.get(2).ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"CRNN output tensor missing width dimension (index 2)",
))
})? as usize;
let src_data: Vec<f32> = src_data.to_vec();
Self::score_to_text_line(&src_data, height, width, &self.keys)
}
fn score_to_text_line(
output_data: &[f32],
height: usize,
width: usize,
keys: &[String],
) -> Result<TextLine, OcrError> {
let mut text_line = TextLine::default();
let mut last_index = 0;
let mut text_score_sum = 0.0;
let mut text_score_count = 0;
for i in 0..height {
let start = i * width;
let stop = (i + 1) * width;
let slice = &output_data[start..stop.min(output_data.len())];
let (max_index, max_value) =
slice
.iter()
.enumerate()
.fold((0, f32::MIN), |(max_idx, max_val), (idx, &val)| {
if val > max_val { (idx, val) } else { (max_idx, max_val) }
});
if max_index > 0 && max_index < keys.len() && !(i > 0 && max_index == last_index) {
text_line.text.push_str(&keys[max_index]);
text_score_sum += max_value;
text_score_count += 1;
}
last_index = max_index;
}
// Avoid division by zero: handle case where no characters were found
text_line.text_score = if text_score_count > 0 {
text_score_sum / text_score_count as f32
} else {
0.0
};
Ok(text_line)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_score_to_text_line_skips_blank_index() {
// keys[0] = "#" (CTC blank), keys[1] = "a", keys[2] = "b"
let keys = vec!["#".to_string(), "a".to_string(), "b".to_string()];
// 3 timesteps, 3 classes each. Simulate: blank, "a", "b"
let output = vec![
1.0, 0.0, 0.0, // timestep 0: max at index 0 (blank) -> skip
0.0, 0.9, 0.1, // timestep 1: max at index 1 ("a")
0.0, 0.1, 0.8, // timestep 2: max at index 2 ("b")
];
let result = CrnnNet::score_to_text_line(&output, 3, 3, &keys).unwrap();
assert_eq!(result.text, "ab");
}
#[test]
fn test_score_to_text_line_deduplicates_consecutive() {
let keys = vec!["#".to_string(), "h".to_string(), "i".to_string()];
// 4 timesteps: "h", "h", "i", "i" -> should deduplicate to "hi"
let output = vec![
0.0, 0.9, 0.0, // "h"
0.0, 0.8, 0.0, // "h" again (same index, skip)
0.0, 0.0, 0.9, // "i"
0.0, 0.0, 0.8, // "i" again (same index, skip)
];
let result = CrnnNet::score_to_text_line(&output, 4, 3, &keys).unwrap();
assert_eq!(result.text, "hi");
}
#[test]
fn test_read_keys_from_file_preserves_dict_layout() {
let dir = std::env::temp_dir().join("kreuzberg_test_dict");
std::fs::create_dir_all(&dir).unwrap();
let dict_path = dir.join("test_dict.txt");
// PP-OCRv5 dict files already include "#" (blank) at start and " " at end.
std::fs::write(&dict_path, "#\na\nb\nc\n ").unwrap();
let mut net = CrnnNet::new();
net.read_keys_from_file(dict_path.to_str().unwrap()).unwrap();
// Dict is loaded as-is: ["#", "a", "b", "c", " "]
assert_eq!(net.keys[0], "#");
assert_eq!(net.keys[1], "a");
assert_eq!(net.keys[2], "b");
assert_eq!(net.keys[3], "c");
assert_eq!(net.keys[net.keys.len() - 1], " ");
std::fs::remove_dir_all(&dir).ok();
}
}

View File

@@ -0,0 +1,421 @@
use crate::{
base_net::BaseNet,
constants::{IMAGENET_MEAN_VALUES, IMAGENET_NORM_VALUES},
ocr_error::OcrError,
ocr_result::{self, TextBox},
ocr_utils::OcrUtils,
scale_param::ScaleParam,
};
use geo_clipper::{Clipper, EndType, JoinType};
use geo_types::{Coord, LineString, Polygon};
use ort::{inputs, session::SessionOutputs};
use ort::{session::Session, value::Tensor};
use std::cmp::Ordering;
#[derive(Debug)]
pub struct DbNet {
session: Option<Session>,
input_names: Vec<String>,
}
impl BaseNet for DbNet {
fn new() -> Self {
Self {
session: None,
input_names: Vec::new(),
}
}
fn set_input_names(&mut self, input_names: Vec<String>) {
self.input_names = input_names;
}
fn set_session(&mut self, session: Option<Session>) {
self.session = session;
}
}
impl DbNet {
pub fn get_text_boxes(
&self,
img_src: &image::RgbImage,
scale: &ScaleParam,
box_score_thresh: f32,
box_thresh: f32,
un_clip_ratio: f32,
thresh: f32,
) -> Result<Vec<TextBox>, OcrError> {
let Some(session) = &self.session else {
return Err(OcrError::SessionNotInitialized);
};
let src_resize = image::imageops::resize(
img_src,
scale.dst_width,
scale.dst_height,
image::imageops::FilterType::Triangle,
);
let input_tensors =
OcrUtils::substract_mean_normalize(&src_resize, &IMAGENET_MEAN_VALUES, &IMAGENET_NORM_VALUES);
let tensor = Tensor::from_array(input_tensors)?;
// SAFETY: ONNX Runtime's C API (OrtRun) is thread-safe for concurrent inference
// on the same session. The ort crate's `&mut self` requirement is overly
// conservative. This matches the pattern used in kreuzberg's embedding engine.
#[allow(unsafe_code)]
let outputs = unsafe {
let session_ptr = session as *const Session as *mut Session;
(*session_ptr).run(inputs![self.input_names[0].as_str() => tensor])?
};
let text_boxes = Self::get_text_boxes_core(
&outputs,
src_resize.height(),
src_resize.width(),
&ScaleParam::new(
scale.src_width,
scale.src_height,
scale.dst_width,
scale.dst_height,
scale.scale_width,
scale.scale_height,
),
box_score_thresh,
box_thresh,
un_clip_ratio,
thresh,
)?;
Ok(text_boxes)
}
fn get_text_boxes_core(
output_tensor: &SessionOutputs,
rows: u32,
cols: u32,
s: &ScaleParam,
box_score_thresh: f32,
_box_thresh: f32,
un_clip_ratio: f32,
thresh: f32,
) -> Result<Vec<TextBox>, OcrError> {
let max_side_thresh = 3.0;
let (_, red_data) = output_tensor.iter().next().ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"No output tensors found in session output",
))
})?;
let pred_data: Vec<f32> = red_data.try_extract_tensor::<f32>()?.1.to_vec();
let cbuf_data: Vec<u8> = pred_data.iter().map(|pixel| (pixel * 255.0) as u8).collect();
let pred_img: image::ImageBuffer<image::Luma<f32>, Vec<f32>> =
image::ImageBuffer::from_vec(cols, rows, pred_data).ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Failed to create image buffer from predictions: {} x {} dimensions may be invalid",
cols, rows
),
))
})?;
let cbuf_img = image::GrayImage::from_vec(cols, rows, cbuf_data).ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Failed to create grayscale image buffer: {} x {} dimensions may be invalid",
cols, rows
),
))
})?;
let threshold_img = imageproc::contrast::threshold(
&cbuf_img,
(thresh * 255.0) as u8,
imageproc::contrast::ThresholdType::Binary,
);
// RapidOCR and PaddleOCR reference do NOT apply dilation before contour extraction.
// Dilation merges adjacent text regions, causing word concatenation.
let img_contours: Vec<imageproc::contours::Contour<i32>> = imageproc::contours::find_contours(&threshold_img);
// Pre-allocate based on contour count to avoid repeated reallocations.
let mut rs_boxes = Vec::with_capacity(img_contours.len());
for contour in img_contours {
if contour.points.len() <= 2 {
continue;
}
let mut max_side = 0.0;
let min_box = Self::get_mini_box(&contour.points, &mut max_side)?;
if max_side < max_side_thresh {
continue;
}
let score = Self::get_score(&contour, &pred_img)?;
if score < box_score_thresh {
continue;
}
let clip_box = Self::unclip(&min_box, un_clip_ratio)?;
if clip_box.is_empty() {
continue;
}
let mut clip_contour = Vec::new();
for point in &clip_box {
clip_contour.push(*point);
}
let mut max_side_clip = 0.0;
let clip_min_box = Self::get_mini_box(&clip_contour, &mut max_side_clip)?;
if max_side_clip < max_side_thresh + 2.0 {
continue;
}
let mut final_points = Vec::new();
for item in clip_min_box {
let x = (item.x / s.scale_width) as u32;
let ptx = x.min(s.src_width);
let y = (item.y / s.scale_height) as u32;
let pty = y.min(s.src_height);
final_points.push(ocr_result::Point { x: ptx, y: pty });
}
let text_box = TextBox {
score,
points: final_points,
};
rs_boxes.push(text_box);
}
Ok(rs_boxes)
}
fn get_mini_box(
contour_points: &[imageproc::point::Point<i32>],
min_edge_size: &mut f32,
) -> Result<Vec<imageproc::point::Point<f32>>, OcrError> {
let rect = imageproc::geometry::min_area_rect(contour_points);
let mut rect_points: Vec<imageproc::point::Point<f32>> = rect
.iter()
.map(|p| imageproc::point::Point::new(p.x as f32, p.y as f32))
.collect();
// Direct multiplication instead of .powi(2) — avoids function call overhead.
let dx_w = rect_points[0].x - rect_points[1].x;
let dy_w = rect_points[0].y - rect_points[1].y;
let width = (dx_w * dx_w + dy_w * dy_w).sqrt();
let dx_h = rect_points[1].x - rect_points[2].x;
let dy_h = rect_points[1].y - rect_points[2].y;
let height = (dx_h * dx_h + dy_h * dy_h).sqrt();
*min_edge_size = width.min(height);
rect_points.sort_by(|a, b| {
if a.x > b.x {
return Ordering::Greater;
}
if a.x == b.x {
return Ordering::Equal;
}
Ordering::Less
});
let mut box_points = Vec::new();
let index_1;
let index_4;
if rect_points[1].y > rect_points[0].y {
index_1 = 0;
index_4 = 1;
} else {
index_1 = 1;
index_4 = 0;
}
let index_2;
let index_3;
if rect_points[3].y > rect_points[2].y {
index_2 = 2;
index_3 = 3;
} else {
index_2 = 3;
index_3 = 2;
}
box_points.push(rect_points[index_1]);
box_points.push(rect_points[index_2]);
box_points.push(rect_points[index_3]);
box_points.push(rect_points[index_4]);
Ok(box_points)
}
fn get_score(
contour: &imageproc::contours::Contour<i32>,
f_map_mat: &image::ImageBuffer<image::Luma<f32>, Vec<f32>>,
) -> Result<f32, OcrError> {
// Initialize boundary values
let mut xmin = i32::MAX;
let mut xmax = i32::MIN;
let mut ymin = i32::MAX;
let mut ymax = i32::MIN;
// Find contour bounding box
for point in contour.points.iter() {
let x = point.x;
let y = point.y;
if x < xmin {
xmin = x;
}
if x > xmax {
xmax = x;
}
if y < ymin {
ymin = y;
}
if y > ymax {
ymax = y;
}
}
let width = f_map_mat.width() as i32;
let height = f_map_mat.height() as i32;
xmin = xmin.max(0).min(width - 1);
xmax = xmax.max(0).min(width - 1);
ymin = ymin.max(0).min(height - 1);
ymax = ymax.max(0).min(height - 1);
let roi_width = xmax - xmin + 1;
let roi_height = ymax - ymin + 1;
if roi_width <= 0 || roi_height <= 0 {
return Ok(0.0);
}
let mut mask = image::GrayImage::new(roi_width as u32, roi_height as u32);
let mut pts = Vec::<imageproc::point::Point<i32>>::new();
for point in contour.points.iter() {
pts.push(imageproc::point::Point::new(point.x - xmin, point.y - ymin));
}
imageproc::drawing::draw_polygon_mut(&mut mask, pts.as_slice(), image::Luma([255]));
let cropped_img =
image::imageops::crop_imm(f_map_mat, xmin as u32, ymin as u32, roi_width as u32, roi_height as u32)
.to_image();
let mean = OcrUtils::calculate_mean_with_mask(&cropped_img, &mask);
Ok(mean)
}
fn unclip(
box_points: &[imageproc::point::Point<f32>],
unclip_ratio: f32,
) -> Result<Vec<imageproc::point::Point<i32>>, OcrError> {
// Direct multiplication instead of .powi(2) — avoids function call overhead.
let dx_w = box_points[0].x - box_points[1].x;
let dy_w = box_points[0].y - box_points[1].y;
let clip_rect_width = (dx_w * dx_w + dy_w * dy_w).sqrt();
let dx_h = box_points[1].x - box_points[2].x;
let dy_h = box_points[1].y - box_points[2].y;
let clip_rect_height = (dx_h * dx_h + dy_h * dy_h).sqrt();
if clip_rect_height < 1.001 && clip_rect_width < 1.001 {
return Ok(Vec::new());
}
let mut the_cliper_pts = Vec::new();
for pt in box_points {
let a1 = Coord {
x: pt.x as f64,
y: pt.y as f64,
};
the_cliper_pts.push(a1);
}
let area = Self::signed_polygon_area(box_points).abs();
let length = Self::length_of_points(box_points);
let distance = area * unclip_ratio / length as f32;
let co = Polygon::new(LineString::new(the_cliper_pts), vec![]);
let solution = co
.offset(distance as f64, JoinType::Round(2.0), EndType::ClosedPolygon, 1.0)
.0;
if solution.is_empty() {
return Ok(Vec::new());
}
let first_polygon = solution.first().ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Polygon solution list was empty after offset operation",
))
})?;
let ret_pts: Vec<_> = first_polygon
.exterior()
.points()
.map(|ip| imageproc::point::Point::new(ip.x() as i32, ip.y() as i32))
.collect();
Ok(ret_pts)
}
fn signed_polygon_area(points: &[imageproc::point::Point<f32>]) -> f32 {
let num_points = points.len();
let mut pts = Vec::with_capacity(num_points + 1);
pts.extend_from_slice(points);
pts.push(points[0]);
let mut area = 0.0;
for i in 0..num_points {
area += (pts[i + 1].x - pts[i].x) * (pts[i + 1].y + pts[i].y) / 2.0;
}
area
}
fn length_of_points(box_points: &[imageproc::point::Point<f32>]) -> f64 {
if box_points.is_empty() {
return 0.0;
}
let mut length = 0.0;
let mut x0 = box_points[0].x as f64;
let mut y0 = box_points[0].y as f64;
for pt in &box_points[1..] {
let x1 = pt.x as f64;
let y1 = pt.y as f64;
let dx = x1 - x0;
let dy = y1 - y0;
length += (dx * dx + dy * dy).sqrt();
x0 = x1;
y0 = y1;
}
// Closing segment back to first point
let dx = box_points[0].x as f64 - x0;
let dy = box_points[0].y as f64 - y0;
length += (dx * dx + dy * dy).sqrt();
length
}
}

View File

@@ -0,0 +1,32 @@
//! # kreuzberg-paddle-ocr
//!
//! PaddleOCR via ONNX Runtime for Kreuzberg - high-performance text detection and recognition.
//!
//! This crate is vendored from [paddle-ocr-rs](https://github.com/mg-chao/paddle-ocr-rs)
//! by mg-chao, with modifications for Kreuzberg integration.
//!
//! ## ONNX Runtime Requirement
//!
//! Requires **ONNX Runtime 1.24+** at runtime.
//!
//! ## Original License
//!
//! The original paddle-ocr-rs is licensed under Apache-2.0.
//! This vendored version is relicensed to MIT with the original author's copyright retained.
#![allow(clippy::too_many_arguments)]
pub mod angle_net;
pub mod base_net;
pub(crate) mod constants;
pub mod crnn_net;
pub mod db_net;
pub mod ocr_error;
pub mod ocr_lite;
pub mod ocr_result;
pub mod ocr_utils;
pub mod scale_param;
pub use ocr_error::OcrError;
pub use ocr_lite::OcrLite;
pub use ocr_result::{Angle, OcrResult, Point, TextBlock, TextBox, TextLine};

View File

@@ -0,0 +1,13 @@
use thiserror::Error;
#[derive(Error, Debug)]
pub enum OcrError {
#[error("Ort error: {0}")]
Ort(#[from] ort::Error),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Image error: {0}")]
ImageError(#[from] image::ImageError),
#[error("Session not initialized")]
SessionNotInitialized,
}

View File

@@ -0,0 +1,447 @@
use std::collections::HashMap;
use image::ImageBuffer;
use ort::session::builder::SessionBuilder;
use crate::{
angle_net::AngleNet,
base_net::BaseNet,
crnn_net::CrnnNet,
db_net::DbNet,
ocr_error::OcrError,
ocr_result::{OcrResult, Point, TextBlock, TextBox},
ocr_utils::OcrUtils,
scale_param::ScaleParam,
};
#[derive(Debug)]
pub struct OcrLite {
db_net: DbNet,
angle_net: AngleNet,
crnn_net: CrnnNet,
}
// SAFETY: OcrLite inference methods (&self) use unsafe pointer casts to call
// ort Session::run, which is thread-safe at the ONNX Runtime C API level.
// After initialization (&mut self), no mutable state is accessed during inference.
unsafe impl Send for OcrLite {}
unsafe impl Sync for OcrLite {}
impl Default for OcrLite {
fn default() -> Self {
Self::new()
}
}
impl OcrLite {
pub fn new() -> Self {
Self {
db_net: DbNet::new(),
angle_net: AngleNet::new(),
crnn_net: CrnnNet::new(),
}
}
pub fn init_models(
&mut self,
det_path: &str,
cls_path: &str,
rec_path: &str,
num_thread: usize,
) -> Result<(), OcrError> {
self.db_net.init_model(det_path, num_thread, None)?;
self.angle_net.init_model(cls_path, num_thread, None)?;
self.crnn_net.init_model(rec_path, num_thread, None)?;
Ok(())
}
pub fn init_models_with_dict(
&mut self,
det_path: &str,
cls_path: &str,
rec_path: &str,
dict_path: &str,
num_thread: usize,
) -> Result<(), OcrError> {
self.db_net.init_model(det_path, num_thread, None)?;
self.angle_net.init_model(cls_path, num_thread, None)?;
self.crnn_net
.init_model_dict_file(rec_path, num_thread, None, dict_path)?;
Ok(())
}
pub fn init_models_custom(
&mut self,
det_path: &str,
cls_path: &str,
rec_path: &str,
builder_fn: fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>,
) -> Result<(), OcrError> {
self.db_net.init_model(det_path, 0, Some(builder_fn))?;
self.angle_net.init_model(cls_path, 0, Some(builder_fn))?;
self.crnn_net.init_model(rec_path, 0, Some(builder_fn))?;
Ok(())
}
/// Initialize models with dictionary file and custom session builder.
///
/// Combines `init_models_with_dict` and `init_models_custom`: loads the
/// dictionary for the recognition model while applying a custom ORT
/// session builder (e.g. for GPU execution providers).
pub fn init_models_with_dict_custom(
&mut self,
det_path: &str,
cls_path: &str,
rec_path: &str,
dict_path: &str,
num_thread: usize,
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
) -> Result<(), OcrError> {
self.db_net.init_model(det_path, num_thread, builder_fn)?;
self.angle_net.init_model(cls_path, num_thread, builder_fn)?;
self.crnn_net
.init_model_dict_file(rec_path, num_thread, builder_fn, dict_path)?;
Ok(())
}
pub fn init_models_from_memory(
&mut self,
det_bytes: &[u8],
cls_bytes: &[u8],
rec_bytes: &[u8],
num_thread: usize,
) -> Result<(), OcrError> {
self.db_net.init_model_from_memory(det_bytes, num_thread, None)?;
self.angle_net.init_model_from_memory(cls_bytes, num_thread, None)?;
self.crnn_net.init_model_from_memory(rec_bytes, num_thread, None)?;
Ok(())
}
pub fn init_models_from_memory_custom(
&mut self,
det_bytes: &[u8],
cls_bytes: &[u8],
rec_bytes: &[u8],
builder_fn: fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>,
) -> Result<(), OcrError> {
self.db_net.init_model_from_memory(det_bytes, 0, Some(builder_fn))?;
self.angle_net.init_model_from_memory(cls_bytes, 0, Some(builder_fn))?;
self.crnn_net.init_model_from_memory(rec_bytes, 0, Some(builder_fn))?;
Ok(())
}
fn detect_base(
&self,
img_src: &image::RgbImage,
padding: u32,
max_side_len: u32,
box_score_thresh: f32,
box_thresh: f32,
un_clip_ratio: f32,
do_angle: bool,
most_angle: bool,
angle_rollback: bool,
angle_rollback_threshold: f32,
cls_thresh: f32,
thresh: f32,
) -> Result<OcrResult, OcrError> {
let origin_max_side = img_src.width().max(img_src.height());
let mut resize;
if max_side_len == 0 || max_side_len > origin_max_side {
resize = origin_max_side;
} else {
resize = max_side_len;
}
resize += 2 * padding;
// Cow avoids cloning the image when padding=0 (the common case).
let padding_src = OcrUtils::make_padding(img_src, padding)?;
let scale = ScaleParam::get_scale_param(&padding_src, resize);
self.detect_once(
&padding_src,
&scale,
padding,
box_score_thresh,
box_thresh,
un_clip_ratio,
do_angle,
most_angle,
angle_rollback,
angle_rollback_threshold,
cls_thresh,
thresh,
)
}
/// Detect text in image
///
/// # Arguments
///
/// - `img_src` - Input image
/// - `padding` - Padding width added during image transformation (improves detection)
/// - `max_side_len` - Maximum side length after transformation (larger images will be scaled down)
/// - `box_score_thresh` - Score threshold for text region detection
/// - `box_thresh` - Box threshold
/// - `un_clip_ratio` - Unclip ratio
/// - `do_angle` - Whether to perform angle detection
/// - `most_angle` - Use most common angle for all text regions
const DEFAULT_CLS_THRESH: f32 = 0.9;
const DEFAULT_THRESH: f32 = 0.3;
const DEFAULT_REC_BATCH_SIZE: u32 = 6;
pub fn detect(
&self,
img_src: &image::RgbImage,
padding: u32,
max_side_len: u32,
box_score_thresh: f32,
box_thresh: f32,
un_clip_ratio: f32,
do_angle: bool,
most_angle: bool,
) -> Result<OcrResult, OcrError> {
self.detect_base(
img_src,
padding,
max_side_len,
box_score_thresh,
box_thresh,
un_clip_ratio,
do_angle,
most_angle,
false,
0.0,
Self::DEFAULT_CLS_THRESH,
Self::DEFAULT_THRESH,
)
}
/// Detect text with angle rollback support
///
/// When `do_angle` is true, if the image was angle-corrected but recognition
/// result is poor, the angle correction will be reverted.
///
/// # Arguments
///
/// - `img_src` - Input image
/// - `padding` - Padding width added during image transformation
/// - `max_side_len` - Maximum side length after transformation
/// - `box_score_thresh` - Score threshold for text region detection
/// - `box_thresh` - Box threshold
/// - `un_clip_ratio` - Unclip ratio
/// - `do_angle` - Whether to perform angle detection
/// - `most_angle` - Use most common angle
/// - `angle_rollback_threshold` - If text score is below this value (or NaN), angle correction is reverted
pub fn detect_angle_rollback(
&self,
img_src: &image::RgbImage,
padding: u32,
max_side_len: u32,
box_score_thresh: f32,
box_thresh: f32,
un_clip_ratio: f32,
do_angle: bool,
most_angle: bool,
angle_rollback_threshold: f32,
) -> Result<OcrResult, OcrError> {
self.detect_base(
img_src,
padding,
max_side_len,
box_score_thresh,
box_thresh,
un_clip_ratio,
do_angle,
most_angle,
true,
angle_rollback_threshold,
Self::DEFAULT_CLS_THRESH,
Self::DEFAULT_THRESH,
)
}
pub fn detect_from_path(
&self,
img_path: &str,
padding: u32,
max_side_len: u32,
box_score_thresh: f32,
box_thresh: f32,
un_clip_ratio: f32,
do_angle: bool,
most_angle: bool,
) -> Result<OcrResult, OcrError> {
let img_src = image::open(img_path)?.to_rgb8();
self.detect(
&img_src,
padding,
max_side_len,
box_score_thresh,
box_thresh,
un_clip_ratio,
do_angle,
most_angle,
)
}
/// Sort text boxes in reading order: top-to-bottom, left-to-right.
///
/// Sorts by top-left Y coordinate first, then by top-left X coordinate within
/// the same Y. Matches PaddleOCR Python's `sorted_boxes` primary ordering.
fn sort_text_boxes(text_boxes: &mut [TextBox]) {
text_boxes.sort_by(|a, b| {
let ay = a.points.first().map_or(0, |p| p.y);
let ax = a.points.first().map_or(0, |p| p.x);
let by = b.points.first().map_or(0, |p| p.y);
let bx = b.points.first().map_or(0, |p| p.x);
(ay, ax).cmp(&(by, bx))
});
}
fn detect_once(
&self,
img_src: &image::RgbImage,
scale: &ScaleParam,
padding: u32,
box_score_thresh: f32,
box_thresh: f32,
un_clip_ratio: f32,
do_angle: bool,
most_angle: bool,
angle_rollback: bool,
angle_rollback_threshold: f32,
cls_thresh: f32,
thresh: f32,
) -> Result<OcrResult, OcrError> {
let mut text_boxes =
self.db_net
.get_text_boxes(img_src, scale, box_score_thresh, box_thresh, un_clip_ratio, thresh)?;
// Sort boxes in reading order (top-to-bottom, left-to-right)
Self::sort_text_boxes(&mut text_boxes);
let part_images = OcrUtils::get_part_images(img_src, &text_boxes);
let angles = self
.angle_net
.get_angles(&part_images, do_angle, most_angle, cls_thresh)?;
let mut rotated_images: Vec<image::RgbImage> = Vec::with_capacity(part_images.len());
// Angle correction rollback
let mut angle_rollback_records = HashMap::<usize, ImageBuffer<image::Rgb<u8>, Vec<u8>>>::new();
for (index, (angle, mut part_image)) in angles.iter().zip(part_images).enumerate() {
if angle.index == 1 {
if angle_rollback {
// Keep original copy
angle_rollback_records.insert(index, part_image.clone());
}
OcrUtils::mat_rotate_clock_wise_180(&mut part_image);
}
rotated_images.push(part_image);
}
let text_lines = self.crnn_net.get_text_lines(
&rotated_images,
&angle_rollback_records,
angle_rollback_threshold,
Self::DEFAULT_REC_BATCH_SIZE,
)?;
let mut text_blocks = Vec::with_capacity(text_lines.len());
for (i, text_line) in text_lines.into_iter().enumerate() {
text_blocks.push(TextBlock {
box_points: text_boxes[i]
.points
.iter()
.map(|p| Point {
x: ((p.x as f32) - padding as f32) as u32,
y: ((p.y as f32) - padding as f32) as u32,
})
.collect(),
box_score: text_boxes[i].score,
angle_index: angles[i].index,
angle_score: angles[i].score,
text: text_line.text,
text_score: text_line.text_score,
});
}
Ok(OcrResult { text_blocks })
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ocr_result::TextBox;
fn make_box(x: u32, y: u32) -> TextBox {
TextBox {
points: vec![
Point { x, y },
Point { x: x + 100, y },
Point { x: x + 100, y: y + 20 },
Point { x, y: y + 20 },
],
score: 0.9,
}
}
#[test]
fn test_sort_text_boxes_top_to_bottom() {
let mut boxes = vec![make_box(10, 100), make_box(10, 50), make_box(10, 10)];
OcrLite::sort_text_boxes(&mut boxes);
assert_eq!(boxes[0].points[0].y, 10);
assert_eq!(boxes[1].points[0].y, 50);
assert_eq!(boxes[2].points[0].y, 100);
}
#[test]
fn test_sort_text_boxes_same_line_left_to_right() {
// Boxes with the same Y are sorted left-to-right by X
let mut boxes = vec![make_box(200, 10), make_box(100, 10), make_box(50, 10)];
OcrLite::sort_text_boxes(&mut boxes);
assert_eq!(boxes[0].points[0].x, 50);
assert_eq!(boxes[1].points[0].x, 100);
assert_eq!(boxes[2].points[0].x, 200);
}
#[test]
fn test_sort_text_boxes_multi_line() {
// Boxes sorted strictly by (y, x): y=50/x=50, y=50/x=300, y=100/x=100, y=100/x=200
let mut boxes = vec![
make_box(300, 50), // line 1, right
make_box(100, 100), // line 2, left
make_box(50, 50), // line 1, left (same y=50)
make_box(200, 100), // line 2, right (same y=100)
];
OcrLite::sort_text_boxes(&mut boxes);
// Line 1 (y=50): left first, then right
assert_eq!(boxes[0].points[0].x, 50);
assert_eq!(boxes[1].points[0].x, 300);
// Line 2 (y=100): left first, then right
assert_eq!(boxes[2].points[0].x, 100);
assert_eq!(boxes[3].points[0].x, 200);
}
#[test]
fn test_sort_text_boxes_empty() {
let mut boxes: Vec<TextBox> = vec![];
OcrLite::sort_text_boxes(&mut boxes);
assert!(boxes.is_empty());
}
#[test]
fn test_sort_text_boxes_single() {
let mut boxes = vec![make_box(10, 20)];
OcrLite::sort_text_boxes(&mut boxes);
assert_eq!(boxes.len(), 1);
}
}

View File

@@ -0,0 +1,105 @@
use std::fmt::{self, Write};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct Point {
pub x: u32,
pub y: u32,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TextBox {
pub points: Vec<Point>,
pub score: f32,
}
impl fmt::Display for TextBox {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// SAFETY: We must have at least 4 points in a valid TextBox
// This is enforced at the OCR processing level, but we check bounds here for safety
if self.points.len() < 4 {
return write!(
f,
"TextBox [score({}), points_count({})]",
self.score,
self.points.len()
);
}
write!(
f,
"TextBox [score({}), [x: {}, y: {}], [x: {}, y: {}], [x: {}, y: {}], [x: {}, y: {}]]",
self.score,
self.points[0].x,
self.points[0].y,
self.points[1].x,
self.points[1].y,
self.points[2].x,
self.points[2].y,
self.points[3].x,
self.points[3].y,
)
}
}
#[derive(Debug, Default)]
pub struct Angle {
pub index: i32,
pub score: f32,
}
impl fmt::Display for Angle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let header = if self.index >= 0 { "Angle" } else { "AngleDisabled" };
write!(f, "{}[Index({}), Score({})]", header, self.index, self.score)
}
}
#[derive(Debug, Default)]
pub struct TextLine {
pub text: String,
pub text_score: f32,
}
impl fmt::Display for TextLine {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "TextLine[Text({}),TextScore({})]", self.text, self.text_score)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TextBlock {
pub box_points: Vec<Point>,
pub box_score: f32,
pub angle_index: i32,
pub angle_score: f32,
pub text: String,
pub text_score: f32,
}
#[derive(Serialize, Deserialize)]
pub struct OcrResult {
pub text_blocks: Vec<TextBlock>,
}
impl fmt::Display for OcrResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut str_builder = String::with_capacity(0);
for text_block in &self.text_blocks {
write!(
str_builder,
"TextBlock[BoxPointsLen({}), BoxScore({}), AngleIndex({}), AngleScore({}), Text({}), TextScore({})]",
text_block.box_points.len(),
text_block.box_score,
text_block.angle_index,
text_block.angle_score,
text_block.text,
text_block.text_score
)?;
}
f.write_str(&str_builder)
}
}

View File

@@ -0,0 +1,206 @@
use std::borrow::Cow;
use crate::{
ocr_error::OcrError,
ocr_result::{Point, TextBox},
};
use image::imageops;
use imageproc::geometric_transformations::{Interpolation, Projection};
use ndarray::{Array, Array4};
pub struct OcrUtils;
impl OcrUtils {
/// Normalize image pixels and transpose from HWC (row-major RGB) to CHW tensor format.
///
/// Formula per pixel: `output[ch] = pixel[ch] * norm[ch] - mean[ch] * norm[ch]`
///
/// This is a hot path called once per page. Key optimizations:
/// - Pre-computes `mean * norm` constants (avoids repeated multiply)
/// - Writes each channel plane contiguously via `as_slice_mut()`, enabling
/// LLVM auto-vectorization (NEON on ARM64, SSE/AVX on x86-64). The previous
/// approach used `tensor[[0, ch, r, c]]` which scattered writes across planes
/// and prevented any vectorization.
pub fn substract_mean_normalize(img_src: &image::RgbImage, mean_vals: &[f32], norm_vals: &[f32]) -> Array4<f32> {
let cols = img_src.width() as usize;
let rows = img_src.height() as usize;
let pixel_count = rows * cols;
let mut input_tensor = Array::zeros((1, 3, rows, cols));
let adjusted = [
mean_vals[0] * norm_vals[0],
mean_vals[1] * norm_vals[1],
mean_vals[2] * norm_vals[2],
];
let raw = img_src.as_raw();
// Write each channel plane as a contiguous slice. ndarray stores (1,3,H,W)
// in C-contiguous (row-major) order, so plane [0,ch] is a contiguous H*W block.
// This enables LLVM to auto-vectorize the inner loop (4-8 f32 ops per cycle).
for ch in 0..3 {
let norm = norm_vals[ch];
let adj = adjusted[ch];
let plane = input_tensor
.slice_mut(ndarray::s![0, ch, .., ..])
.into_shape_with_order(pixel_count)
.expect("contiguous plane slice");
let plane_slice = plane.into_slice().expect("contiguous memory");
for (i, out) in plane_slice.iter_mut().enumerate() {
// raw is HWC: pixel i has R at raw[i*3], G at raw[i*3+1], B at raw[i*3+2]
*out = raw[i * 3 + ch] as f32 * norm - adj;
}
}
input_tensor
}
/// Add white padding around the image, or borrow it unchanged when padding=0.
/// Returns Cow to avoid cloning the image in the common no-padding case.
pub fn make_padding<'a>(img_src: &'a image::RgbImage, padding: u32) -> Result<Cow<'a, image::RgbImage>, OcrError> {
if padding == 0 {
return Ok(Cow::Borrowed(img_src));
}
let width = img_src.width();
let height = img_src.height();
let mut padding_src = image::RgbImage::new(width + 2 * padding, height + 2 * padding);
imageproc::drawing::draw_filled_rect_mut(
&mut padding_src,
imageproc::rect::Rect::at(0, 0).of_size(width + 2 * padding, height + 2 * padding),
image::Rgb([255, 255, 255]),
);
image::imageops::replace(&mut padding_src, img_src, padding as i64, padding as i64);
Ok(Cow::Owned(padding_src))
}
pub fn get_part_images(img_src: &image::RgbImage, text_boxes: &[TextBox]) -> Vec<image::RgbImage> {
text_boxes
.iter()
.map(|text_box| Self::get_rotate_crop_image(img_src, &text_box.points))
.collect()
}
pub fn get_rotate_crop_image(img_src: &image::RgbImage, box_points: &[Point]) -> image::RgbImage {
let mut points = box_points.to_vec();
// Calculate bounding box
let (min_x, min_y, max_x, max_y) = points.iter().fold(
(u32::MAX, u32::MAX, 0u32, 0u32),
|(min_x, min_y, max_x, max_y), point| {
(
min_x.min(point.x),
min_y.min(point.y),
max_x.max(point.x),
max_y.max(point.y),
)
},
);
// Crop image
let img_crop = imageops::crop_imm(img_src, min_x, min_y, max_x - min_x, max_y - min_y).to_image();
for point in &mut points {
point.x = point.x.saturating_sub(min_x);
point.y = point.y.saturating_sub(min_y);
}
// Ensure we have enough points for transformation
if points.len() < 4 {
// Fallback: return the cropped image as-is if we don't have 4 points
return img_crop;
}
// Direct multiplication instead of .pow(2) — avoids integer power function overhead.
let dx_w = (points[0].x as i32 - points[1].x as i32) as f32;
let dy_w = (points[0].y as i32 - points[1].y as i32) as f32;
let img_crop_width = (dx_w * dx_w + dy_w * dy_w).sqrt() as u32;
let dx_h = (points[0].x as i32 - points[3].x as i32) as f32;
let dy_h = (points[0].y as i32 - points[3].y as i32) as f32;
let img_crop_height = (dx_h * dx_h + dy_h * dy_h).sqrt() as u32;
// Ensure dimensions are valid (non-zero)
if img_crop_width == 0 || img_crop_height == 0 {
return img_crop;
}
let src_points = [
(points[0].x as f32, points[0].y as f32),
(points[1].x as f32, points[1].y as f32),
(points[2].x as f32, points[2].y as f32),
(points[3].x as f32, points[3].y as f32),
];
let dst_points = [
(0.0, 0.0),
(img_crop_width as f32, 0.0),
(img_crop_width as f32, img_crop_height as f32),
(0.0, img_crop_height as f32),
];
let projection = match Projection::from_control_points(src_points, dst_points) {
Some(proj) => proj,
None => {
// If projection cannot be created, return the cropped image as fallback
return img_crop;
}
};
let mut part_img = image::RgbImage::new(img_crop_width, img_crop_height);
imageproc::geometric_transformations::warp_into(
&img_crop,
&projection,
Interpolation::Nearest,
image::Rgb([255, 255, 255]),
&mut part_img,
);
// Rotate image if needed
if part_img.height() >= part_img.width() * 3 / 2 {
let mut rotated = image::RgbImage::new(part_img.height(), part_img.width());
for (x, y, pixel) in part_img.enumerate_pixels() {
rotated.put_pixel(y, part_img.width() - 1 - x, *pixel);
}
rotated
} else {
part_img
}
}
pub fn mat_rotate_clock_wise_180(src: &mut image::RgbImage) {
imageops::rotate180_in_place(src);
}
/// Compute mean of f32 image values where mask > 0.
///
/// Uses raw slice access instead of per-pixel get_pixel() for better
/// cache behavior and to enable auto-vectorization of the reduction.
pub fn calculate_mean_with_mask(
img: &image::ImageBuffer<image::Luma<f32>, Vec<f32>>,
mask: &image::ImageBuffer<image::Luma<u8>, Vec<u8>>,
) -> f32 {
assert_eq!(img.width(), mask.width());
assert_eq!(img.height(), mask.height());
let img_raw = img.as_raw();
let mask_raw = mask.as_raw();
let mut sum: f32 = 0.0;
let mut count: u32 = 0;
for (px, &m) in img_raw.iter().zip(mask_raw.iter()) {
if m > 0 {
sum += *px;
count += 1;
}
}
if count == 0 { 0.0 } else { sum / count as f32 }
}
}

View File

@@ -0,0 +1,69 @@
#[derive(Debug)]
pub struct ScaleParam {
pub src_width: u32,
pub src_height: u32,
pub dst_width: u32,
pub dst_height: u32,
pub scale_width: f32,
pub scale_height: f32,
}
impl ScaleParam {
pub fn new(
src_width: u32,
src_height: u32,
dst_width: u32,
dst_height: u32,
scale_width: f32,
scale_height: f32,
) -> Self {
Self {
src_width,
src_height,
dst_width,
dst_height,
scale_width,
scale_height,
}
}
pub fn get_scale_param(src: &image::RgbImage, target_size: u32) -> Self {
let src_width = src.width();
let src_height = src.height();
let mut dst_width;
let mut dst_height;
let ratio: f32 = if src_width > src_height {
target_size as f32 / src_width as f32
} else {
target_size as f32 / src_height as f32
};
dst_width = (src_width as f32 * ratio) as u32;
dst_height = (src_height as f32 * ratio) as u32;
if dst_width % 32 != 0 {
dst_width = (dst_width / 32) * 32;
dst_width = dst_width.max(32);
}
if dst_height % 32 != 0 {
dst_height = (dst_height / 32) * 32;
dst_height = dst_height.max(32);
}
let scale_width = dst_width as f32 / src_width as f32;
let scale_height = dst_height as f32 / src_height as f32;
Self::new(src_width, src_height, dst_width, dst_height, scale_width, scale_height)
}
}
impl std::fmt::Display for ScaleParam {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"src_width:{},src_height:{},dst_width:{},dst_height:{},scale_width:{},scale_height:{}",
self.src_width, self.src_height, self.dst_width, self.dst_height, self.scale_width, self.scale_height
)
}
}