This commit is contained in:
139
crates/kreuzberg-paddle-ocr/src/angle_net.rs
Normal file
139
crates/kreuzberg-paddle-ocr/src/angle_net.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use crate::{
|
||||
base_net::BaseNet,
|
||||
constants::{IMAGENET_MEAN_VALUES, IMAGENET_NORM_VALUES},
|
||||
ocr_error::OcrError,
|
||||
ocr_result::Angle,
|
||||
ocr_utils::OcrUtils,
|
||||
};
|
||||
|
||||
use ort::{
|
||||
inputs,
|
||||
session::{Session, SessionOutputs},
|
||||
value::Tensor,
|
||||
};
|
||||
|
||||
// PP-LCNet_x1_0_textline_ori preprocessing (ImageNet normalization).
|
||||
// Input: resize to 160×80 (W×H), normalize with ImageNet mean/std.
|
||||
// Formula in substract_mean_normalize: (pixel - MEAN) * NORM
|
||||
// For ImageNet: (pixel/255 - mean) / std = (pixel - mean*255) * (1/(std*255))
|
||||
// V2 PP-LCNet angle classifier expects [3, 80, 160] input (NCHW).
|
||||
const ANGLE_DST_WIDTH: u32 = 160;
|
||||
const ANGLE_DST_HEIGHT: u32 = 80;
|
||||
const ANGLE_COLS: usize = 2;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AngleNet {
|
||||
session: Option<Session>,
|
||||
input_names: Vec<String>,
|
||||
}
|
||||
|
||||
impl BaseNet for AngleNet {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
session: None,
|
||||
input_names: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn set_input_names(&mut self, input_names: Vec<String>) {
|
||||
self.input_names = input_names;
|
||||
}
|
||||
|
||||
fn set_session(&mut self, session: Option<Session>) {
|
||||
self.session = session;
|
||||
}
|
||||
}
|
||||
|
||||
impl AngleNet {
|
||||
pub fn get_angles(
|
||||
&self,
|
||||
part_imgs: &[image::RgbImage],
|
||||
do_angle: bool,
|
||||
most_angle: bool,
|
||||
cls_thresh: f32,
|
||||
) -> Result<Vec<Angle>, OcrError> {
|
||||
// Pre-allocate — we know exact count upfront.
|
||||
let mut angles = Vec::with_capacity(part_imgs.len());
|
||||
|
||||
if do_angle {
|
||||
for img in part_imgs {
|
||||
let angle = self.get_angle(img, cls_thresh)?;
|
||||
angles.push(angle);
|
||||
}
|
||||
} else {
|
||||
angles.extend(part_imgs.iter().map(|_| Angle::default()));
|
||||
}
|
||||
|
||||
if do_angle && most_angle {
|
||||
let sum: i32 = angles.iter().map(|x| x.index).sum();
|
||||
let half_percent = angles.len() as f32 / 2.0;
|
||||
let most_angle_index = if (sum as f32) < half_percent { 0 } else { 1 };
|
||||
|
||||
for angle in angles.iter_mut() {
|
||||
angle.index = most_angle_index;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(angles)
|
||||
}
|
||||
|
||||
fn get_angle(&self, img_src: &image::RgbImage, cls_thresh: f32) -> Result<Angle, OcrError> {
|
||||
let Some(session) = &self.session else {
|
||||
return Err(OcrError::SessionNotInitialized);
|
||||
};
|
||||
|
||||
let angle_img = image::imageops::resize(
|
||||
img_src,
|
||||
ANGLE_DST_WIDTH,
|
||||
ANGLE_DST_HEIGHT,
|
||||
image::imageops::FilterType::Triangle,
|
||||
);
|
||||
|
||||
let input_tensors =
|
||||
OcrUtils::substract_mean_normalize(&angle_img, &IMAGENET_MEAN_VALUES, &IMAGENET_NORM_VALUES);
|
||||
|
||||
let input_tensors = Tensor::from_array(input_tensors)?;
|
||||
|
||||
// SAFETY: ONNX Runtime C API is thread-safe for concurrent inference.
|
||||
#[allow(unsafe_code)]
|
||||
let outputs = unsafe {
|
||||
let session_ptr = session as *const Session as *mut Session;
|
||||
(*session_ptr).run(inputs![self.input_names[0].as_str() => input_tensors])?
|
||||
};
|
||||
|
||||
let mut angle = Self::score_to_angle(&outputs, ANGLE_COLS)?;
|
||||
|
||||
// Only apply rotation if confidence exceeds threshold (matches PaddleOCR's cls_thresh=0.9)
|
||||
if angle.score < cls_thresh {
|
||||
angle.index = 0; // Keep original orientation when confidence is low
|
||||
}
|
||||
|
||||
Ok(angle)
|
||||
}
|
||||
|
||||
fn score_to_angle(output_tensor: &SessionOutputs, angle_cols: usize) -> Result<Angle, OcrError> {
|
||||
let (_, red_data) = output_tensor.iter().next().ok_or_else(|| {
|
||||
OcrError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"No output tensors found in angle classification session output",
|
||||
))
|
||||
})?;
|
||||
|
||||
let src_data: Vec<f32> = red_data.try_extract_tensor::<f32>()?.1.to_vec();
|
||||
|
||||
let mut angle = Angle::default();
|
||||
let mut max_value = f32::MIN;
|
||||
let mut angle_index = 0;
|
||||
|
||||
for (i, value) in src_data.iter().take(angle_cols).enumerate() {
|
||||
if *value > max_value {
|
||||
max_value = *value;
|
||||
angle_index = i as i32;
|
||||
}
|
||||
}
|
||||
|
||||
angle.index = angle_index;
|
||||
angle.score = max_value;
|
||||
Ok(angle)
|
||||
}
|
||||
}
|
||||
78
crates/kreuzberg-paddle-ocr/src/base_net.rs
Normal file
78
crates/kreuzberg-paddle-ocr/src/base_net.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use ort::session::{
|
||||
Session,
|
||||
builder::{GraphOptimizationLevel, SessionBuilder},
|
||||
};
|
||||
|
||||
use crate::ocr_error::OcrError;
|
||||
|
||||
pub trait BaseNet {
|
||||
fn new() -> Self;
|
||||
|
||||
fn get_session_builder(
|
||||
&self,
|
||||
num_thread: usize,
|
||||
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
|
||||
) -> Result<SessionBuilder, OcrError> {
|
||||
let builder = Session::builder()?;
|
||||
let builder = match builder_fn {
|
||||
Some(custom) => custom(builder)?,
|
||||
None => builder
|
||||
.with_optimization_level(GraphOptimizationLevel::All)
|
||||
.map_err(|e| OcrError::Ort(ort::Error::new(e.message())))?
|
||||
.with_intra_threads(num_thread)
|
||||
.map_err(|e| OcrError::Ort(ort::Error::new(e.message())))?
|
||||
.with_inter_threads(1)
|
||||
.map_err(|e| OcrError::Ort(ort::Error::new(e.message())))?,
|
||||
};
|
||||
|
||||
Ok(builder)
|
||||
}
|
||||
|
||||
fn set_input_names(&mut self, input_names: Vec<String>);
|
||||
fn set_session(&mut self, session: Option<Session>);
|
||||
|
||||
fn init(&mut self, session: Session) {
|
||||
let input_names: Vec<String> = session.inputs().iter().map(|input| input.name().to_string()).collect();
|
||||
|
||||
self.set_input_names(input_names);
|
||||
self.set_session(Some(session));
|
||||
}
|
||||
|
||||
fn init_model(
|
||||
&mut self,
|
||||
path: &str,
|
||||
num_thread: usize,
|
||||
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
|
||||
) -> Result<(), OcrError> {
|
||||
// Wrap ORT session creation in catch_unwind to prevent mutex poisoning
|
||||
// on platforms where ORT initialization can panic (notably Windows).
|
||||
let path_owned = path.to_string();
|
||||
let session = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
let mut builder = self.get_session_builder(num_thread, builder_fn)?;
|
||||
builder.commit_from_file(&path_owned).map_err(OcrError::from)
|
||||
}))
|
||||
.map_err(|_| OcrError::Ort(ort::Error::new("ORT session initialization panicked")))??;
|
||||
self.init(session);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn init_model_from_memory(
|
||||
&mut self,
|
||||
model_bytes: &[u8],
|
||||
num_thread: usize,
|
||||
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
|
||||
) -> Result<(), OcrError> {
|
||||
// Wrap ORT session creation in catch_unwind to prevent mutex poisoning
|
||||
// on platforms where ORT initialization can panic (notably Windows).
|
||||
let session = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
let mut builder = self.get_session_builder(num_thread, builder_fn)?;
|
||||
builder.commit_from_memory(model_bytes).map_err(OcrError::from)
|
||||
}))
|
||||
.map_err(|_| OcrError::Ort(ort::Error::new("ORT session initialization panicked")))??;
|
||||
|
||||
self.init(session);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
33
crates/kreuzberg-paddle-ocr/src/constants.rs
Normal file
33
crates/kreuzberg-paddle-ocr/src/constants.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
//! Shared normalization constants for PaddleOCR preprocessing.
|
||||
//!
|
||||
//! Two normalization schemes are used:
|
||||
//!
|
||||
//! - **ImageNet** (`IMAGENET_MEAN_VALUES` / `IMAGENET_NORM_VALUES`): used by the text
|
||||
//! detection network (`DbNet`) and the angle classifier (`AngleNet`).
|
||||
//! Formula: `(pixel - mean * 255) * (1 / (std * 255))`.
|
||||
//!
|
||||
//! - **CRNN** (`CRNN_MEAN_VALUES` / `CRNN_NORM_VALUES`): used by the text recognition
|
||||
//! network (`CrnnNet`).
|
||||
//! Formula: `(pixel - 127.5) * (1 / 127.5)`.
|
||||
|
||||
/// ImageNet channel means (R, G, B), pre-multiplied by 255.
|
||||
///
|
||||
/// Derived from `[0.485, 0.456, 0.406]` (per-channel ImageNet means).
|
||||
/// Used by `DbNet` (text detection) and `AngleNet` (angle classification).
|
||||
pub(crate) const IMAGENET_MEAN_VALUES: [f32; 3] = [0.485 * 255.0, 0.456 * 255.0, 0.406 * 255.0];
|
||||
|
||||
/// ImageNet channel normalization factors (R, G, B), equal to `1 / (std * 255)`.
|
||||
///
|
||||
/// Derived from `[0.229, 0.224, 0.225]` (per-channel ImageNet standard deviations).
|
||||
/// Used by `DbNet` (text detection) and `AngleNet` (angle classification).
|
||||
pub(crate) const IMAGENET_NORM_VALUES: [f32; 3] = [1.0 / (0.229 * 255.0), 1.0 / (0.224 * 255.0), 1.0 / (0.225 * 255.0)];
|
||||
|
||||
/// CRNN channel means (R, G, B): `127.5` for all channels.
|
||||
///
|
||||
/// Used by `CrnnNet` (text recognition).
|
||||
pub(crate) const CRNN_MEAN_VALUES: [f32; 3] = [127.5, 127.5, 127.5];
|
||||
|
||||
/// CRNN channel normalization factors (R, G, B): `1 / 127.5` for all channels.
|
||||
///
|
||||
/// Used by `CrnnNet` (text recognition).
|
||||
pub(crate) const CRNN_NORM_VALUES: [f32; 3] = [1.0 / 127.5, 1.0 / 127.5, 1.0 / 127.5];
|
||||
393
crates/kreuzberg-paddle-ocr/src/crnn_net.rs
Normal file
393
crates/kreuzberg-paddle-ocr/src/crnn_net.rs
Normal file
@@ -0,0 +1,393 @@
|
||||
use ndarray::Array4;
|
||||
use ort::session::Session;
|
||||
use ort::value::Tensor;
|
||||
use ort::{inputs, session::builder::SessionBuilder};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
base_net::BaseNet,
|
||||
constants::{CRNN_MEAN_VALUES, CRNN_NORM_VALUES},
|
||||
ocr_error::OcrError,
|
||||
ocr_result::TextLine,
|
||||
ocr_utils::OcrUtils,
|
||||
};
|
||||
|
||||
const CRNN_DST_HEIGHT: u32 = 48;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CrnnNet {
|
||||
session: Option<Session>,
|
||||
keys: Vec<String>,
|
||||
input_names: Vec<String>,
|
||||
}
|
||||
|
||||
impl BaseNet for CrnnNet {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
session: None,
|
||||
keys: Vec::new(),
|
||||
input_names: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn set_input_names(&mut self, input_names: Vec<String>) {
|
||||
self.input_names = input_names;
|
||||
}
|
||||
|
||||
fn set_session(&mut self, session: Option<Session>) {
|
||||
self.session = session;
|
||||
}
|
||||
}
|
||||
|
||||
impl CrnnNet {
|
||||
pub fn init_model(
|
||||
&mut self,
|
||||
path: &str,
|
||||
num_thread: usize,
|
||||
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
|
||||
) -> Result<(), OcrError> {
|
||||
BaseNet::init_model(self, path, num_thread, builder_fn)?;
|
||||
|
||||
self.keys = self.get_keys()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn init_model_dict_file(
|
||||
&mut self,
|
||||
path: &str,
|
||||
num_thread: usize,
|
||||
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
|
||||
dict_file_path: &str,
|
||||
) -> Result<(), OcrError> {
|
||||
BaseNet::init_model(self, path, num_thread, builder_fn)?;
|
||||
|
||||
self.read_keys_from_file(dict_file_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn init_model_from_memory(
|
||||
&mut self,
|
||||
model_bytes: &[u8],
|
||||
num_thread: usize,
|
||||
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
|
||||
) -> Result<(), OcrError> {
|
||||
BaseNet::init_model_from_memory(self, model_bytes, num_thread, builder_fn)?;
|
||||
|
||||
self.keys = self.get_keys()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_keys(&mut self) -> Result<Vec<String>, OcrError> {
|
||||
let session = self.session.as_ref().ok_or(OcrError::SessionNotInitialized)?;
|
||||
|
||||
let metadata = session.metadata()?;
|
||||
let model_charater_list = metadata.custom("character").ok_or_else(|| {
|
||||
OcrError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
"crnn_net character not found in metadata",
|
||||
))
|
||||
})?;
|
||||
|
||||
// PP-OCRv5 model metadata already includes the CTC blank token ("#") at
|
||||
// index 0 and the space token (" ") at the end. Do NOT prepend/append
|
||||
// extra tokens — doing so shifts every character index by one and
|
||||
// produces garbled output.
|
||||
let keys: Vec<String> = model_charater_list.split('\n').map(|s: &str| s.to_string()).collect();
|
||||
|
||||
Ok(keys)
|
||||
}
|
||||
|
||||
fn read_keys_from_file(&mut self, path: &str) -> Result<(), OcrError> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
|
||||
// PP-OCRv5 dict files already include the CTC blank token ("#") at
|
||||
// index 0 and the space token (" ") at the end. Do NOT prepend/append
|
||||
// extra tokens — doing so shifts every character index by one and
|
||||
// produces garbled output.
|
||||
let keys: Vec<String> = content.split('\n').map(|s| s.to_string()).collect();
|
||||
|
||||
self.keys = keys;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_text_lines(
|
||||
&self,
|
||||
part_imgs: &[image::RgbImage],
|
||||
angle_rollback_records: &HashMap<usize, image::RgbImage>,
|
||||
angle_rollback_threshold: f32,
|
||||
batch_size: u32,
|
||||
) -> Result<Vec<TextLine>, OcrError> {
|
||||
if part_imgs.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Batch recognition: sort by aspect ratio, batch, pad to max width
|
||||
let mut text_lines = self.get_text_lines_batched(part_imgs, batch_size)?;
|
||||
|
||||
// Angle rollback: re-recognize individual images that scored poorly
|
||||
for (index, text_line) in text_lines.iter_mut().enumerate() {
|
||||
if (text_line.text_score.is_nan() || text_line.text_score < angle_rollback_threshold)
|
||||
&& let Some(angle_rollback_record) = angle_rollback_records.get(&index)
|
||||
{
|
||||
*text_line = self.get_text_line(angle_rollback_record)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(text_lines)
|
||||
}
|
||||
|
||||
/// Batch recognition: sort crops by width, group into batches, pad to max width,
|
||||
/// run single ONNX inference per batch. Matches PaddleOCR/RapidOCR batching strategy.
|
||||
fn get_text_lines_batched(
|
||||
&self,
|
||||
part_imgs: &[image::RgbImage],
|
||||
batch_size: u32,
|
||||
) -> Result<Vec<TextLine>, OcrError> {
|
||||
let session = self.session.as_ref().ok_or(OcrError::SessionNotInitialized)?;
|
||||
let batch_size = (batch_size as usize).max(1);
|
||||
|
||||
// Compute target widths and sort indices by aspect ratio (width/height)
|
||||
let mut indexed_widths: Vec<(usize, u32)> = part_imgs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, img)| {
|
||||
let scale = CRNN_DST_HEIGHT as f32 / img.height().max(1) as f32;
|
||||
let dst_width = (img.width() as f32 * scale).ceil() as u32;
|
||||
(i, dst_width.max(1))
|
||||
})
|
||||
.collect();
|
||||
indexed_widths.sort_by_key(|&(_, w)| w);
|
||||
|
||||
let mut results: Vec<(usize, TextLine)> = Vec::with_capacity(part_imgs.len());
|
||||
|
||||
// Process in batches
|
||||
for chunk in indexed_widths.chunks(batch_size) {
|
||||
if chunk.len() == 1 {
|
||||
// Single image — use existing path (no padding overhead)
|
||||
let (orig_idx, _) = chunk[0];
|
||||
let text_line = self.get_text_line(&part_imgs[orig_idx])?;
|
||||
results.push((orig_idx, text_line));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find max width in this batch
|
||||
let max_width = chunk.iter().map(|&(_, w)| w).max().unwrap_or(1);
|
||||
|
||||
// Build batch tensor [N, 3, 48, max_width] with zero-padding
|
||||
let n = chunk.len();
|
||||
let mut batch_data = Array4::<f32>::zeros((n, 3, CRNN_DST_HEIGHT as usize, max_width as usize));
|
||||
|
||||
for (batch_idx, &(orig_idx, dst_width)) in chunk.iter().enumerate() {
|
||||
let img = &part_imgs[orig_idx];
|
||||
let resized =
|
||||
image::imageops::resize(img, dst_width, CRNN_DST_HEIGHT, image::imageops::FilterType::Triangle);
|
||||
|
||||
// Normalize and fill into batch tensor (zero-padded on right).
|
||||
// Use raw slice access instead of per-pixel get_pixel() to
|
||||
// eliminate millions of bounds checks in the hot loop.
|
||||
let cols = resized.width() as usize;
|
||||
let rows = resized.height() as usize;
|
||||
let raw = resized.as_raw();
|
||||
assert_eq!(raw.len(), rows * cols * 3, "unexpected image buffer size");
|
||||
let adjusted = [
|
||||
CRNN_MEAN_VALUES[0] * CRNN_NORM_VALUES[0],
|
||||
CRNN_MEAN_VALUES[1] * CRNN_NORM_VALUES[1],
|
||||
CRNN_MEAN_VALUES[2] * CRNN_NORM_VALUES[2],
|
||||
];
|
||||
for r in 0..rows {
|
||||
for c in 0..cols {
|
||||
let base = r * cols * 3 + c * 3;
|
||||
for ch in 0..3 {
|
||||
batch_data[[batch_idx, ch, r, c]] =
|
||||
raw[base + ch] as f32 * CRNN_NORM_VALUES[ch] - adjusted[ch];
|
||||
}
|
||||
}
|
||||
}
|
||||
// Remaining columns stay zero (padding)
|
||||
}
|
||||
|
||||
let input_tensor = Tensor::from_array(batch_data)?;
|
||||
|
||||
// SAFETY: ONNX Runtime C API is thread-safe for concurrent inference.
|
||||
#[allow(unsafe_code)]
|
||||
let outputs = unsafe {
|
||||
let session_ptr = session as *const Session as *mut Session;
|
||||
(*session_ptr).run(inputs![self.input_names[0].as_str() => input_tensor])?
|
||||
};
|
||||
|
||||
let (_, output_value) = outputs.iter().next().ok_or_else(|| {
|
||||
OcrError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"No output tensors found in batched CRNN session output",
|
||||
))
|
||||
})?;
|
||||
|
||||
let (shape, flat_data) = output_value.try_extract_tensor::<f32>()?;
|
||||
// Shape: [batch, timesteps, num_classes]
|
||||
let batch_dim = *shape.first().unwrap_or(&1) as usize;
|
||||
let timesteps = *shape.get(1).unwrap_or(&0) as usize;
|
||||
let num_classes = *shape.get(2).unwrap_or(&0) as usize;
|
||||
|
||||
for (batch_idx, item) in chunk.iter().enumerate().take(batch_dim.min(n)) {
|
||||
let offset = batch_idx * timesteps * num_classes;
|
||||
let slice = &flat_data[offset..offset + timesteps * num_classes];
|
||||
let text_line = Self::score_to_text_line(slice, timesteps, num_classes, &self.keys)?;
|
||||
results.push((item.0, text_line));
|
||||
}
|
||||
}
|
||||
|
||||
// Reorder results back to original index order
|
||||
results.sort_by_key(|&(idx, _)| idx);
|
||||
Ok(results.into_iter().map(|(_, tl)| tl).collect())
|
||||
}
|
||||
|
||||
fn get_text_line(&self, img_src: &image::RgbImage) -> Result<TextLine, OcrError> {
|
||||
let Some(session) = &self.session else {
|
||||
return Err(OcrError::SessionNotInitialized);
|
||||
};
|
||||
|
||||
let scale = CRNN_DST_HEIGHT as f32 / img_src.height() as f32;
|
||||
let dst_width = (img_src.width() as f32 * scale).ceil() as u32;
|
||||
|
||||
let src_resize = image::imageops::resize(
|
||||
img_src,
|
||||
dst_width,
|
||||
CRNN_DST_HEIGHT,
|
||||
image::imageops::FilterType::Triangle,
|
||||
);
|
||||
|
||||
let input_tensors = OcrUtils::substract_mean_normalize(&src_resize, &CRNN_MEAN_VALUES, &CRNN_NORM_VALUES);
|
||||
|
||||
let input_tensors = Tensor::from_array(input_tensors)?;
|
||||
|
||||
// SAFETY: ONNX Runtime C API is thread-safe for concurrent inference.
|
||||
#[allow(unsafe_code)]
|
||||
let outputs = unsafe {
|
||||
let session_ptr = session as *const Session as *mut Session;
|
||||
(*session_ptr).run(inputs![self.input_names[0].as_str() => input_tensors])?
|
||||
};
|
||||
|
||||
let (_, red_data) = outputs.iter().next().ok_or_else(|| {
|
||||
OcrError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"No output tensors found in CRNN session output",
|
||||
))
|
||||
})?;
|
||||
|
||||
let (shape, src_data) = red_data.try_extract_tensor::<f32>()?;
|
||||
let dimensions = shape;
|
||||
let height = *dimensions.get(1).ok_or_else(|| {
|
||||
OcrError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"CRNN output tensor missing height dimension (index 1)",
|
||||
))
|
||||
})? as usize;
|
||||
let width = *dimensions.get(2).ok_or_else(|| {
|
||||
OcrError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"CRNN output tensor missing width dimension (index 2)",
|
||||
))
|
||||
})? as usize;
|
||||
let src_data: Vec<f32> = src_data.to_vec();
|
||||
|
||||
Self::score_to_text_line(&src_data, height, width, &self.keys)
|
||||
}
|
||||
|
||||
fn score_to_text_line(
|
||||
output_data: &[f32],
|
||||
height: usize,
|
||||
width: usize,
|
||||
keys: &[String],
|
||||
) -> Result<TextLine, OcrError> {
|
||||
let mut text_line = TextLine::default();
|
||||
let mut last_index = 0;
|
||||
|
||||
let mut text_score_sum = 0.0;
|
||||
let mut text_score_count = 0;
|
||||
for i in 0..height {
|
||||
let start = i * width;
|
||||
let stop = (i + 1) * width;
|
||||
let slice = &output_data[start..stop.min(output_data.len())];
|
||||
|
||||
let (max_index, max_value) =
|
||||
slice
|
||||
.iter()
|
||||
.enumerate()
|
||||
.fold((0, f32::MIN), |(max_idx, max_val), (idx, &val)| {
|
||||
if val > max_val { (idx, val) } else { (max_idx, max_val) }
|
||||
});
|
||||
|
||||
if max_index > 0 && max_index < keys.len() && !(i > 0 && max_index == last_index) {
|
||||
text_line.text.push_str(&keys[max_index]);
|
||||
text_score_sum += max_value;
|
||||
text_score_count += 1;
|
||||
}
|
||||
last_index = max_index;
|
||||
}
|
||||
|
||||
// Avoid division by zero: handle case where no characters were found
|
||||
text_line.text_score = if text_score_count > 0 {
|
||||
text_score_sum / text_score_count as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
Ok(text_line)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_score_to_text_line_skips_blank_index() {
|
||||
// keys[0] = "#" (CTC blank), keys[1] = "a", keys[2] = "b"
|
||||
let keys = vec!["#".to_string(), "a".to_string(), "b".to_string()];
|
||||
// 3 timesteps, 3 classes each. Simulate: blank, "a", "b"
|
||||
let output = vec![
|
||||
1.0, 0.0, 0.0, // timestep 0: max at index 0 (blank) -> skip
|
||||
0.0, 0.9, 0.1, // timestep 1: max at index 1 ("a")
|
||||
0.0, 0.1, 0.8, // timestep 2: max at index 2 ("b")
|
||||
];
|
||||
let result = CrnnNet::score_to_text_line(&output, 3, 3, &keys).unwrap();
|
||||
assert_eq!(result.text, "ab");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_score_to_text_line_deduplicates_consecutive() {
|
||||
let keys = vec!["#".to_string(), "h".to_string(), "i".to_string()];
|
||||
// 4 timesteps: "h", "h", "i", "i" -> should deduplicate to "hi"
|
||||
let output = vec![
|
||||
0.0, 0.9, 0.0, // "h"
|
||||
0.0, 0.8, 0.0, // "h" again (same index, skip)
|
||||
0.0, 0.0, 0.9, // "i"
|
||||
0.0, 0.0, 0.8, // "i" again (same index, skip)
|
||||
];
|
||||
let result = CrnnNet::score_to_text_line(&output, 4, 3, &keys).unwrap();
|
||||
assert_eq!(result.text, "hi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_keys_from_file_preserves_dict_layout() {
|
||||
let dir = std::env::temp_dir().join("kreuzberg_test_dict");
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
let dict_path = dir.join("test_dict.txt");
|
||||
// PP-OCRv5 dict files already include "#" (blank) at start and " " at end.
|
||||
std::fs::write(&dict_path, "#\na\nb\nc\n ").unwrap();
|
||||
|
||||
let mut net = CrnnNet::new();
|
||||
net.read_keys_from_file(dict_path.to_str().unwrap()).unwrap();
|
||||
|
||||
// Dict is loaded as-is: ["#", "a", "b", "c", " "]
|
||||
assert_eq!(net.keys[0], "#");
|
||||
assert_eq!(net.keys[1], "a");
|
||||
assert_eq!(net.keys[2], "b");
|
||||
assert_eq!(net.keys[3], "c");
|
||||
assert_eq!(net.keys[net.keys.len() - 1], " ");
|
||||
|
||||
std::fs::remove_dir_all(&dir).ok();
|
||||
}
|
||||
}
|
||||
421
crates/kreuzberg-paddle-ocr/src/db_net.rs
Normal file
421
crates/kreuzberg-paddle-ocr/src/db_net.rs
Normal file
@@ -0,0 +1,421 @@
|
||||
use crate::{
|
||||
base_net::BaseNet,
|
||||
constants::{IMAGENET_MEAN_VALUES, IMAGENET_NORM_VALUES},
|
||||
ocr_error::OcrError,
|
||||
ocr_result::{self, TextBox},
|
||||
ocr_utils::OcrUtils,
|
||||
scale_param::ScaleParam,
|
||||
};
|
||||
use geo_clipper::{Clipper, EndType, JoinType};
|
||||
use geo_types::{Coord, LineString, Polygon};
|
||||
use ort::{inputs, session::SessionOutputs};
|
||||
use ort::{session::Session, value::Tensor};
|
||||
use std::cmp::Ordering;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DbNet {
|
||||
session: Option<Session>,
|
||||
input_names: Vec<String>,
|
||||
}
|
||||
|
||||
impl BaseNet for DbNet {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
session: None,
|
||||
input_names: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn set_input_names(&mut self, input_names: Vec<String>) {
|
||||
self.input_names = input_names;
|
||||
}
|
||||
|
||||
fn set_session(&mut self, session: Option<Session>) {
|
||||
self.session = session;
|
||||
}
|
||||
}
|
||||
|
||||
impl DbNet {
|
||||
pub fn get_text_boxes(
|
||||
&self,
|
||||
img_src: &image::RgbImage,
|
||||
scale: &ScaleParam,
|
||||
box_score_thresh: f32,
|
||||
box_thresh: f32,
|
||||
un_clip_ratio: f32,
|
||||
thresh: f32,
|
||||
) -> Result<Vec<TextBox>, OcrError> {
|
||||
let Some(session) = &self.session else {
|
||||
return Err(OcrError::SessionNotInitialized);
|
||||
};
|
||||
|
||||
let src_resize = image::imageops::resize(
|
||||
img_src,
|
||||
scale.dst_width,
|
||||
scale.dst_height,
|
||||
image::imageops::FilterType::Triangle,
|
||||
);
|
||||
|
||||
let input_tensors =
|
||||
OcrUtils::substract_mean_normalize(&src_resize, &IMAGENET_MEAN_VALUES, &IMAGENET_NORM_VALUES);
|
||||
|
||||
let tensor = Tensor::from_array(input_tensors)?;
|
||||
|
||||
// SAFETY: ONNX Runtime's C API (OrtRun) is thread-safe for concurrent inference
|
||||
// on the same session. The ort crate's `&mut self` requirement is overly
|
||||
// conservative. This matches the pattern used in kreuzberg's embedding engine.
|
||||
#[allow(unsafe_code)]
|
||||
let outputs = unsafe {
|
||||
let session_ptr = session as *const Session as *mut Session;
|
||||
(*session_ptr).run(inputs![self.input_names[0].as_str() => tensor])?
|
||||
};
|
||||
|
||||
let text_boxes = Self::get_text_boxes_core(
|
||||
&outputs,
|
||||
src_resize.height(),
|
||||
src_resize.width(),
|
||||
&ScaleParam::new(
|
||||
scale.src_width,
|
||||
scale.src_height,
|
||||
scale.dst_width,
|
||||
scale.dst_height,
|
||||
scale.scale_width,
|
||||
scale.scale_height,
|
||||
),
|
||||
box_score_thresh,
|
||||
box_thresh,
|
||||
un_clip_ratio,
|
||||
thresh,
|
||||
)?;
|
||||
|
||||
Ok(text_boxes)
|
||||
}
|
||||
|
||||
fn get_text_boxes_core(
|
||||
output_tensor: &SessionOutputs,
|
||||
rows: u32,
|
||||
cols: u32,
|
||||
s: &ScaleParam,
|
||||
box_score_thresh: f32,
|
||||
_box_thresh: f32,
|
||||
un_clip_ratio: f32,
|
||||
thresh: f32,
|
||||
) -> Result<Vec<TextBox>, OcrError> {
|
||||
let max_side_thresh = 3.0;
|
||||
|
||||
let (_, red_data) = output_tensor.iter().next().ok_or_else(|| {
|
||||
OcrError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"No output tensors found in session output",
|
||||
))
|
||||
})?;
|
||||
|
||||
let pred_data: Vec<f32> = red_data.try_extract_tensor::<f32>()?.1.to_vec();
|
||||
|
||||
let cbuf_data: Vec<u8> = pred_data.iter().map(|pixel| (pixel * 255.0) as u8).collect();
|
||||
|
||||
let pred_img: image::ImageBuffer<image::Luma<f32>, Vec<f32>> =
|
||||
image::ImageBuffer::from_vec(cols, rows, pred_data).ok_or_else(|| {
|
||||
OcrError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!(
|
||||
"Failed to create image buffer from predictions: {} x {} dimensions may be invalid",
|
||||
cols, rows
|
||||
),
|
||||
))
|
||||
})?;
|
||||
|
||||
let cbuf_img = image::GrayImage::from_vec(cols, rows, cbuf_data).ok_or_else(|| {
|
||||
OcrError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!(
|
||||
"Failed to create grayscale image buffer: {} x {} dimensions may be invalid",
|
||||
cols, rows
|
||||
),
|
||||
))
|
||||
})?;
|
||||
|
||||
let threshold_img = imageproc::contrast::threshold(
|
||||
&cbuf_img,
|
||||
(thresh * 255.0) as u8,
|
||||
imageproc::contrast::ThresholdType::Binary,
|
||||
);
|
||||
|
||||
// RapidOCR and PaddleOCR reference do NOT apply dilation before contour extraction.
|
||||
// Dilation merges adjacent text regions, causing word concatenation.
|
||||
let img_contours: Vec<imageproc::contours::Contour<i32>> = imageproc::contours::find_contours(&threshold_img);
|
||||
|
||||
// Pre-allocate based on contour count to avoid repeated reallocations.
|
||||
let mut rs_boxes = Vec::with_capacity(img_contours.len());
|
||||
|
||||
for contour in img_contours {
|
||||
if contour.points.len() <= 2 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut max_side = 0.0;
|
||||
let min_box = Self::get_mini_box(&contour.points, &mut max_side)?;
|
||||
if max_side < max_side_thresh {
|
||||
continue;
|
||||
}
|
||||
|
||||
let score = Self::get_score(&contour, &pred_img)?;
|
||||
if score < box_score_thresh {
|
||||
continue;
|
||||
}
|
||||
|
||||
let clip_box = Self::unclip(&min_box, un_clip_ratio)?;
|
||||
if clip_box.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut clip_contour = Vec::new();
|
||||
for point in &clip_box {
|
||||
clip_contour.push(*point);
|
||||
}
|
||||
|
||||
let mut max_side_clip = 0.0;
|
||||
let clip_min_box = Self::get_mini_box(&clip_contour, &mut max_side_clip)?;
|
||||
if max_side_clip < max_side_thresh + 2.0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut final_points = Vec::new();
|
||||
for item in clip_min_box {
|
||||
let x = (item.x / s.scale_width) as u32;
|
||||
let ptx = x.min(s.src_width);
|
||||
|
||||
let y = (item.y / s.scale_height) as u32;
|
||||
let pty = y.min(s.src_height);
|
||||
|
||||
final_points.push(ocr_result::Point { x: ptx, y: pty });
|
||||
}
|
||||
|
||||
let text_box = TextBox {
|
||||
score,
|
||||
points: final_points,
|
||||
};
|
||||
|
||||
rs_boxes.push(text_box);
|
||||
}
|
||||
|
||||
Ok(rs_boxes)
|
||||
}
|
||||
|
||||
fn get_mini_box(
|
||||
contour_points: &[imageproc::point::Point<i32>],
|
||||
min_edge_size: &mut f32,
|
||||
) -> Result<Vec<imageproc::point::Point<f32>>, OcrError> {
|
||||
let rect = imageproc::geometry::min_area_rect(contour_points);
|
||||
|
||||
let mut rect_points: Vec<imageproc::point::Point<f32>> = rect
|
||||
.iter()
|
||||
.map(|p| imageproc::point::Point::new(p.x as f32, p.y as f32))
|
||||
.collect();
|
||||
|
||||
// Direct multiplication instead of .powi(2) — avoids function call overhead.
|
||||
let dx_w = rect_points[0].x - rect_points[1].x;
|
||||
let dy_w = rect_points[0].y - rect_points[1].y;
|
||||
let width = (dx_w * dx_w + dy_w * dy_w).sqrt();
|
||||
let dx_h = rect_points[1].x - rect_points[2].x;
|
||||
let dy_h = rect_points[1].y - rect_points[2].y;
|
||||
let height = (dx_h * dx_h + dy_h * dy_h).sqrt();
|
||||
|
||||
*min_edge_size = width.min(height);
|
||||
|
||||
rect_points.sort_by(|a, b| {
|
||||
if a.x > b.x {
|
||||
return Ordering::Greater;
|
||||
}
|
||||
if a.x == b.x {
|
||||
return Ordering::Equal;
|
||||
}
|
||||
Ordering::Less
|
||||
});
|
||||
|
||||
let mut box_points = Vec::new();
|
||||
let index_1;
|
||||
let index_4;
|
||||
if rect_points[1].y > rect_points[0].y {
|
||||
index_1 = 0;
|
||||
index_4 = 1;
|
||||
} else {
|
||||
index_1 = 1;
|
||||
index_4 = 0;
|
||||
}
|
||||
|
||||
let index_2;
|
||||
let index_3;
|
||||
if rect_points[3].y > rect_points[2].y {
|
||||
index_2 = 2;
|
||||
index_3 = 3;
|
||||
} else {
|
||||
index_2 = 3;
|
||||
index_3 = 2;
|
||||
}
|
||||
|
||||
box_points.push(rect_points[index_1]);
|
||||
box_points.push(rect_points[index_2]);
|
||||
box_points.push(rect_points[index_3]);
|
||||
box_points.push(rect_points[index_4]);
|
||||
|
||||
Ok(box_points)
|
||||
}
|
||||
|
||||
fn get_score(
|
||||
contour: &imageproc::contours::Contour<i32>,
|
||||
f_map_mat: &image::ImageBuffer<image::Luma<f32>, Vec<f32>>,
|
||||
) -> Result<f32, OcrError> {
|
||||
// Initialize boundary values
|
||||
let mut xmin = i32::MAX;
|
||||
let mut xmax = i32::MIN;
|
||||
let mut ymin = i32::MAX;
|
||||
let mut ymax = i32::MIN;
|
||||
|
||||
// Find contour bounding box
|
||||
for point in contour.points.iter() {
|
||||
let x = point.x;
|
||||
let y = point.y;
|
||||
|
||||
if x < xmin {
|
||||
xmin = x;
|
||||
}
|
||||
if x > xmax {
|
||||
xmax = x;
|
||||
}
|
||||
if y < ymin {
|
||||
ymin = y;
|
||||
}
|
||||
if y > ymax {
|
||||
ymax = y;
|
||||
}
|
||||
}
|
||||
|
||||
let width = f_map_mat.width() as i32;
|
||||
let height = f_map_mat.height() as i32;
|
||||
|
||||
xmin = xmin.max(0).min(width - 1);
|
||||
xmax = xmax.max(0).min(width - 1);
|
||||
ymin = ymin.max(0).min(height - 1);
|
||||
ymax = ymax.max(0).min(height - 1);
|
||||
|
||||
let roi_width = xmax - xmin + 1;
|
||||
let roi_height = ymax - ymin + 1;
|
||||
|
||||
if roi_width <= 0 || roi_height <= 0 {
|
||||
return Ok(0.0);
|
||||
}
|
||||
|
||||
let mut mask = image::GrayImage::new(roi_width as u32, roi_height as u32);
|
||||
|
||||
let mut pts = Vec::<imageproc::point::Point<i32>>::new();
|
||||
for point in contour.points.iter() {
|
||||
pts.push(imageproc::point::Point::new(point.x - xmin, point.y - ymin));
|
||||
}
|
||||
|
||||
imageproc::drawing::draw_polygon_mut(&mut mask, pts.as_slice(), image::Luma([255]));
|
||||
|
||||
let cropped_img =
|
||||
image::imageops::crop_imm(f_map_mat, xmin as u32, ymin as u32, roi_width as u32, roi_height as u32)
|
||||
.to_image();
|
||||
|
||||
let mean = OcrUtils::calculate_mean_with_mask(&cropped_img, &mask);
|
||||
|
||||
Ok(mean)
|
||||
}
|
||||
|
||||
fn unclip(
|
||||
box_points: &[imageproc::point::Point<f32>],
|
||||
unclip_ratio: f32,
|
||||
) -> Result<Vec<imageproc::point::Point<i32>>, OcrError> {
|
||||
// Direct multiplication instead of .powi(2) — avoids function call overhead.
|
||||
let dx_w = box_points[0].x - box_points[1].x;
|
||||
let dy_w = box_points[0].y - box_points[1].y;
|
||||
let clip_rect_width = (dx_w * dx_w + dy_w * dy_w).sqrt();
|
||||
let dx_h = box_points[1].x - box_points[2].x;
|
||||
let dy_h = box_points[1].y - box_points[2].y;
|
||||
let clip_rect_height = (dx_h * dx_h + dy_h * dy_h).sqrt();
|
||||
|
||||
if clip_rect_height < 1.001 && clip_rect_width < 1.001 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut the_cliper_pts = Vec::new();
|
||||
for pt in box_points {
|
||||
let a1 = Coord {
|
||||
x: pt.x as f64,
|
||||
y: pt.y as f64,
|
||||
};
|
||||
the_cliper_pts.push(a1);
|
||||
}
|
||||
|
||||
let area = Self::signed_polygon_area(box_points).abs();
|
||||
let length = Self::length_of_points(box_points);
|
||||
let distance = area * unclip_ratio / length as f32;
|
||||
|
||||
let co = Polygon::new(LineString::new(the_cliper_pts), vec![]);
|
||||
let solution = co
|
||||
.offset(distance as f64, JoinType::Round(2.0), EndType::ClosedPolygon, 1.0)
|
||||
.0;
|
||||
|
||||
if solution.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let first_polygon = solution.first().ok_or_else(|| {
|
||||
OcrError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"Polygon solution list was empty after offset operation",
|
||||
))
|
||||
})?;
|
||||
|
||||
let ret_pts: Vec<_> = first_polygon
|
||||
.exterior()
|
||||
.points()
|
||||
.map(|ip| imageproc::point::Point::new(ip.x() as i32, ip.y() as i32))
|
||||
.collect();
|
||||
|
||||
Ok(ret_pts)
|
||||
}
|
||||
|
||||
fn signed_polygon_area(points: &[imageproc::point::Point<f32>]) -> f32 {
|
||||
let num_points = points.len();
|
||||
let mut pts = Vec::with_capacity(num_points + 1);
|
||||
pts.extend_from_slice(points);
|
||||
pts.push(points[0]);
|
||||
|
||||
let mut area = 0.0;
|
||||
for i in 0..num_points {
|
||||
area += (pts[i + 1].x - pts[i].x) * (pts[i + 1].y + pts[i].y) / 2.0;
|
||||
}
|
||||
|
||||
area
|
||||
}
|
||||
|
||||
fn length_of_points(box_points: &[imageproc::point::Point<f32>]) -> f64 {
|
||||
if box_points.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut length = 0.0;
|
||||
let mut x0 = box_points[0].x as f64;
|
||||
let mut y0 = box_points[0].y as f64;
|
||||
|
||||
for pt in &box_points[1..] {
|
||||
let x1 = pt.x as f64;
|
||||
let y1 = pt.y as f64;
|
||||
let dx = x1 - x0;
|
||||
let dy = y1 - y0;
|
||||
length += (dx * dx + dy * dy).sqrt();
|
||||
x0 = x1;
|
||||
y0 = y1;
|
||||
}
|
||||
|
||||
// Closing segment back to first point
|
||||
let dx = box_points[0].x as f64 - x0;
|
||||
let dy = box_points[0].y as f64 - y0;
|
||||
length += (dx * dx + dy * dy).sqrt();
|
||||
|
||||
length
|
||||
}
|
||||
}
|
||||
32
crates/kreuzberg-paddle-ocr/src/lib.rs
Normal file
32
crates/kreuzberg-paddle-ocr/src/lib.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
//! # kreuzberg-paddle-ocr
|
||||
//!
|
||||
//! PaddleOCR via ONNX Runtime for Kreuzberg - high-performance text detection and recognition.
|
||||
//!
|
||||
//! This crate is vendored from [paddle-ocr-rs](https://github.com/mg-chao/paddle-ocr-rs)
|
||||
//! by mg-chao, with modifications for Kreuzberg integration.
|
||||
//!
|
||||
//! ## ONNX Runtime Requirement
|
||||
//!
|
||||
//! Requires **ONNX Runtime 1.24+** at runtime.
|
||||
//!
|
||||
//! ## Original License
|
||||
//!
|
||||
//! The original paddle-ocr-rs is licensed under Apache-2.0.
|
||||
//! This vendored version is relicensed to MIT with the original author's copyright retained.
|
||||
|
||||
#![allow(clippy::too_many_arguments)]
|
||||
|
||||
pub mod angle_net;
|
||||
pub mod base_net;
|
||||
pub(crate) mod constants;
|
||||
pub mod crnn_net;
|
||||
pub mod db_net;
|
||||
pub mod ocr_error;
|
||||
pub mod ocr_lite;
|
||||
pub mod ocr_result;
|
||||
pub mod ocr_utils;
|
||||
pub mod scale_param;
|
||||
|
||||
pub use ocr_error::OcrError;
|
||||
pub use ocr_lite::OcrLite;
|
||||
pub use ocr_result::{Angle, OcrResult, Point, TextBlock, TextBox, TextLine};
|
||||
13
crates/kreuzberg-paddle-ocr/src/ocr_error.rs
Normal file
13
crates/kreuzberg-paddle-ocr/src/ocr_error.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum OcrError {
|
||||
#[error("Ort error: {0}")]
|
||||
Ort(#[from] ort::Error),
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("Image error: {0}")]
|
||||
ImageError(#[from] image::ImageError),
|
||||
#[error("Session not initialized")]
|
||||
SessionNotInitialized,
|
||||
}
|
||||
447
crates/kreuzberg-paddle-ocr/src/ocr_lite.rs
Normal file
447
crates/kreuzberg-paddle-ocr/src/ocr_lite.rs
Normal file
@@ -0,0 +1,447 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use image::ImageBuffer;
|
||||
use ort::session::builder::SessionBuilder;
|
||||
|
||||
use crate::{
|
||||
angle_net::AngleNet,
|
||||
base_net::BaseNet,
|
||||
crnn_net::CrnnNet,
|
||||
db_net::DbNet,
|
||||
ocr_error::OcrError,
|
||||
ocr_result::{OcrResult, Point, TextBlock, TextBox},
|
||||
ocr_utils::OcrUtils,
|
||||
scale_param::ScaleParam,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OcrLite {
|
||||
db_net: DbNet,
|
||||
angle_net: AngleNet,
|
||||
crnn_net: CrnnNet,
|
||||
}
|
||||
|
||||
// SAFETY: OcrLite inference methods (&self) use unsafe pointer casts to call
|
||||
// ort Session::run, which is thread-safe at the ONNX Runtime C API level.
|
||||
// After initialization (&mut self), no mutable state is accessed during inference.
|
||||
unsafe impl Send for OcrLite {}
|
||||
unsafe impl Sync for OcrLite {}
|
||||
|
||||
impl Default for OcrLite {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl OcrLite {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
db_net: DbNet::new(),
|
||||
angle_net: AngleNet::new(),
|
||||
crnn_net: CrnnNet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init_models(
|
||||
&mut self,
|
||||
det_path: &str,
|
||||
cls_path: &str,
|
||||
rec_path: &str,
|
||||
num_thread: usize,
|
||||
) -> Result<(), OcrError> {
|
||||
self.db_net.init_model(det_path, num_thread, None)?;
|
||||
self.angle_net.init_model(cls_path, num_thread, None)?;
|
||||
self.crnn_net.init_model(rec_path, num_thread, None)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn init_models_with_dict(
|
||||
&mut self,
|
||||
det_path: &str,
|
||||
cls_path: &str,
|
||||
rec_path: &str,
|
||||
dict_path: &str,
|
||||
num_thread: usize,
|
||||
) -> Result<(), OcrError> {
|
||||
self.db_net.init_model(det_path, num_thread, None)?;
|
||||
self.angle_net.init_model(cls_path, num_thread, None)?;
|
||||
self.crnn_net
|
||||
.init_model_dict_file(rec_path, num_thread, None, dict_path)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn init_models_custom(
|
||||
&mut self,
|
||||
det_path: &str,
|
||||
cls_path: &str,
|
||||
rec_path: &str,
|
||||
builder_fn: fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>,
|
||||
) -> Result<(), OcrError> {
|
||||
self.db_net.init_model(det_path, 0, Some(builder_fn))?;
|
||||
self.angle_net.init_model(cls_path, 0, Some(builder_fn))?;
|
||||
self.crnn_net.init_model(rec_path, 0, Some(builder_fn))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Initialize models with dictionary file and custom session builder.
|
||||
///
|
||||
/// Combines `init_models_with_dict` and `init_models_custom`: loads the
|
||||
/// dictionary for the recognition model while applying a custom ORT
|
||||
/// session builder (e.g. for GPU execution providers).
|
||||
pub fn init_models_with_dict_custom(
|
||||
&mut self,
|
||||
det_path: &str,
|
||||
cls_path: &str,
|
||||
rec_path: &str,
|
||||
dict_path: &str,
|
||||
num_thread: usize,
|
||||
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
|
||||
) -> Result<(), OcrError> {
|
||||
self.db_net.init_model(det_path, num_thread, builder_fn)?;
|
||||
self.angle_net.init_model(cls_path, num_thread, builder_fn)?;
|
||||
self.crnn_net
|
||||
.init_model_dict_file(rec_path, num_thread, builder_fn, dict_path)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn init_models_from_memory(
|
||||
&mut self,
|
||||
det_bytes: &[u8],
|
||||
cls_bytes: &[u8],
|
||||
rec_bytes: &[u8],
|
||||
num_thread: usize,
|
||||
) -> Result<(), OcrError> {
|
||||
self.db_net.init_model_from_memory(det_bytes, num_thread, None)?;
|
||||
self.angle_net.init_model_from_memory(cls_bytes, num_thread, None)?;
|
||||
self.crnn_net.init_model_from_memory(rec_bytes, num_thread, None)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn init_models_from_memory_custom(
|
||||
&mut self,
|
||||
det_bytes: &[u8],
|
||||
cls_bytes: &[u8],
|
||||
rec_bytes: &[u8],
|
||||
builder_fn: fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>,
|
||||
) -> Result<(), OcrError> {
|
||||
self.db_net.init_model_from_memory(det_bytes, 0, Some(builder_fn))?;
|
||||
self.angle_net.init_model_from_memory(cls_bytes, 0, Some(builder_fn))?;
|
||||
self.crnn_net.init_model_from_memory(rec_bytes, 0, Some(builder_fn))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn detect_base(
|
||||
&self,
|
||||
img_src: &image::RgbImage,
|
||||
padding: u32,
|
||||
max_side_len: u32,
|
||||
box_score_thresh: f32,
|
||||
box_thresh: f32,
|
||||
un_clip_ratio: f32,
|
||||
do_angle: bool,
|
||||
most_angle: bool,
|
||||
angle_rollback: bool,
|
||||
angle_rollback_threshold: f32,
|
||||
cls_thresh: f32,
|
||||
thresh: f32,
|
||||
) -> Result<OcrResult, OcrError> {
|
||||
let origin_max_side = img_src.width().max(img_src.height());
|
||||
let mut resize;
|
||||
if max_side_len == 0 || max_side_len > origin_max_side {
|
||||
resize = origin_max_side;
|
||||
} else {
|
||||
resize = max_side_len;
|
||||
}
|
||||
resize += 2 * padding;
|
||||
|
||||
// Cow avoids cloning the image when padding=0 (the common case).
|
||||
let padding_src = OcrUtils::make_padding(img_src, padding)?;
|
||||
|
||||
let scale = ScaleParam::get_scale_param(&padding_src, resize);
|
||||
|
||||
self.detect_once(
|
||||
&padding_src,
|
||||
&scale,
|
||||
padding,
|
||||
box_score_thresh,
|
||||
box_thresh,
|
||||
un_clip_ratio,
|
||||
do_angle,
|
||||
most_angle,
|
||||
angle_rollback,
|
||||
angle_rollback_threshold,
|
||||
cls_thresh,
|
||||
thresh,
|
||||
)
|
||||
}
|
||||
|
||||
/// Detect text in image
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `img_src` - Input image
|
||||
/// - `padding` - Padding width added during image transformation (improves detection)
|
||||
/// - `max_side_len` - Maximum side length after transformation (larger images will be scaled down)
|
||||
/// - `box_score_thresh` - Score threshold for text region detection
|
||||
/// - `box_thresh` - Box threshold
|
||||
/// - `un_clip_ratio` - Unclip ratio
|
||||
/// - `do_angle` - Whether to perform angle detection
|
||||
/// - `most_angle` - Use most common angle for all text regions
|
||||
const DEFAULT_CLS_THRESH: f32 = 0.9;
|
||||
const DEFAULT_THRESH: f32 = 0.3;
|
||||
const DEFAULT_REC_BATCH_SIZE: u32 = 6;
|
||||
|
||||
pub fn detect(
|
||||
&self,
|
||||
img_src: &image::RgbImage,
|
||||
padding: u32,
|
||||
max_side_len: u32,
|
||||
box_score_thresh: f32,
|
||||
box_thresh: f32,
|
||||
un_clip_ratio: f32,
|
||||
do_angle: bool,
|
||||
most_angle: bool,
|
||||
) -> Result<OcrResult, OcrError> {
|
||||
self.detect_base(
|
||||
img_src,
|
||||
padding,
|
||||
max_side_len,
|
||||
box_score_thresh,
|
||||
box_thresh,
|
||||
un_clip_ratio,
|
||||
do_angle,
|
||||
most_angle,
|
||||
false,
|
||||
0.0,
|
||||
Self::DEFAULT_CLS_THRESH,
|
||||
Self::DEFAULT_THRESH,
|
||||
)
|
||||
}
|
||||
|
||||
/// Detect text with angle rollback support
|
||||
///
|
||||
/// When `do_angle` is true, if the image was angle-corrected but recognition
|
||||
/// result is poor, the angle correction will be reverted.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `img_src` - Input image
|
||||
/// - `padding` - Padding width added during image transformation
|
||||
/// - `max_side_len` - Maximum side length after transformation
|
||||
/// - `box_score_thresh` - Score threshold for text region detection
|
||||
/// - `box_thresh` - Box threshold
|
||||
/// - `un_clip_ratio` - Unclip ratio
|
||||
/// - `do_angle` - Whether to perform angle detection
|
||||
/// - `most_angle` - Use most common angle
|
||||
/// - `angle_rollback_threshold` - If text score is below this value (or NaN), angle correction is reverted
|
||||
pub fn detect_angle_rollback(
|
||||
&self,
|
||||
img_src: &image::RgbImage,
|
||||
padding: u32,
|
||||
max_side_len: u32,
|
||||
box_score_thresh: f32,
|
||||
box_thresh: f32,
|
||||
un_clip_ratio: f32,
|
||||
do_angle: bool,
|
||||
most_angle: bool,
|
||||
angle_rollback_threshold: f32,
|
||||
) -> Result<OcrResult, OcrError> {
|
||||
self.detect_base(
|
||||
img_src,
|
||||
padding,
|
||||
max_side_len,
|
||||
box_score_thresh,
|
||||
box_thresh,
|
||||
un_clip_ratio,
|
||||
do_angle,
|
||||
most_angle,
|
||||
true,
|
||||
angle_rollback_threshold,
|
||||
Self::DEFAULT_CLS_THRESH,
|
||||
Self::DEFAULT_THRESH,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn detect_from_path(
|
||||
&self,
|
||||
img_path: &str,
|
||||
padding: u32,
|
||||
max_side_len: u32,
|
||||
box_score_thresh: f32,
|
||||
box_thresh: f32,
|
||||
un_clip_ratio: f32,
|
||||
do_angle: bool,
|
||||
most_angle: bool,
|
||||
) -> Result<OcrResult, OcrError> {
|
||||
let img_src = image::open(img_path)?.to_rgb8();
|
||||
|
||||
self.detect(
|
||||
&img_src,
|
||||
padding,
|
||||
max_side_len,
|
||||
box_score_thresh,
|
||||
box_thresh,
|
||||
un_clip_ratio,
|
||||
do_angle,
|
||||
most_angle,
|
||||
)
|
||||
}
|
||||
|
||||
/// Sort text boxes in reading order: top-to-bottom, left-to-right.
|
||||
///
|
||||
/// Sorts by top-left Y coordinate first, then by top-left X coordinate within
|
||||
/// the same Y. Matches PaddleOCR Python's `sorted_boxes` primary ordering.
|
||||
fn sort_text_boxes(text_boxes: &mut [TextBox]) {
|
||||
text_boxes.sort_by(|a, b| {
|
||||
let ay = a.points.first().map_or(0, |p| p.y);
|
||||
let ax = a.points.first().map_or(0, |p| p.x);
|
||||
let by = b.points.first().map_or(0, |p| p.y);
|
||||
let bx = b.points.first().map_or(0, |p| p.x);
|
||||
(ay, ax).cmp(&(by, bx))
|
||||
});
|
||||
}
|
||||
|
||||
fn detect_once(
|
||||
&self,
|
||||
img_src: &image::RgbImage,
|
||||
scale: &ScaleParam,
|
||||
padding: u32,
|
||||
box_score_thresh: f32,
|
||||
box_thresh: f32,
|
||||
un_clip_ratio: f32,
|
||||
do_angle: bool,
|
||||
most_angle: bool,
|
||||
angle_rollback: bool,
|
||||
angle_rollback_threshold: f32,
|
||||
cls_thresh: f32,
|
||||
thresh: f32,
|
||||
) -> Result<OcrResult, OcrError> {
|
||||
let mut text_boxes =
|
||||
self.db_net
|
||||
.get_text_boxes(img_src, scale, box_score_thresh, box_thresh, un_clip_ratio, thresh)?;
|
||||
|
||||
// Sort boxes in reading order (top-to-bottom, left-to-right)
|
||||
Self::sort_text_boxes(&mut text_boxes);
|
||||
|
||||
let part_images = OcrUtils::get_part_images(img_src, &text_boxes);
|
||||
|
||||
let angles = self
|
||||
.angle_net
|
||||
.get_angles(&part_images, do_angle, most_angle, cls_thresh)?;
|
||||
|
||||
let mut rotated_images: Vec<image::RgbImage> = Vec::with_capacity(part_images.len());
|
||||
|
||||
// Angle correction rollback
|
||||
let mut angle_rollback_records = HashMap::<usize, ImageBuffer<image::Rgb<u8>, Vec<u8>>>::new();
|
||||
|
||||
for (index, (angle, mut part_image)) in angles.iter().zip(part_images).enumerate() {
|
||||
if angle.index == 1 {
|
||||
if angle_rollback {
|
||||
// Keep original copy
|
||||
angle_rollback_records.insert(index, part_image.clone());
|
||||
}
|
||||
|
||||
OcrUtils::mat_rotate_clock_wise_180(&mut part_image);
|
||||
}
|
||||
rotated_images.push(part_image);
|
||||
}
|
||||
|
||||
let text_lines = self.crnn_net.get_text_lines(
|
||||
&rotated_images,
|
||||
&angle_rollback_records,
|
||||
angle_rollback_threshold,
|
||||
Self::DEFAULT_REC_BATCH_SIZE,
|
||||
)?;
|
||||
|
||||
let mut text_blocks = Vec::with_capacity(text_lines.len());
|
||||
for (i, text_line) in text_lines.into_iter().enumerate() {
|
||||
text_blocks.push(TextBlock {
|
||||
box_points: text_boxes[i]
|
||||
.points
|
||||
.iter()
|
||||
.map(|p| Point {
|
||||
x: ((p.x as f32) - padding as f32) as u32,
|
||||
y: ((p.y as f32) - padding as f32) as u32,
|
||||
})
|
||||
.collect(),
|
||||
box_score: text_boxes[i].score,
|
||||
angle_index: angles[i].index,
|
||||
angle_score: angles[i].score,
|
||||
text: text_line.text,
|
||||
text_score: text_line.text_score,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(OcrResult { text_blocks })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ocr_result::TextBox;
|
||||
|
||||
fn make_box(x: u32, y: u32) -> TextBox {
|
||||
TextBox {
|
||||
points: vec![
|
||||
Point { x, y },
|
||||
Point { x: x + 100, y },
|
||||
Point { x: x + 100, y: y + 20 },
|
||||
Point { x, y: y + 20 },
|
||||
],
|
||||
score: 0.9,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_text_boxes_top_to_bottom() {
|
||||
let mut boxes = vec![make_box(10, 100), make_box(10, 50), make_box(10, 10)];
|
||||
OcrLite::sort_text_boxes(&mut boxes);
|
||||
assert_eq!(boxes[0].points[0].y, 10);
|
||||
assert_eq!(boxes[1].points[0].y, 50);
|
||||
assert_eq!(boxes[2].points[0].y, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_text_boxes_same_line_left_to_right() {
|
||||
// Boxes with the same Y are sorted left-to-right by X
|
||||
let mut boxes = vec![make_box(200, 10), make_box(100, 10), make_box(50, 10)];
|
||||
OcrLite::sort_text_boxes(&mut boxes);
|
||||
assert_eq!(boxes[0].points[0].x, 50);
|
||||
assert_eq!(boxes[1].points[0].x, 100);
|
||||
assert_eq!(boxes[2].points[0].x, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_text_boxes_multi_line() {
|
||||
// Boxes sorted strictly by (y, x): y=50/x=50, y=50/x=300, y=100/x=100, y=100/x=200
|
||||
let mut boxes = vec![
|
||||
make_box(300, 50), // line 1, right
|
||||
make_box(100, 100), // line 2, left
|
||||
make_box(50, 50), // line 1, left (same y=50)
|
||||
make_box(200, 100), // line 2, right (same y=100)
|
||||
];
|
||||
OcrLite::sort_text_boxes(&mut boxes);
|
||||
|
||||
// Line 1 (y=50): left first, then right
|
||||
assert_eq!(boxes[0].points[0].x, 50);
|
||||
assert_eq!(boxes[1].points[0].x, 300);
|
||||
// Line 2 (y=100): left first, then right
|
||||
assert_eq!(boxes[2].points[0].x, 100);
|
||||
assert_eq!(boxes[3].points[0].x, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_text_boxes_empty() {
|
||||
let mut boxes: Vec<TextBox> = vec![];
|
||||
OcrLite::sort_text_boxes(&mut boxes);
|
||||
assert!(boxes.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_text_boxes_single() {
|
||||
let mut boxes = vec![make_box(10, 20)];
|
||||
OcrLite::sort_text_boxes(&mut boxes);
|
||||
assert_eq!(boxes.len(), 1);
|
||||
}
|
||||
}
|
||||
105
crates/kreuzberg-paddle-ocr/src/ocr_result.rs
Normal file
105
crates/kreuzberg-paddle-ocr/src/ocr_result.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
use std::fmt::{self, Write};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct Point {
|
||||
pub x: u32,
|
||||
pub y: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct TextBox {
|
||||
pub points: Vec<Point>,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
impl fmt::Display for TextBox {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
// SAFETY: We must have at least 4 points in a valid TextBox
|
||||
// This is enforced at the OCR processing level, but we check bounds here for safety
|
||||
if self.points.len() < 4 {
|
||||
return write!(
|
||||
f,
|
||||
"TextBox [score({}), points_count({})]",
|
||||
self.score,
|
||||
self.points.len()
|
||||
);
|
||||
}
|
||||
|
||||
write!(
|
||||
f,
|
||||
"TextBox [score({}), [x: {}, y: {}], [x: {}, y: {}], [x: {}, y: {}], [x: {}, y: {}]]",
|
||||
self.score,
|
||||
self.points[0].x,
|
||||
self.points[0].y,
|
||||
self.points[1].x,
|
||||
self.points[1].y,
|
||||
self.points[2].x,
|
||||
self.points[2].y,
|
||||
self.points[3].x,
|
||||
self.points[3].y,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Angle {
|
||||
pub index: i32,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
impl fmt::Display for Angle {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let header = if self.index >= 0 { "Angle" } else { "AngleDisabled" };
|
||||
write!(f, "{}[Index({}), Score({})]", header, self.index, self.score)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct TextLine {
|
||||
pub text: String,
|
||||
pub text_score: f32,
|
||||
}
|
||||
|
||||
impl fmt::Display for TextLine {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "TextLine[Text({}),TextScore({})]", self.text, self.text_score)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct TextBlock {
|
||||
pub box_points: Vec<Point>,
|
||||
pub box_score: f32,
|
||||
|
||||
pub angle_index: i32,
|
||||
pub angle_score: f32,
|
||||
|
||||
pub text: String,
|
||||
pub text_score: f32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct OcrResult {
|
||||
pub text_blocks: Vec<TextBlock>,
|
||||
}
|
||||
|
||||
impl fmt::Display for OcrResult {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let mut str_builder = String::with_capacity(0);
|
||||
for text_block in &self.text_blocks {
|
||||
write!(
|
||||
str_builder,
|
||||
"TextBlock[BoxPointsLen({}), BoxScore({}), AngleIndex({}), AngleScore({}), Text({}), TextScore({})]",
|
||||
text_block.box_points.len(),
|
||||
text_block.box_score,
|
||||
text_block.angle_index,
|
||||
text_block.angle_score,
|
||||
text_block.text,
|
||||
text_block.text_score
|
||||
)?;
|
||||
}
|
||||
f.write_str(&str_builder)
|
||||
}
|
||||
}
|
||||
206
crates/kreuzberg-paddle-ocr/src/ocr_utils.rs
Normal file
206
crates/kreuzberg-paddle-ocr/src/ocr_utils.rs
Normal file
@@ -0,0 +1,206 @@
|
||||
use std::borrow::Cow;
|
||||
|
||||
use crate::{
|
||||
ocr_error::OcrError,
|
||||
ocr_result::{Point, TextBox},
|
||||
};
|
||||
use image::imageops;
|
||||
use imageproc::geometric_transformations::{Interpolation, Projection};
|
||||
use ndarray::{Array, Array4};
|
||||
|
||||
pub struct OcrUtils;
|
||||
|
||||
impl OcrUtils {
|
||||
/// Normalize image pixels and transpose from HWC (row-major RGB) to CHW tensor format.
|
||||
///
|
||||
/// Formula per pixel: `output[ch] = pixel[ch] * norm[ch] - mean[ch] * norm[ch]`
|
||||
///
|
||||
/// This is a hot path called once per page. Key optimizations:
|
||||
/// - Pre-computes `mean * norm` constants (avoids repeated multiply)
|
||||
/// - Writes each channel plane contiguously via `as_slice_mut()`, enabling
|
||||
/// LLVM auto-vectorization (NEON on ARM64, SSE/AVX on x86-64). The previous
|
||||
/// approach used `tensor[[0, ch, r, c]]` which scattered writes across planes
|
||||
/// and prevented any vectorization.
|
||||
pub fn substract_mean_normalize(img_src: &image::RgbImage, mean_vals: &[f32], norm_vals: &[f32]) -> Array4<f32> {
|
||||
let cols = img_src.width() as usize;
|
||||
let rows = img_src.height() as usize;
|
||||
let pixel_count = rows * cols;
|
||||
|
||||
let mut input_tensor = Array::zeros((1, 3, rows, cols));
|
||||
|
||||
let adjusted = [
|
||||
mean_vals[0] * norm_vals[0],
|
||||
mean_vals[1] * norm_vals[1],
|
||||
mean_vals[2] * norm_vals[2],
|
||||
];
|
||||
|
||||
let raw = img_src.as_raw();
|
||||
|
||||
// Write each channel plane as a contiguous slice. ndarray stores (1,3,H,W)
|
||||
// in C-contiguous (row-major) order, so plane [0,ch] is a contiguous H*W block.
|
||||
// This enables LLVM to auto-vectorize the inner loop (4-8 f32 ops per cycle).
|
||||
for ch in 0..3 {
|
||||
let norm = norm_vals[ch];
|
||||
let adj = adjusted[ch];
|
||||
let plane = input_tensor
|
||||
.slice_mut(ndarray::s![0, ch, .., ..])
|
||||
.into_shape_with_order(pixel_count)
|
||||
.expect("contiguous plane slice");
|
||||
let plane_slice = plane.into_slice().expect("contiguous memory");
|
||||
|
||||
for (i, out) in plane_slice.iter_mut().enumerate() {
|
||||
// raw is HWC: pixel i has R at raw[i*3], G at raw[i*3+1], B at raw[i*3+2]
|
||||
*out = raw[i * 3 + ch] as f32 * norm - adj;
|
||||
}
|
||||
}
|
||||
|
||||
input_tensor
|
||||
}
|
||||
|
||||
/// Add white padding around the image, or borrow it unchanged when padding=0.
|
||||
/// Returns Cow to avoid cloning the image in the common no-padding case.
|
||||
pub fn make_padding<'a>(img_src: &'a image::RgbImage, padding: u32) -> Result<Cow<'a, image::RgbImage>, OcrError> {
|
||||
if padding == 0 {
|
||||
return Ok(Cow::Borrowed(img_src));
|
||||
}
|
||||
|
||||
let width = img_src.width();
|
||||
let height = img_src.height();
|
||||
|
||||
let mut padding_src = image::RgbImage::new(width + 2 * padding, height + 2 * padding);
|
||||
imageproc::drawing::draw_filled_rect_mut(
|
||||
&mut padding_src,
|
||||
imageproc::rect::Rect::at(0, 0).of_size(width + 2 * padding, height + 2 * padding),
|
||||
image::Rgb([255, 255, 255]),
|
||||
);
|
||||
|
||||
image::imageops::replace(&mut padding_src, img_src, padding as i64, padding as i64);
|
||||
|
||||
Ok(Cow::Owned(padding_src))
|
||||
}
|
||||
|
||||
pub fn get_part_images(img_src: &image::RgbImage, text_boxes: &[TextBox]) -> Vec<image::RgbImage> {
|
||||
text_boxes
|
||||
.iter()
|
||||
.map(|text_box| Self::get_rotate_crop_image(img_src, &text_box.points))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn get_rotate_crop_image(img_src: &image::RgbImage, box_points: &[Point]) -> image::RgbImage {
|
||||
let mut points = box_points.to_vec();
|
||||
|
||||
// Calculate bounding box
|
||||
let (min_x, min_y, max_x, max_y) = points.iter().fold(
|
||||
(u32::MAX, u32::MAX, 0u32, 0u32),
|
||||
|(min_x, min_y, max_x, max_y), point| {
|
||||
(
|
||||
min_x.min(point.x),
|
||||
min_y.min(point.y),
|
||||
max_x.max(point.x),
|
||||
max_y.max(point.y),
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
// Crop image
|
||||
let img_crop = imageops::crop_imm(img_src, min_x, min_y, max_x - min_x, max_y - min_y).to_image();
|
||||
|
||||
for point in &mut points {
|
||||
point.x = point.x.saturating_sub(min_x);
|
||||
point.y = point.y.saturating_sub(min_y);
|
||||
}
|
||||
|
||||
// Ensure we have enough points for transformation
|
||||
if points.len() < 4 {
|
||||
// Fallback: return the cropped image as-is if we don't have 4 points
|
||||
return img_crop;
|
||||
}
|
||||
|
||||
// Direct multiplication instead of .pow(2) — avoids integer power function overhead.
|
||||
let dx_w = (points[0].x as i32 - points[1].x as i32) as f32;
|
||||
let dy_w = (points[0].y as i32 - points[1].y as i32) as f32;
|
||||
let img_crop_width = (dx_w * dx_w + dy_w * dy_w).sqrt() as u32;
|
||||
let dx_h = (points[0].x as i32 - points[3].x as i32) as f32;
|
||||
let dy_h = (points[0].y as i32 - points[3].y as i32) as f32;
|
||||
let img_crop_height = (dx_h * dx_h + dy_h * dy_h).sqrt() as u32;
|
||||
|
||||
// Ensure dimensions are valid (non-zero)
|
||||
if img_crop_width == 0 || img_crop_height == 0 {
|
||||
return img_crop;
|
||||
}
|
||||
|
||||
let src_points = [
|
||||
(points[0].x as f32, points[0].y as f32),
|
||||
(points[1].x as f32, points[1].y as f32),
|
||||
(points[2].x as f32, points[2].y as f32),
|
||||
(points[3].x as f32, points[3].y as f32),
|
||||
];
|
||||
|
||||
let dst_points = [
|
||||
(0.0, 0.0),
|
||||
(img_crop_width as f32, 0.0),
|
||||
(img_crop_width as f32, img_crop_height as f32),
|
||||
(0.0, img_crop_height as f32),
|
||||
];
|
||||
|
||||
let projection = match Projection::from_control_points(src_points, dst_points) {
|
||||
Some(proj) => proj,
|
||||
None => {
|
||||
// If projection cannot be created, return the cropped image as fallback
|
||||
return img_crop;
|
||||
}
|
||||
};
|
||||
|
||||
let mut part_img = image::RgbImage::new(img_crop_width, img_crop_height);
|
||||
imageproc::geometric_transformations::warp_into(
|
||||
&img_crop,
|
||||
&projection,
|
||||
Interpolation::Nearest,
|
||||
image::Rgb([255, 255, 255]),
|
||||
&mut part_img,
|
||||
);
|
||||
|
||||
// Rotate image if needed
|
||||
if part_img.height() >= part_img.width() * 3 / 2 {
|
||||
let mut rotated = image::RgbImage::new(part_img.height(), part_img.width());
|
||||
|
||||
for (x, y, pixel) in part_img.enumerate_pixels() {
|
||||
rotated.put_pixel(y, part_img.width() - 1 - x, *pixel);
|
||||
}
|
||||
|
||||
rotated
|
||||
} else {
|
||||
part_img
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mat_rotate_clock_wise_180(src: &mut image::RgbImage) {
|
||||
imageops::rotate180_in_place(src);
|
||||
}
|
||||
|
||||
/// Compute mean of f32 image values where mask > 0.
|
||||
///
|
||||
/// Uses raw slice access instead of per-pixel get_pixel() for better
|
||||
/// cache behavior and to enable auto-vectorization of the reduction.
|
||||
pub fn calculate_mean_with_mask(
|
||||
img: &image::ImageBuffer<image::Luma<f32>, Vec<f32>>,
|
||||
mask: &image::ImageBuffer<image::Luma<u8>, Vec<u8>>,
|
||||
) -> f32 {
|
||||
assert_eq!(img.width(), mask.width());
|
||||
assert_eq!(img.height(), mask.height());
|
||||
|
||||
let img_raw = img.as_raw();
|
||||
let mask_raw = mask.as_raw();
|
||||
let mut sum: f32 = 0.0;
|
||||
let mut count: u32 = 0;
|
||||
|
||||
for (px, &m) in img_raw.iter().zip(mask_raw.iter()) {
|
||||
if m > 0 {
|
||||
sum += *px;
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if count == 0 { 0.0 } else { sum / count as f32 }
|
||||
}
|
||||
}
|
||||
69
crates/kreuzberg-paddle-ocr/src/scale_param.rs
Normal file
69
crates/kreuzberg-paddle-ocr/src/scale_param.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
#[derive(Debug)]
|
||||
pub struct ScaleParam {
|
||||
pub src_width: u32,
|
||||
pub src_height: u32,
|
||||
pub dst_width: u32,
|
||||
pub dst_height: u32,
|
||||
pub scale_width: f32,
|
||||
pub scale_height: f32,
|
||||
}
|
||||
|
||||
impl ScaleParam {
|
||||
pub fn new(
|
||||
src_width: u32,
|
||||
src_height: u32,
|
||||
dst_width: u32,
|
||||
dst_height: u32,
|
||||
scale_width: f32,
|
||||
scale_height: f32,
|
||||
) -> Self {
|
||||
Self {
|
||||
src_width,
|
||||
src_height,
|
||||
dst_width,
|
||||
dst_height,
|
||||
scale_width,
|
||||
scale_height,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_scale_param(src: &image::RgbImage, target_size: u32) -> Self {
|
||||
let src_width = src.width();
|
||||
let src_height = src.height();
|
||||
let mut dst_width;
|
||||
let mut dst_height;
|
||||
|
||||
let ratio: f32 = if src_width > src_height {
|
||||
target_size as f32 / src_width as f32
|
||||
} else {
|
||||
target_size as f32 / src_height as f32
|
||||
};
|
||||
|
||||
dst_width = (src_width as f32 * ratio) as u32;
|
||||
dst_height = (src_height as f32 * ratio) as u32;
|
||||
|
||||
if dst_width % 32 != 0 {
|
||||
dst_width = (dst_width / 32) * 32;
|
||||
dst_width = dst_width.max(32);
|
||||
}
|
||||
if dst_height % 32 != 0 {
|
||||
dst_height = (dst_height / 32) * 32;
|
||||
dst_height = dst_height.max(32);
|
||||
}
|
||||
|
||||
let scale_width = dst_width as f32 / src_width as f32;
|
||||
let scale_height = dst_height as f32 / src_height as f32;
|
||||
|
||||
Self::new(src_width, src_height, dst_width, dst_height, scale_width, scale_height)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ScaleParam {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"src_width:{},src_height:{},dst_width:{},dst_height:{},scale_width:{},scale_height:{}",
|
||||
self.src_width, self.src_height, self.dst_width, self.dst_height, self.scale_width, self.scale_height
|
||||
)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user