Nomad changes
All checks were successful
Deploy fil (kreuzberg) / deploy (push) Successful in 49s

This commit is contained in:
Henrik Jess Nielsen
2026-06-01 23:40:55 +02:00
parent 72b1a0a6ed
commit b4c07d3693
5723 changed files with 1130655 additions and 0 deletions

View File

@@ -0,0 +1,41 @@
[package]
name = "kreuzberg-paddle-ocr"
version.workspace = true
edition = "2024"
rust-version.workspace = true
authors.workspace = true
description = "PaddleOCR via ONNX Runtime for Kreuzberg - high-performance text recognition"
license = "MIT"
repository.workspace = true
homepage = "https://kreuzberg.dev"
documentation = "https://docs.rs/kreuzberg-paddle-ocr"
readme = "README.md"
keywords = ["paddle", "ocr", "onnx", "recognition", "detection"]
categories = ["computer-vision", "text-processing"]
exclude = ["tests/*", ".github/*"]
[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]
[lib]
name = "kreuzberg_paddle_ocr"
crate-type = ["lib"]
[features]
default = []
load-dynamic = ["ort/load-dynamic"]
[dependencies]
geo-clipper = "0.9"
geo-types = "0.7"
image = { workspace = true }
# Crate-specific dependencies (not in workspace)
# Disable rayon - OCR parallelism is handled at higher level
imageproc = { version = "0.26", default-features = false }
ndarray = "0.17"
ort = { workspace = true, features = ["ndarray"] }
# Workspace dependencies
serde = { workspace = true }
thiserror = { workspace = true }

View File

@@ -0,0 +1,22 @@
MIT License
Copyright (c) 2024 mg-chao
Copyright (c) 2025 Na'aman Hirschfeld
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,57 @@
# kreuzberg-paddle-ocr
[![Bindings](https://img.shields.io/badge/Bindings-alef%20%D7%90-007ec6)](https://github.com/kreuzberg-dev/alef)
PaddleOCR via ONNX Runtime for Kreuzberg - high-performance text detection and recognition using PaddlePaddle's OCR models.
Based on the original [paddle-ocr-rs](https://github.com/mg-chao/paddle-ocr-rs) by [mg-chao](https://github.com/mg-chao), this vendored version includes improvements for Kreuzberg integration:
- **Workspace Dependency Alignment**: Uses Kreuzberg's workspace dependencies for consistency
- **Edition 2024**: Updated to Rust 2024 edition
- **ndarray Compatibility**: Aligned with Kreuzberg's ndarray version requirements
- **Integration**: Designed to work seamlessly with Kreuzberg's OCR backend system
## Features
- Text detection using DBNet (Differentiable Binarization)
- Text recognition using CRNN (Convolutional Recurrent Neural Network)
- Angle detection for rotated text
- Support for multiple languages via PaddleOCR models
- ONNX Runtime for efficient CPU inference
## ONNX Runtime Requirement
This crate requires **ONNX Runtime 1.24+** at runtime.
Install it:
- **macOS (Homebrew)**: `brew install onnxruntime`
- **Linux**: Download from [ONNX Runtime releases](https://github.com/microsoft/onnxruntime/releases)
- **Windows**: Download from [ONNX Runtime releases](https://github.com/microsoft/onnxruntime/releases)
## Usage
This crate is used internally by Kreuzberg when the `paddle-ocr` feature is enabled:
```toml
[dependencies]
kreuzberg = { version = "4.2", features = ["paddle-ocr"] }
```
## Models
PaddleOCR models are automatically downloaded and cached on first use. Supported models include:
- PP-OCRv5 server detection model
- PP-OCRv5 per-family recognition models (11 script families)
- PPOCRv2 mobile angle classification model
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
## Acknowledgements
This project is based on the original [paddle-ocr-rs](https://github.com/mg-chao/paddle-ocr-rs) by [mg-chao](https://github.com/mg-chao), originally licensed under Apache-2.0. We are grateful for the foundational work that made this integration possible.
The original paddle-ocr-rs provides Rust bindings for PaddlePaddle's OCR models via ONNX Runtime, enabling efficient text detection and recognition without Python dependencies.

View File

@@ -0,0 +1,139 @@
use crate::{
base_net::BaseNet,
constants::{IMAGENET_MEAN_VALUES, IMAGENET_NORM_VALUES},
ocr_error::OcrError,
ocr_result::Angle,
ocr_utils::OcrUtils,
};
use ort::{
inputs,
session::{Session, SessionOutputs},
value::Tensor,
};
// PP-LCNet_x1_0_textline_ori preprocessing (ImageNet normalization).
// Input: resize to 160×80 (W×H), normalize with ImageNet mean/std.
// Formula in substract_mean_normalize: (pixel - MEAN) * NORM
// For ImageNet: (pixel/255 - mean) / std = (pixel - mean*255) * (1/(std*255))
// V2 PP-LCNet angle classifier expects [3, 80, 160] input (NCHW).
const ANGLE_DST_WIDTH: u32 = 160;
const ANGLE_DST_HEIGHT: u32 = 80;
const ANGLE_COLS: usize = 2;
#[derive(Debug)]
pub struct AngleNet {
session: Option<Session>,
input_names: Vec<String>,
}
impl BaseNet for AngleNet {
fn new() -> Self {
Self {
session: None,
input_names: Vec::new(),
}
}
fn set_input_names(&mut self, input_names: Vec<String>) {
self.input_names = input_names;
}
fn set_session(&mut self, session: Option<Session>) {
self.session = session;
}
}
impl AngleNet {
pub fn get_angles(
&self,
part_imgs: &[image::RgbImage],
do_angle: bool,
most_angle: bool,
cls_thresh: f32,
) -> Result<Vec<Angle>, OcrError> {
// Pre-allocate — we know exact count upfront.
let mut angles = Vec::with_capacity(part_imgs.len());
if do_angle {
for img in part_imgs {
let angle = self.get_angle(img, cls_thresh)?;
angles.push(angle);
}
} else {
angles.extend(part_imgs.iter().map(|_| Angle::default()));
}
if do_angle && most_angle {
let sum: i32 = angles.iter().map(|x| x.index).sum();
let half_percent = angles.len() as f32 / 2.0;
let most_angle_index = if (sum as f32) < half_percent { 0 } else { 1 };
for angle in angles.iter_mut() {
angle.index = most_angle_index;
}
}
Ok(angles)
}
fn get_angle(&self, img_src: &image::RgbImage, cls_thresh: f32) -> Result<Angle, OcrError> {
let Some(session) = &self.session else {
return Err(OcrError::SessionNotInitialized);
};
let angle_img = image::imageops::resize(
img_src,
ANGLE_DST_WIDTH,
ANGLE_DST_HEIGHT,
image::imageops::FilterType::Triangle,
);
let input_tensors =
OcrUtils::substract_mean_normalize(&angle_img, &IMAGENET_MEAN_VALUES, &IMAGENET_NORM_VALUES);
let input_tensors = Tensor::from_array(input_tensors)?;
// SAFETY: ONNX Runtime C API is thread-safe for concurrent inference.
#[allow(unsafe_code)]
let outputs = unsafe {
let session_ptr = session as *const Session as *mut Session;
(*session_ptr).run(inputs![self.input_names[0].as_str() => input_tensors])?
};
let mut angle = Self::score_to_angle(&outputs, ANGLE_COLS)?;
// Only apply rotation if confidence exceeds threshold (matches PaddleOCR's cls_thresh=0.9)
if angle.score < cls_thresh {
angle.index = 0; // Keep original orientation when confidence is low
}
Ok(angle)
}
fn score_to_angle(output_tensor: &SessionOutputs, angle_cols: usize) -> Result<Angle, OcrError> {
let (_, red_data) = output_tensor.iter().next().ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"No output tensors found in angle classification session output",
))
})?;
let src_data: Vec<f32> = red_data.try_extract_tensor::<f32>()?.1.to_vec();
let mut angle = Angle::default();
let mut max_value = f32::MIN;
let mut angle_index = 0;
for (i, value) in src_data.iter().take(angle_cols).enumerate() {
if *value > max_value {
max_value = *value;
angle_index = i as i32;
}
}
angle.index = angle_index;
angle.score = max_value;
Ok(angle)
}
}

View File

@@ -0,0 +1,78 @@
use ort::session::{
Session,
builder::{GraphOptimizationLevel, SessionBuilder},
};
use crate::ocr_error::OcrError;
pub trait BaseNet {
fn new() -> Self;
fn get_session_builder(
&self,
num_thread: usize,
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
) -> Result<SessionBuilder, OcrError> {
let builder = Session::builder()?;
let builder = match builder_fn {
Some(custom) => custom(builder)?,
None => builder
.with_optimization_level(GraphOptimizationLevel::All)
.map_err(|e| OcrError::Ort(ort::Error::new(e.message())))?
.with_intra_threads(num_thread)
.map_err(|e| OcrError::Ort(ort::Error::new(e.message())))?
.with_inter_threads(1)
.map_err(|e| OcrError::Ort(ort::Error::new(e.message())))?,
};
Ok(builder)
}
fn set_input_names(&mut self, input_names: Vec<String>);
fn set_session(&mut self, session: Option<Session>);
fn init(&mut self, session: Session) {
let input_names: Vec<String> = session.inputs().iter().map(|input| input.name().to_string()).collect();
self.set_input_names(input_names);
self.set_session(Some(session));
}
fn init_model(
&mut self,
path: &str,
num_thread: usize,
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
) -> Result<(), OcrError> {
// Wrap ORT session creation in catch_unwind to prevent mutex poisoning
// on platforms where ORT initialization can panic (notably Windows).
let path_owned = path.to_string();
let session = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut builder = self.get_session_builder(num_thread, builder_fn)?;
builder.commit_from_file(&path_owned).map_err(OcrError::from)
}))
.map_err(|_| OcrError::Ort(ort::Error::new("ORT session initialization panicked")))??;
self.init(session);
Ok(())
}
fn init_model_from_memory(
&mut self,
model_bytes: &[u8],
num_thread: usize,
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
) -> Result<(), OcrError> {
// Wrap ORT session creation in catch_unwind to prevent mutex poisoning
// on platforms where ORT initialization can panic (notably Windows).
let session = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut builder = self.get_session_builder(num_thread, builder_fn)?;
builder.commit_from_memory(model_bytes).map_err(OcrError::from)
}))
.map_err(|_| OcrError::Ort(ort::Error::new("ORT session initialization panicked")))??;
self.init(session);
Ok(())
}
}

View File

@@ -0,0 +1,33 @@
//! Shared normalization constants for PaddleOCR preprocessing.
//!
//! Two normalization schemes are used:
//!
//! - **ImageNet** (`IMAGENET_MEAN_VALUES` / `IMAGENET_NORM_VALUES`): used by the text
//! detection network (`DbNet`) and the angle classifier (`AngleNet`).
//! Formula: `(pixel - mean * 255) * (1 / (std * 255))`.
//!
//! - **CRNN** (`CRNN_MEAN_VALUES` / `CRNN_NORM_VALUES`): used by the text recognition
//! network (`CrnnNet`).
//! Formula: `(pixel - 127.5) * (1 / 127.5)`.
/// ImageNet channel means (R, G, B), pre-multiplied by 255.
///
/// Derived from `[0.485, 0.456, 0.406]` (per-channel ImageNet means).
/// Used by `DbNet` (text detection) and `AngleNet` (angle classification).
pub(crate) const IMAGENET_MEAN_VALUES: [f32; 3] = [0.485 * 255.0, 0.456 * 255.0, 0.406 * 255.0];
/// ImageNet channel normalization factors (R, G, B), equal to `1 / (std * 255)`.
///
/// Derived from `[0.229, 0.224, 0.225]` (per-channel ImageNet standard deviations).
/// Used by `DbNet` (text detection) and `AngleNet` (angle classification).
pub(crate) const IMAGENET_NORM_VALUES: [f32; 3] = [1.0 / (0.229 * 255.0), 1.0 / (0.224 * 255.0), 1.0 / (0.225 * 255.0)];
/// CRNN channel means (R, G, B): `127.5` for all channels.
///
/// Used by `CrnnNet` (text recognition).
pub(crate) const CRNN_MEAN_VALUES: [f32; 3] = [127.5, 127.5, 127.5];
/// CRNN channel normalization factors (R, G, B): `1 / 127.5` for all channels.
///
/// Used by `CrnnNet` (text recognition).
pub(crate) const CRNN_NORM_VALUES: [f32; 3] = [1.0 / 127.5, 1.0 / 127.5, 1.0 / 127.5];

View File

@@ -0,0 +1,393 @@
use ndarray::Array4;
use ort::session::Session;
use ort::value::Tensor;
use ort::{inputs, session::builder::SessionBuilder};
use std::collections::HashMap;
use crate::{
base_net::BaseNet,
constants::{CRNN_MEAN_VALUES, CRNN_NORM_VALUES},
ocr_error::OcrError,
ocr_result::TextLine,
ocr_utils::OcrUtils,
};
const CRNN_DST_HEIGHT: u32 = 48;
#[derive(Debug)]
pub struct CrnnNet {
session: Option<Session>,
keys: Vec<String>,
input_names: Vec<String>,
}
impl BaseNet for CrnnNet {
fn new() -> Self {
Self {
session: None,
keys: Vec::new(),
input_names: Vec::new(),
}
}
fn set_input_names(&mut self, input_names: Vec<String>) {
self.input_names = input_names;
}
fn set_session(&mut self, session: Option<Session>) {
self.session = session;
}
}
impl CrnnNet {
pub fn init_model(
&mut self,
path: &str,
num_thread: usize,
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
) -> Result<(), OcrError> {
BaseNet::init_model(self, path, num_thread, builder_fn)?;
self.keys = self.get_keys()?;
Ok(())
}
pub fn init_model_dict_file(
&mut self,
path: &str,
num_thread: usize,
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
dict_file_path: &str,
) -> Result<(), OcrError> {
BaseNet::init_model(self, path, num_thread, builder_fn)?;
self.read_keys_from_file(dict_file_path)?;
Ok(())
}
pub fn init_model_from_memory(
&mut self,
model_bytes: &[u8],
num_thread: usize,
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
) -> Result<(), OcrError> {
BaseNet::init_model_from_memory(self, model_bytes, num_thread, builder_fn)?;
self.keys = self.get_keys()?;
Ok(())
}
fn get_keys(&mut self) -> Result<Vec<String>, OcrError> {
let session = self.session.as_ref().ok_or(OcrError::SessionNotInitialized)?;
let metadata = session.metadata()?;
let model_charater_list = metadata.custom("character").ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
"crnn_net character not found in metadata",
))
})?;
// PP-OCRv5 model metadata already includes the CTC blank token ("#") at
// index 0 and the space token (" ") at the end. Do NOT prepend/append
// extra tokens — doing so shifts every character index by one and
// produces garbled output.
let keys: Vec<String> = model_charater_list.split('\n').map(|s: &str| s.to_string()).collect();
Ok(keys)
}
fn read_keys_from_file(&mut self, path: &str) -> Result<(), OcrError> {
let content = std::fs::read_to_string(path)?;
// PP-OCRv5 dict files already include the CTC blank token ("#") at
// index 0 and the space token (" ") at the end. Do NOT prepend/append
// extra tokens — doing so shifts every character index by one and
// produces garbled output.
let keys: Vec<String> = content.split('\n').map(|s| s.to_string()).collect();
self.keys = keys;
Ok(())
}
pub fn get_text_lines(
&self,
part_imgs: &[image::RgbImage],
angle_rollback_records: &HashMap<usize, image::RgbImage>,
angle_rollback_threshold: f32,
batch_size: u32,
) -> Result<Vec<TextLine>, OcrError> {
if part_imgs.is_empty() {
return Ok(Vec::new());
}
// Batch recognition: sort by aspect ratio, batch, pad to max width
let mut text_lines = self.get_text_lines_batched(part_imgs, batch_size)?;
// Angle rollback: re-recognize individual images that scored poorly
for (index, text_line) in text_lines.iter_mut().enumerate() {
if (text_line.text_score.is_nan() || text_line.text_score < angle_rollback_threshold)
&& let Some(angle_rollback_record) = angle_rollback_records.get(&index)
{
*text_line = self.get_text_line(angle_rollback_record)?;
}
}
Ok(text_lines)
}
/// Batch recognition: sort crops by width, group into batches, pad to max width,
/// run single ONNX inference per batch. Matches PaddleOCR/RapidOCR batching strategy.
fn get_text_lines_batched(
&self,
part_imgs: &[image::RgbImage],
batch_size: u32,
) -> Result<Vec<TextLine>, OcrError> {
let session = self.session.as_ref().ok_or(OcrError::SessionNotInitialized)?;
let batch_size = (batch_size as usize).max(1);
// Compute target widths and sort indices by aspect ratio (width/height)
let mut indexed_widths: Vec<(usize, u32)> = part_imgs
.iter()
.enumerate()
.map(|(i, img)| {
let scale = CRNN_DST_HEIGHT as f32 / img.height().max(1) as f32;
let dst_width = (img.width() as f32 * scale).ceil() as u32;
(i, dst_width.max(1))
})
.collect();
indexed_widths.sort_by_key(|&(_, w)| w);
let mut results: Vec<(usize, TextLine)> = Vec::with_capacity(part_imgs.len());
// Process in batches
for chunk in indexed_widths.chunks(batch_size) {
if chunk.len() == 1 {
// Single image — use existing path (no padding overhead)
let (orig_idx, _) = chunk[0];
let text_line = self.get_text_line(&part_imgs[orig_idx])?;
results.push((orig_idx, text_line));
continue;
}
// Find max width in this batch
let max_width = chunk.iter().map(|&(_, w)| w).max().unwrap_or(1);
// Build batch tensor [N, 3, 48, max_width] with zero-padding
let n = chunk.len();
let mut batch_data = Array4::<f32>::zeros((n, 3, CRNN_DST_HEIGHT as usize, max_width as usize));
for (batch_idx, &(orig_idx, dst_width)) in chunk.iter().enumerate() {
let img = &part_imgs[orig_idx];
let resized =
image::imageops::resize(img, dst_width, CRNN_DST_HEIGHT, image::imageops::FilterType::Triangle);
// Normalize and fill into batch tensor (zero-padded on right).
// Use raw slice access instead of per-pixel get_pixel() to
// eliminate millions of bounds checks in the hot loop.
let cols = resized.width() as usize;
let rows = resized.height() as usize;
let raw = resized.as_raw();
assert_eq!(raw.len(), rows * cols * 3, "unexpected image buffer size");
let adjusted = [
CRNN_MEAN_VALUES[0] * CRNN_NORM_VALUES[0],
CRNN_MEAN_VALUES[1] * CRNN_NORM_VALUES[1],
CRNN_MEAN_VALUES[2] * CRNN_NORM_VALUES[2],
];
for r in 0..rows {
for c in 0..cols {
let base = r * cols * 3 + c * 3;
for ch in 0..3 {
batch_data[[batch_idx, ch, r, c]] =
raw[base + ch] as f32 * CRNN_NORM_VALUES[ch] - adjusted[ch];
}
}
}
// Remaining columns stay zero (padding)
}
let input_tensor = Tensor::from_array(batch_data)?;
// SAFETY: ONNX Runtime C API is thread-safe for concurrent inference.
#[allow(unsafe_code)]
let outputs = unsafe {
let session_ptr = session as *const Session as *mut Session;
(*session_ptr).run(inputs![self.input_names[0].as_str() => input_tensor])?
};
let (_, output_value) = outputs.iter().next().ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"No output tensors found in batched CRNN session output",
))
})?;
let (shape, flat_data) = output_value.try_extract_tensor::<f32>()?;
// Shape: [batch, timesteps, num_classes]
let batch_dim = *shape.first().unwrap_or(&1) as usize;
let timesteps = *shape.get(1).unwrap_or(&0) as usize;
let num_classes = *shape.get(2).unwrap_or(&0) as usize;
for (batch_idx, item) in chunk.iter().enumerate().take(batch_dim.min(n)) {
let offset = batch_idx * timesteps * num_classes;
let slice = &flat_data[offset..offset + timesteps * num_classes];
let text_line = Self::score_to_text_line(slice, timesteps, num_classes, &self.keys)?;
results.push((item.0, text_line));
}
}
// Reorder results back to original index order
results.sort_by_key(|&(idx, _)| idx);
Ok(results.into_iter().map(|(_, tl)| tl).collect())
}
fn get_text_line(&self, img_src: &image::RgbImage) -> Result<TextLine, OcrError> {
let Some(session) = &self.session else {
return Err(OcrError::SessionNotInitialized);
};
let scale = CRNN_DST_HEIGHT as f32 / img_src.height() as f32;
let dst_width = (img_src.width() as f32 * scale).ceil() as u32;
let src_resize = image::imageops::resize(
img_src,
dst_width,
CRNN_DST_HEIGHT,
image::imageops::FilterType::Triangle,
);
let input_tensors = OcrUtils::substract_mean_normalize(&src_resize, &CRNN_MEAN_VALUES, &CRNN_NORM_VALUES);
let input_tensors = Tensor::from_array(input_tensors)?;
// SAFETY: ONNX Runtime C API is thread-safe for concurrent inference.
#[allow(unsafe_code)]
let outputs = unsafe {
let session_ptr = session as *const Session as *mut Session;
(*session_ptr).run(inputs![self.input_names[0].as_str() => input_tensors])?
};
let (_, red_data) = outputs.iter().next().ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"No output tensors found in CRNN session output",
))
})?;
let (shape, src_data) = red_data.try_extract_tensor::<f32>()?;
let dimensions = shape;
let height = *dimensions.get(1).ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"CRNN output tensor missing height dimension (index 1)",
))
})? as usize;
let width = *dimensions.get(2).ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"CRNN output tensor missing width dimension (index 2)",
))
})? as usize;
let src_data: Vec<f32> = src_data.to_vec();
Self::score_to_text_line(&src_data, height, width, &self.keys)
}
fn score_to_text_line(
output_data: &[f32],
height: usize,
width: usize,
keys: &[String],
) -> Result<TextLine, OcrError> {
let mut text_line = TextLine::default();
let mut last_index = 0;
let mut text_score_sum = 0.0;
let mut text_score_count = 0;
for i in 0..height {
let start = i * width;
let stop = (i + 1) * width;
let slice = &output_data[start..stop.min(output_data.len())];
let (max_index, max_value) =
slice
.iter()
.enumerate()
.fold((0, f32::MIN), |(max_idx, max_val), (idx, &val)| {
if val > max_val { (idx, val) } else { (max_idx, max_val) }
});
if max_index > 0 && max_index < keys.len() && !(i > 0 && max_index == last_index) {
text_line.text.push_str(&keys[max_index]);
text_score_sum += max_value;
text_score_count += 1;
}
last_index = max_index;
}
// Avoid division by zero: handle case where no characters were found
text_line.text_score = if text_score_count > 0 {
text_score_sum / text_score_count as f32
} else {
0.0
};
Ok(text_line)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_score_to_text_line_skips_blank_index() {
// keys[0] = "#" (CTC blank), keys[1] = "a", keys[2] = "b"
let keys = vec!["#".to_string(), "a".to_string(), "b".to_string()];
// 3 timesteps, 3 classes each. Simulate: blank, "a", "b"
let output = vec![
1.0, 0.0, 0.0, // timestep 0: max at index 0 (blank) -> skip
0.0, 0.9, 0.1, // timestep 1: max at index 1 ("a")
0.0, 0.1, 0.8, // timestep 2: max at index 2 ("b")
];
let result = CrnnNet::score_to_text_line(&output, 3, 3, &keys).unwrap();
assert_eq!(result.text, "ab");
}
#[test]
fn test_score_to_text_line_deduplicates_consecutive() {
let keys = vec!["#".to_string(), "h".to_string(), "i".to_string()];
// 4 timesteps: "h", "h", "i", "i" -> should deduplicate to "hi"
let output = vec![
0.0, 0.9, 0.0, // "h"
0.0, 0.8, 0.0, // "h" again (same index, skip)
0.0, 0.0, 0.9, // "i"
0.0, 0.0, 0.8, // "i" again (same index, skip)
];
let result = CrnnNet::score_to_text_line(&output, 4, 3, &keys).unwrap();
assert_eq!(result.text, "hi");
}
#[test]
fn test_read_keys_from_file_preserves_dict_layout() {
let dir = std::env::temp_dir().join("kreuzberg_test_dict");
std::fs::create_dir_all(&dir).unwrap();
let dict_path = dir.join("test_dict.txt");
// PP-OCRv5 dict files already include "#" (blank) at start and " " at end.
std::fs::write(&dict_path, "#\na\nb\nc\n ").unwrap();
let mut net = CrnnNet::new();
net.read_keys_from_file(dict_path.to_str().unwrap()).unwrap();
// Dict is loaded as-is: ["#", "a", "b", "c", " "]
assert_eq!(net.keys[0], "#");
assert_eq!(net.keys[1], "a");
assert_eq!(net.keys[2], "b");
assert_eq!(net.keys[3], "c");
assert_eq!(net.keys[net.keys.len() - 1], " ");
std::fs::remove_dir_all(&dir).ok();
}
}

View File

@@ -0,0 +1,421 @@
use crate::{
base_net::BaseNet,
constants::{IMAGENET_MEAN_VALUES, IMAGENET_NORM_VALUES},
ocr_error::OcrError,
ocr_result::{self, TextBox},
ocr_utils::OcrUtils,
scale_param::ScaleParam,
};
use geo_clipper::{Clipper, EndType, JoinType};
use geo_types::{Coord, LineString, Polygon};
use ort::{inputs, session::SessionOutputs};
use ort::{session::Session, value::Tensor};
use std::cmp::Ordering;
#[derive(Debug)]
pub struct DbNet {
session: Option<Session>,
input_names: Vec<String>,
}
impl BaseNet for DbNet {
fn new() -> Self {
Self {
session: None,
input_names: Vec::new(),
}
}
fn set_input_names(&mut self, input_names: Vec<String>) {
self.input_names = input_names;
}
fn set_session(&mut self, session: Option<Session>) {
self.session = session;
}
}
impl DbNet {
pub fn get_text_boxes(
&self,
img_src: &image::RgbImage,
scale: &ScaleParam,
box_score_thresh: f32,
box_thresh: f32,
un_clip_ratio: f32,
thresh: f32,
) -> Result<Vec<TextBox>, OcrError> {
let Some(session) = &self.session else {
return Err(OcrError::SessionNotInitialized);
};
let src_resize = image::imageops::resize(
img_src,
scale.dst_width,
scale.dst_height,
image::imageops::FilterType::Triangle,
);
let input_tensors =
OcrUtils::substract_mean_normalize(&src_resize, &IMAGENET_MEAN_VALUES, &IMAGENET_NORM_VALUES);
let tensor = Tensor::from_array(input_tensors)?;
// SAFETY: ONNX Runtime's C API (OrtRun) is thread-safe for concurrent inference
// on the same session. The ort crate's `&mut self` requirement is overly
// conservative. This matches the pattern used in kreuzberg's embedding engine.
#[allow(unsafe_code)]
let outputs = unsafe {
let session_ptr = session as *const Session as *mut Session;
(*session_ptr).run(inputs![self.input_names[0].as_str() => tensor])?
};
let text_boxes = Self::get_text_boxes_core(
&outputs,
src_resize.height(),
src_resize.width(),
&ScaleParam::new(
scale.src_width,
scale.src_height,
scale.dst_width,
scale.dst_height,
scale.scale_width,
scale.scale_height,
),
box_score_thresh,
box_thresh,
un_clip_ratio,
thresh,
)?;
Ok(text_boxes)
}
fn get_text_boxes_core(
output_tensor: &SessionOutputs,
rows: u32,
cols: u32,
s: &ScaleParam,
box_score_thresh: f32,
_box_thresh: f32,
un_clip_ratio: f32,
thresh: f32,
) -> Result<Vec<TextBox>, OcrError> {
let max_side_thresh = 3.0;
let (_, red_data) = output_tensor.iter().next().ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"No output tensors found in session output",
))
})?;
let pred_data: Vec<f32> = red_data.try_extract_tensor::<f32>()?.1.to_vec();
let cbuf_data: Vec<u8> = pred_data.iter().map(|pixel| (pixel * 255.0) as u8).collect();
let pred_img: image::ImageBuffer<image::Luma<f32>, Vec<f32>> =
image::ImageBuffer::from_vec(cols, rows, pred_data).ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Failed to create image buffer from predictions: {} x {} dimensions may be invalid",
cols, rows
),
))
})?;
let cbuf_img = image::GrayImage::from_vec(cols, rows, cbuf_data).ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Failed to create grayscale image buffer: {} x {} dimensions may be invalid",
cols, rows
),
))
})?;
let threshold_img = imageproc::contrast::threshold(
&cbuf_img,
(thresh * 255.0) as u8,
imageproc::contrast::ThresholdType::Binary,
);
// RapidOCR and PaddleOCR reference do NOT apply dilation before contour extraction.
// Dilation merges adjacent text regions, causing word concatenation.
let img_contours: Vec<imageproc::contours::Contour<i32>> = imageproc::contours::find_contours(&threshold_img);
// Pre-allocate based on contour count to avoid repeated reallocations.
let mut rs_boxes = Vec::with_capacity(img_contours.len());
for contour in img_contours {
if contour.points.len() <= 2 {
continue;
}
let mut max_side = 0.0;
let min_box = Self::get_mini_box(&contour.points, &mut max_side)?;
if max_side < max_side_thresh {
continue;
}
let score = Self::get_score(&contour, &pred_img)?;
if score < box_score_thresh {
continue;
}
let clip_box = Self::unclip(&min_box, un_clip_ratio)?;
if clip_box.is_empty() {
continue;
}
let mut clip_contour = Vec::new();
for point in &clip_box {
clip_contour.push(*point);
}
let mut max_side_clip = 0.0;
let clip_min_box = Self::get_mini_box(&clip_contour, &mut max_side_clip)?;
if max_side_clip < max_side_thresh + 2.0 {
continue;
}
let mut final_points = Vec::new();
for item in clip_min_box {
let x = (item.x / s.scale_width) as u32;
let ptx = x.min(s.src_width);
let y = (item.y / s.scale_height) as u32;
let pty = y.min(s.src_height);
final_points.push(ocr_result::Point { x: ptx, y: pty });
}
let text_box = TextBox {
score,
points: final_points,
};
rs_boxes.push(text_box);
}
Ok(rs_boxes)
}
fn get_mini_box(
contour_points: &[imageproc::point::Point<i32>],
min_edge_size: &mut f32,
) -> Result<Vec<imageproc::point::Point<f32>>, OcrError> {
let rect = imageproc::geometry::min_area_rect(contour_points);
let mut rect_points: Vec<imageproc::point::Point<f32>> = rect
.iter()
.map(|p| imageproc::point::Point::new(p.x as f32, p.y as f32))
.collect();
// Direct multiplication instead of .powi(2) — avoids function call overhead.
let dx_w = rect_points[0].x - rect_points[1].x;
let dy_w = rect_points[0].y - rect_points[1].y;
let width = (dx_w * dx_w + dy_w * dy_w).sqrt();
let dx_h = rect_points[1].x - rect_points[2].x;
let dy_h = rect_points[1].y - rect_points[2].y;
let height = (dx_h * dx_h + dy_h * dy_h).sqrt();
*min_edge_size = width.min(height);
rect_points.sort_by(|a, b| {
if a.x > b.x {
return Ordering::Greater;
}
if a.x == b.x {
return Ordering::Equal;
}
Ordering::Less
});
let mut box_points = Vec::new();
let index_1;
let index_4;
if rect_points[1].y > rect_points[0].y {
index_1 = 0;
index_4 = 1;
} else {
index_1 = 1;
index_4 = 0;
}
let index_2;
let index_3;
if rect_points[3].y > rect_points[2].y {
index_2 = 2;
index_3 = 3;
} else {
index_2 = 3;
index_3 = 2;
}
box_points.push(rect_points[index_1]);
box_points.push(rect_points[index_2]);
box_points.push(rect_points[index_3]);
box_points.push(rect_points[index_4]);
Ok(box_points)
}
fn get_score(
contour: &imageproc::contours::Contour<i32>,
f_map_mat: &image::ImageBuffer<image::Luma<f32>, Vec<f32>>,
) -> Result<f32, OcrError> {
// Initialize boundary values
let mut xmin = i32::MAX;
let mut xmax = i32::MIN;
let mut ymin = i32::MAX;
let mut ymax = i32::MIN;
// Find contour bounding box
for point in contour.points.iter() {
let x = point.x;
let y = point.y;
if x < xmin {
xmin = x;
}
if x > xmax {
xmax = x;
}
if y < ymin {
ymin = y;
}
if y > ymax {
ymax = y;
}
}
let width = f_map_mat.width() as i32;
let height = f_map_mat.height() as i32;
xmin = xmin.max(0).min(width - 1);
xmax = xmax.max(0).min(width - 1);
ymin = ymin.max(0).min(height - 1);
ymax = ymax.max(0).min(height - 1);
let roi_width = xmax - xmin + 1;
let roi_height = ymax - ymin + 1;
if roi_width <= 0 || roi_height <= 0 {
return Ok(0.0);
}
let mut mask = image::GrayImage::new(roi_width as u32, roi_height as u32);
let mut pts = Vec::<imageproc::point::Point<i32>>::new();
for point in contour.points.iter() {
pts.push(imageproc::point::Point::new(point.x - xmin, point.y - ymin));
}
imageproc::drawing::draw_polygon_mut(&mut mask, pts.as_slice(), image::Luma([255]));
let cropped_img =
image::imageops::crop_imm(f_map_mat, xmin as u32, ymin as u32, roi_width as u32, roi_height as u32)
.to_image();
let mean = OcrUtils::calculate_mean_with_mask(&cropped_img, &mask);
Ok(mean)
}
fn unclip(
box_points: &[imageproc::point::Point<f32>],
unclip_ratio: f32,
) -> Result<Vec<imageproc::point::Point<i32>>, OcrError> {
// Direct multiplication instead of .powi(2) — avoids function call overhead.
let dx_w = box_points[0].x - box_points[1].x;
let dy_w = box_points[0].y - box_points[1].y;
let clip_rect_width = (dx_w * dx_w + dy_w * dy_w).sqrt();
let dx_h = box_points[1].x - box_points[2].x;
let dy_h = box_points[1].y - box_points[2].y;
let clip_rect_height = (dx_h * dx_h + dy_h * dy_h).sqrt();
if clip_rect_height < 1.001 && clip_rect_width < 1.001 {
return Ok(Vec::new());
}
let mut the_cliper_pts = Vec::new();
for pt in box_points {
let a1 = Coord {
x: pt.x as f64,
y: pt.y as f64,
};
the_cliper_pts.push(a1);
}
let area = Self::signed_polygon_area(box_points).abs();
let length = Self::length_of_points(box_points);
let distance = area * unclip_ratio / length as f32;
let co = Polygon::new(LineString::new(the_cliper_pts), vec![]);
let solution = co
.offset(distance as f64, JoinType::Round(2.0), EndType::ClosedPolygon, 1.0)
.0;
if solution.is_empty() {
return Ok(Vec::new());
}
let first_polygon = solution.first().ok_or_else(|| {
OcrError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Polygon solution list was empty after offset operation",
))
})?;
let ret_pts: Vec<_> = first_polygon
.exterior()
.points()
.map(|ip| imageproc::point::Point::new(ip.x() as i32, ip.y() as i32))
.collect();
Ok(ret_pts)
}
fn signed_polygon_area(points: &[imageproc::point::Point<f32>]) -> f32 {
let num_points = points.len();
let mut pts = Vec::with_capacity(num_points + 1);
pts.extend_from_slice(points);
pts.push(points[0]);
let mut area = 0.0;
for i in 0..num_points {
area += (pts[i + 1].x - pts[i].x) * (pts[i + 1].y + pts[i].y) / 2.0;
}
area
}
fn length_of_points(box_points: &[imageproc::point::Point<f32>]) -> f64 {
if box_points.is_empty() {
return 0.0;
}
let mut length = 0.0;
let mut x0 = box_points[0].x as f64;
let mut y0 = box_points[0].y as f64;
for pt in &box_points[1..] {
let x1 = pt.x as f64;
let y1 = pt.y as f64;
let dx = x1 - x0;
let dy = y1 - y0;
length += (dx * dx + dy * dy).sqrt();
x0 = x1;
y0 = y1;
}
// Closing segment back to first point
let dx = box_points[0].x as f64 - x0;
let dy = box_points[0].y as f64 - y0;
length += (dx * dx + dy * dy).sqrt();
length
}
}

View File

@@ -0,0 +1,32 @@
//! # kreuzberg-paddle-ocr
//!
//! PaddleOCR via ONNX Runtime for Kreuzberg - high-performance text detection and recognition.
//!
//! This crate is vendored from [paddle-ocr-rs](https://github.com/mg-chao/paddle-ocr-rs)
//! by mg-chao, with modifications for Kreuzberg integration.
//!
//! ## ONNX Runtime Requirement
//!
//! Requires **ONNX Runtime 1.24+** at runtime.
//!
//! ## Original License
//!
//! The original paddle-ocr-rs is licensed under Apache-2.0.
//! This vendored version is relicensed to MIT with the original author's copyright retained.
#![allow(clippy::too_many_arguments)]
pub mod angle_net;
pub mod base_net;
pub(crate) mod constants;
pub mod crnn_net;
pub mod db_net;
pub mod ocr_error;
pub mod ocr_lite;
pub mod ocr_result;
pub mod ocr_utils;
pub mod scale_param;
pub use ocr_error::OcrError;
pub use ocr_lite::OcrLite;
pub use ocr_result::{Angle, OcrResult, Point, TextBlock, TextBox, TextLine};

View File

@@ -0,0 +1,13 @@
use thiserror::Error;
#[derive(Error, Debug)]
pub enum OcrError {
#[error("Ort error: {0}")]
Ort(#[from] ort::Error),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Image error: {0}")]
ImageError(#[from] image::ImageError),
#[error("Session not initialized")]
SessionNotInitialized,
}

View File

@@ -0,0 +1,447 @@
use std::collections::HashMap;
use image::ImageBuffer;
use ort::session::builder::SessionBuilder;
use crate::{
angle_net::AngleNet,
base_net::BaseNet,
crnn_net::CrnnNet,
db_net::DbNet,
ocr_error::OcrError,
ocr_result::{OcrResult, Point, TextBlock, TextBox},
ocr_utils::OcrUtils,
scale_param::ScaleParam,
};
#[derive(Debug)]
pub struct OcrLite {
db_net: DbNet,
angle_net: AngleNet,
crnn_net: CrnnNet,
}
// SAFETY: OcrLite inference methods (&self) use unsafe pointer casts to call
// ort Session::run, which is thread-safe at the ONNX Runtime C API level.
// After initialization (&mut self), no mutable state is accessed during inference.
unsafe impl Send for OcrLite {}
unsafe impl Sync for OcrLite {}
impl Default for OcrLite {
fn default() -> Self {
Self::new()
}
}
impl OcrLite {
pub fn new() -> Self {
Self {
db_net: DbNet::new(),
angle_net: AngleNet::new(),
crnn_net: CrnnNet::new(),
}
}
pub fn init_models(
&mut self,
det_path: &str,
cls_path: &str,
rec_path: &str,
num_thread: usize,
) -> Result<(), OcrError> {
self.db_net.init_model(det_path, num_thread, None)?;
self.angle_net.init_model(cls_path, num_thread, None)?;
self.crnn_net.init_model(rec_path, num_thread, None)?;
Ok(())
}
pub fn init_models_with_dict(
&mut self,
det_path: &str,
cls_path: &str,
rec_path: &str,
dict_path: &str,
num_thread: usize,
) -> Result<(), OcrError> {
self.db_net.init_model(det_path, num_thread, None)?;
self.angle_net.init_model(cls_path, num_thread, None)?;
self.crnn_net
.init_model_dict_file(rec_path, num_thread, None, dict_path)?;
Ok(())
}
pub fn init_models_custom(
&mut self,
det_path: &str,
cls_path: &str,
rec_path: &str,
builder_fn: fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>,
) -> Result<(), OcrError> {
self.db_net.init_model(det_path, 0, Some(builder_fn))?;
self.angle_net.init_model(cls_path, 0, Some(builder_fn))?;
self.crnn_net.init_model(rec_path, 0, Some(builder_fn))?;
Ok(())
}
/// Initialize models with dictionary file and custom session builder.
///
/// Combines `init_models_with_dict` and `init_models_custom`: loads the
/// dictionary for the recognition model while applying a custom ORT
/// session builder (e.g. for GPU execution providers).
pub fn init_models_with_dict_custom(
&mut self,
det_path: &str,
cls_path: &str,
rec_path: &str,
dict_path: &str,
num_thread: usize,
builder_fn: Option<fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>>,
) -> Result<(), OcrError> {
self.db_net.init_model(det_path, num_thread, builder_fn)?;
self.angle_net.init_model(cls_path, num_thread, builder_fn)?;
self.crnn_net
.init_model_dict_file(rec_path, num_thread, builder_fn, dict_path)?;
Ok(())
}
pub fn init_models_from_memory(
&mut self,
det_bytes: &[u8],
cls_bytes: &[u8],
rec_bytes: &[u8],
num_thread: usize,
) -> Result<(), OcrError> {
self.db_net.init_model_from_memory(det_bytes, num_thread, None)?;
self.angle_net.init_model_from_memory(cls_bytes, num_thread, None)?;
self.crnn_net.init_model_from_memory(rec_bytes, num_thread, None)?;
Ok(())
}
pub fn init_models_from_memory_custom(
&mut self,
det_bytes: &[u8],
cls_bytes: &[u8],
rec_bytes: &[u8],
builder_fn: fn(SessionBuilder) -> Result<SessionBuilder, ort::Error>,
) -> Result<(), OcrError> {
self.db_net.init_model_from_memory(det_bytes, 0, Some(builder_fn))?;
self.angle_net.init_model_from_memory(cls_bytes, 0, Some(builder_fn))?;
self.crnn_net.init_model_from_memory(rec_bytes, 0, Some(builder_fn))?;
Ok(())
}
fn detect_base(
&self,
img_src: &image::RgbImage,
padding: u32,
max_side_len: u32,
box_score_thresh: f32,
box_thresh: f32,
un_clip_ratio: f32,
do_angle: bool,
most_angle: bool,
angle_rollback: bool,
angle_rollback_threshold: f32,
cls_thresh: f32,
thresh: f32,
) -> Result<OcrResult, OcrError> {
let origin_max_side = img_src.width().max(img_src.height());
let mut resize;
if max_side_len == 0 || max_side_len > origin_max_side {
resize = origin_max_side;
} else {
resize = max_side_len;
}
resize += 2 * padding;
// Cow avoids cloning the image when padding=0 (the common case).
let padding_src = OcrUtils::make_padding(img_src, padding)?;
let scale = ScaleParam::get_scale_param(&padding_src, resize);
self.detect_once(
&padding_src,
&scale,
padding,
box_score_thresh,
box_thresh,
un_clip_ratio,
do_angle,
most_angle,
angle_rollback,
angle_rollback_threshold,
cls_thresh,
thresh,
)
}
/// Detect text in image
///
/// # Arguments
///
/// - `img_src` - Input image
/// - `padding` - Padding width added during image transformation (improves detection)
/// - `max_side_len` - Maximum side length after transformation (larger images will be scaled down)
/// - `box_score_thresh` - Score threshold for text region detection
/// - `box_thresh` - Box threshold
/// - `un_clip_ratio` - Unclip ratio
/// - `do_angle` - Whether to perform angle detection
/// - `most_angle` - Use most common angle for all text regions
const DEFAULT_CLS_THRESH: f32 = 0.9;
const DEFAULT_THRESH: f32 = 0.3;
const DEFAULT_REC_BATCH_SIZE: u32 = 6;
pub fn detect(
&self,
img_src: &image::RgbImage,
padding: u32,
max_side_len: u32,
box_score_thresh: f32,
box_thresh: f32,
un_clip_ratio: f32,
do_angle: bool,
most_angle: bool,
) -> Result<OcrResult, OcrError> {
self.detect_base(
img_src,
padding,
max_side_len,
box_score_thresh,
box_thresh,
un_clip_ratio,
do_angle,
most_angle,
false,
0.0,
Self::DEFAULT_CLS_THRESH,
Self::DEFAULT_THRESH,
)
}
/// Detect text with angle rollback support
///
/// When `do_angle` is true, if the image was angle-corrected but recognition
/// result is poor, the angle correction will be reverted.
///
/// # Arguments
///
/// - `img_src` - Input image
/// - `padding` - Padding width added during image transformation
/// - `max_side_len` - Maximum side length after transformation
/// - `box_score_thresh` - Score threshold for text region detection
/// - `box_thresh` - Box threshold
/// - `un_clip_ratio` - Unclip ratio
/// - `do_angle` - Whether to perform angle detection
/// - `most_angle` - Use most common angle
/// - `angle_rollback_threshold` - If text score is below this value (or NaN), angle correction is reverted
pub fn detect_angle_rollback(
&self,
img_src: &image::RgbImage,
padding: u32,
max_side_len: u32,
box_score_thresh: f32,
box_thresh: f32,
un_clip_ratio: f32,
do_angle: bool,
most_angle: bool,
angle_rollback_threshold: f32,
) -> Result<OcrResult, OcrError> {
self.detect_base(
img_src,
padding,
max_side_len,
box_score_thresh,
box_thresh,
un_clip_ratio,
do_angle,
most_angle,
true,
angle_rollback_threshold,
Self::DEFAULT_CLS_THRESH,
Self::DEFAULT_THRESH,
)
}
pub fn detect_from_path(
&self,
img_path: &str,
padding: u32,
max_side_len: u32,
box_score_thresh: f32,
box_thresh: f32,
un_clip_ratio: f32,
do_angle: bool,
most_angle: bool,
) -> Result<OcrResult, OcrError> {
let img_src = image::open(img_path)?.to_rgb8();
self.detect(
&img_src,
padding,
max_side_len,
box_score_thresh,
box_thresh,
un_clip_ratio,
do_angle,
most_angle,
)
}
/// Sort text boxes in reading order: top-to-bottom, left-to-right.
///
/// Sorts by top-left Y coordinate first, then by top-left X coordinate within
/// the same Y. Matches PaddleOCR Python's `sorted_boxes` primary ordering.
fn sort_text_boxes(text_boxes: &mut [TextBox]) {
text_boxes.sort_by(|a, b| {
let ay = a.points.first().map_or(0, |p| p.y);
let ax = a.points.first().map_or(0, |p| p.x);
let by = b.points.first().map_or(0, |p| p.y);
let bx = b.points.first().map_or(0, |p| p.x);
(ay, ax).cmp(&(by, bx))
});
}
fn detect_once(
&self,
img_src: &image::RgbImage,
scale: &ScaleParam,
padding: u32,
box_score_thresh: f32,
box_thresh: f32,
un_clip_ratio: f32,
do_angle: bool,
most_angle: bool,
angle_rollback: bool,
angle_rollback_threshold: f32,
cls_thresh: f32,
thresh: f32,
) -> Result<OcrResult, OcrError> {
let mut text_boxes =
self.db_net
.get_text_boxes(img_src, scale, box_score_thresh, box_thresh, un_clip_ratio, thresh)?;
// Sort boxes in reading order (top-to-bottom, left-to-right)
Self::sort_text_boxes(&mut text_boxes);
let part_images = OcrUtils::get_part_images(img_src, &text_boxes);
let angles = self
.angle_net
.get_angles(&part_images, do_angle, most_angle, cls_thresh)?;
let mut rotated_images: Vec<image::RgbImage> = Vec::with_capacity(part_images.len());
// Angle correction rollback
let mut angle_rollback_records = HashMap::<usize, ImageBuffer<image::Rgb<u8>, Vec<u8>>>::new();
for (index, (angle, mut part_image)) in angles.iter().zip(part_images).enumerate() {
if angle.index == 1 {
if angle_rollback {
// Keep original copy
angle_rollback_records.insert(index, part_image.clone());
}
OcrUtils::mat_rotate_clock_wise_180(&mut part_image);
}
rotated_images.push(part_image);
}
let text_lines = self.crnn_net.get_text_lines(
&rotated_images,
&angle_rollback_records,
angle_rollback_threshold,
Self::DEFAULT_REC_BATCH_SIZE,
)?;
let mut text_blocks = Vec::with_capacity(text_lines.len());
for (i, text_line) in text_lines.into_iter().enumerate() {
text_blocks.push(TextBlock {
box_points: text_boxes[i]
.points
.iter()
.map(|p| Point {
x: ((p.x as f32) - padding as f32) as u32,
y: ((p.y as f32) - padding as f32) as u32,
})
.collect(),
box_score: text_boxes[i].score,
angle_index: angles[i].index,
angle_score: angles[i].score,
text: text_line.text,
text_score: text_line.text_score,
});
}
Ok(OcrResult { text_blocks })
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ocr_result::TextBox;
fn make_box(x: u32, y: u32) -> TextBox {
TextBox {
points: vec![
Point { x, y },
Point { x: x + 100, y },
Point { x: x + 100, y: y + 20 },
Point { x, y: y + 20 },
],
score: 0.9,
}
}
#[test]
fn test_sort_text_boxes_top_to_bottom() {
let mut boxes = vec![make_box(10, 100), make_box(10, 50), make_box(10, 10)];
OcrLite::sort_text_boxes(&mut boxes);
assert_eq!(boxes[0].points[0].y, 10);
assert_eq!(boxes[1].points[0].y, 50);
assert_eq!(boxes[2].points[0].y, 100);
}
#[test]
fn test_sort_text_boxes_same_line_left_to_right() {
// Boxes with the same Y are sorted left-to-right by X
let mut boxes = vec![make_box(200, 10), make_box(100, 10), make_box(50, 10)];
OcrLite::sort_text_boxes(&mut boxes);
assert_eq!(boxes[0].points[0].x, 50);
assert_eq!(boxes[1].points[0].x, 100);
assert_eq!(boxes[2].points[0].x, 200);
}
#[test]
fn test_sort_text_boxes_multi_line() {
// Boxes sorted strictly by (y, x): y=50/x=50, y=50/x=300, y=100/x=100, y=100/x=200
let mut boxes = vec![
make_box(300, 50), // line 1, right
make_box(100, 100), // line 2, left
make_box(50, 50), // line 1, left (same y=50)
make_box(200, 100), // line 2, right (same y=100)
];
OcrLite::sort_text_boxes(&mut boxes);
// Line 1 (y=50): left first, then right
assert_eq!(boxes[0].points[0].x, 50);
assert_eq!(boxes[1].points[0].x, 300);
// Line 2 (y=100): left first, then right
assert_eq!(boxes[2].points[0].x, 100);
assert_eq!(boxes[3].points[0].x, 200);
}
#[test]
fn test_sort_text_boxes_empty() {
let mut boxes: Vec<TextBox> = vec![];
OcrLite::sort_text_boxes(&mut boxes);
assert!(boxes.is_empty());
}
#[test]
fn test_sort_text_boxes_single() {
let mut boxes = vec![make_box(10, 20)];
OcrLite::sort_text_boxes(&mut boxes);
assert_eq!(boxes.len(), 1);
}
}

View File

@@ -0,0 +1,105 @@
use std::fmt::{self, Write};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct Point {
pub x: u32,
pub y: u32,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TextBox {
pub points: Vec<Point>,
pub score: f32,
}
impl fmt::Display for TextBox {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// SAFETY: We must have at least 4 points in a valid TextBox
// This is enforced at the OCR processing level, but we check bounds here for safety
if self.points.len() < 4 {
return write!(
f,
"TextBox [score({}), points_count({})]",
self.score,
self.points.len()
);
}
write!(
f,
"TextBox [score({}), [x: {}, y: {}], [x: {}, y: {}], [x: {}, y: {}], [x: {}, y: {}]]",
self.score,
self.points[0].x,
self.points[0].y,
self.points[1].x,
self.points[1].y,
self.points[2].x,
self.points[2].y,
self.points[3].x,
self.points[3].y,
)
}
}
#[derive(Debug, Default)]
pub struct Angle {
pub index: i32,
pub score: f32,
}
impl fmt::Display for Angle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let header = if self.index >= 0 { "Angle" } else { "AngleDisabled" };
write!(f, "{}[Index({}), Score({})]", header, self.index, self.score)
}
}
#[derive(Debug, Default)]
pub struct TextLine {
pub text: String,
pub text_score: f32,
}
impl fmt::Display for TextLine {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "TextLine[Text({}),TextScore({})]", self.text, self.text_score)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TextBlock {
pub box_points: Vec<Point>,
pub box_score: f32,
pub angle_index: i32,
pub angle_score: f32,
pub text: String,
pub text_score: f32,
}
#[derive(Serialize, Deserialize)]
pub struct OcrResult {
pub text_blocks: Vec<TextBlock>,
}
impl fmt::Display for OcrResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut str_builder = String::with_capacity(0);
for text_block in &self.text_blocks {
write!(
str_builder,
"TextBlock[BoxPointsLen({}), BoxScore({}), AngleIndex({}), AngleScore({}), Text({}), TextScore({})]",
text_block.box_points.len(),
text_block.box_score,
text_block.angle_index,
text_block.angle_score,
text_block.text,
text_block.text_score
)?;
}
f.write_str(&str_builder)
}
}

View File

@@ -0,0 +1,206 @@
use std::borrow::Cow;
use crate::{
ocr_error::OcrError,
ocr_result::{Point, TextBox},
};
use image::imageops;
use imageproc::geometric_transformations::{Interpolation, Projection};
use ndarray::{Array, Array4};
pub struct OcrUtils;
impl OcrUtils {
/// Normalize image pixels and transpose from HWC (row-major RGB) to CHW tensor format.
///
/// Formula per pixel: `output[ch] = pixel[ch] * norm[ch] - mean[ch] * norm[ch]`
///
/// This is a hot path called once per page. Key optimizations:
/// - Pre-computes `mean * norm` constants (avoids repeated multiply)
/// - Writes each channel plane contiguously via `as_slice_mut()`, enabling
/// LLVM auto-vectorization (NEON on ARM64, SSE/AVX on x86-64). The previous
/// approach used `tensor[[0, ch, r, c]]` which scattered writes across planes
/// and prevented any vectorization.
pub fn substract_mean_normalize(img_src: &image::RgbImage, mean_vals: &[f32], norm_vals: &[f32]) -> Array4<f32> {
let cols = img_src.width() as usize;
let rows = img_src.height() as usize;
let pixel_count = rows * cols;
let mut input_tensor = Array::zeros((1, 3, rows, cols));
let adjusted = [
mean_vals[0] * norm_vals[0],
mean_vals[1] * norm_vals[1],
mean_vals[2] * norm_vals[2],
];
let raw = img_src.as_raw();
// Write each channel plane as a contiguous slice. ndarray stores (1,3,H,W)
// in C-contiguous (row-major) order, so plane [0,ch] is a contiguous H*W block.
// This enables LLVM to auto-vectorize the inner loop (4-8 f32 ops per cycle).
for ch in 0..3 {
let norm = norm_vals[ch];
let adj = adjusted[ch];
let plane = input_tensor
.slice_mut(ndarray::s![0, ch, .., ..])
.into_shape_with_order(pixel_count)
.expect("contiguous plane slice");
let plane_slice = plane.into_slice().expect("contiguous memory");
for (i, out) in plane_slice.iter_mut().enumerate() {
// raw is HWC: pixel i has R at raw[i*3], G at raw[i*3+1], B at raw[i*3+2]
*out = raw[i * 3 + ch] as f32 * norm - adj;
}
}
input_tensor
}
/// Add white padding around the image, or borrow it unchanged when padding=0.
/// Returns Cow to avoid cloning the image in the common no-padding case.
pub fn make_padding<'a>(img_src: &'a image::RgbImage, padding: u32) -> Result<Cow<'a, image::RgbImage>, OcrError> {
if padding == 0 {
return Ok(Cow::Borrowed(img_src));
}
let width = img_src.width();
let height = img_src.height();
let mut padding_src = image::RgbImage::new(width + 2 * padding, height + 2 * padding);
imageproc::drawing::draw_filled_rect_mut(
&mut padding_src,
imageproc::rect::Rect::at(0, 0).of_size(width + 2 * padding, height + 2 * padding),
image::Rgb([255, 255, 255]),
);
image::imageops::replace(&mut padding_src, img_src, padding as i64, padding as i64);
Ok(Cow::Owned(padding_src))
}
pub fn get_part_images(img_src: &image::RgbImage, text_boxes: &[TextBox]) -> Vec<image::RgbImage> {
text_boxes
.iter()
.map(|text_box| Self::get_rotate_crop_image(img_src, &text_box.points))
.collect()
}
pub fn get_rotate_crop_image(img_src: &image::RgbImage, box_points: &[Point]) -> image::RgbImage {
let mut points = box_points.to_vec();
// Calculate bounding box
let (min_x, min_y, max_x, max_y) = points.iter().fold(
(u32::MAX, u32::MAX, 0u32, 0u32),
|(min_x, min_y, max_x, max_y), point| {
(
min_x.min(point.x),
min_y.min(point.y),
max_x.max(point.x),
max_y.max(point.y),
)
},
);
// Crop image
let img_crop = imageops::crop_imm(img_src, min_x, min_y, max_x - min_x, max_y - min_y).to_image();
for point in &mut points {
point.x = point.x.saturating_sub(min_x);
point.y = point.y.saturating_sub(min_y);
}
// Ensure we have enough points for transformation
if points.len() < 4 {
// Fallback: return the cropped image as-is if we don't have 4 points
return img_crop;
}
// Direct multiplication instead of .pow(2) — avoids integer power function overhead.
let dx_w = (points[0].x as i32 - points[1].x as i32) as f32;
let dy_w = (points[0].y as i32 - points[1].y as i32) as f32;
let img_crop_width = (dx_w * dx_w + dy_w * dy_w).sqrt() as u32;
let dx_h = (points[0].x as i32 - points[3].x as i32) as f32;
let dy_h = (points[0].y as i32 - points[3].y as i32) as f32;
let img_crop_height = (dx_h * dx_h + dy_h * dy_h).sqrt() as u32;
// Ensure dimensions are valid (non-zero)
if img_crop_width == 0 || img_crop_height == 0 {
return img_crop;
}
let src_points = [
(points[0].x as f32, points[0].y as f32),
(points[1].x as f32, points[1].y as f32),
(points[2].x as f32, points[2].y as f32),
(points[3].x as f32, points[3].y as f32),
];
let dst_points = [
(0.0, 0.0),
(img_crop_width as f32, 0.0),
(img_crop_width as f32, img_crop_height as f32),
(0.0, img_crop_height as f32),
];
let projection = match Projection::from_control_points(src_points, dst_points) {
Some(proj) => proj,
None => {
// If projection cannot be created, return the cropped image as fallback
return img_crop;
}
};
let mut part_img = image::RgbImage::new(img_crop_width, img_crop_height);
imageproc::geometric_transformations::warp_into(
&img_crop,
&projection,
Interpolation::Nearest,
image::Rgb([255, 255, 255]),
&mut part_img,
);
// Rotate image if needed
if part_img.height() >= part_img.width() * 3 / 2 {
let mut rotated = image::RgbImage::new(part_img.height(), part_img.width());
for (x, y, pixel) in part_img.enumerate_pixels() {
rotated.put_pixel(y, part_img.width() - 1 - x, *pixel);
}
rotated
} else {
part_img
}
}
pub fn mat_rotate_clock_wise_180(src: &mut image::RgbImage) {
imageops::rotate180_in_place(src);
}
/// Compute mean of f32 image values where mask > 0.
///
/// Uses raw slice access instead of per-pixel get_pixel() for better
/// cache behavior and to enable auto-vectorization of the reduction.
pub fn calculate_mean_with_mask(
img: &image::ImageBuffer<image::Luma<f32>, Vec<f32>>,
mask: &image::ImageBuffer<image::Luma<u8>, Vec<u8>>,
) -> f32 {
assert_eq!(img.width(), mask.width());
assert_eq!(img.height(), mask.height());
let img_raw = img.as_raw();
let mask_raw = mask.as_raw();
let mut sum: f32 = 0.0;
let mut count: u32 = 0;
for (px, &m) in img_raw.iter().zip(mask_raw.iter()) {
if m > 0 {
sum += *px;
count += 1;
}
}
if count == 0 { 0.0 } else { sum / count as f32 }
}
}

View File

@@ -0,0 +1,69 @@
#[derive(Debug)]
pub struct ScaleParam {
pub src_width: u32,
pub src_height: u32,
pub dst_width: u32,
pub dst_height: u32,
pub scale_width: f32,
pub scale_height: f32,
}
impl ScaleParam {
pub fn new(
src_width: u32,
src_height: u32,
dst_width: u32,
dst_height: u32,
scale_width: f32,
scale_height: f32,
) -> Self {
Self {
src_width,
src_height,
dst_width,
dst_height,
scale_width,
scale_height,
}
}
pub fn get_scale_param(src: &image::RgbImage, target_size: u32) -> Self {
let src_width = src.width();
let src_height = src.height();
let mut dst_width;
let mut dst_height;
let ratio: f32 = if src_width > src_height {
target_size as f32 / src_width as f32
} else {
target_size as f32 / src_height as f32
};
dst_width = (src_width as f32 * ratio) as u32;
dst_height = (src_height as f32 * ratio) as u32;
if dst_width % 32 != 0 {
dst_width = (dst_width / 32) * 32;
dst_width = dst_width.max(32);
}
if dst_height % 32 != 0 {
dst_height = (dst_height / 32) * 32;
dst_height = dst_height.max(32);
}
let scale_width = dst_width as f32 / src_width as f32;
let scale_height = dst_height as f32 / src_height as f32;
Self::new(src_width, src_height, dst_width, dst_height, scale_width, scale_height)
}
}
impl std::fmt::Display for ScaleParam {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"src_width:{},src_height:{},dst_width:{},dst_height:{},scale_width:{},scale_height:{}",
self.src_width, self.src_height, self.dst_width, self.dst_height, self.scale_width, self.scale_height
)
}
}

View File

@@ -0,0 +1,436 @@
//! Diagnostic test to trace PaddleOCR detection pipeline.
//!
//! This test isolates each step to determine where empty results originate.
//! Since this crate doesn't have PNG/image decoder features, we create test
//! images programmatically.
use std::path::PathBuf;
fn get_workspace_root() -> PathBuf {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
manifest_dir.parent().unwrap().parent().unwrap().to_path_buf()
}
fn get_model_dir() -> PathBuf {
get_workspace_root().join(".kreuzberg/paddle-ocr")
}
/// Create a simple test image with black text "HELLO" on white background.
/// This avoids needing PNG decoder features.
fn create_test_image() -> image::RgbImage {
let width = 200u32;
let height = 100u32;
let mut img = image::RgbImage::from_pixel(width, height, image::Rgb([255, 255, 255]));
// Draw a thick black rectangle to simulate text (a simple "block" pattern)
// This ensures the detection model has SOMETHING to detect
let black = image::Rgb([0, 0, 0]);
// Draw "H" shape (x: 20-60, y: 20-80)
for y in 20..80 {
img.put_pixel(20, y, black);
img.put_pixel(21, y, black);
img.put_pixel(22, y, black);
}
for y in 20..80 {
img.put_pixel(55, y, black);
img.put_pixel(56, y, black);
img.put_pixel(57, y, black);
}
for x in 20..58 {
img.put_pixel(x, 48, black);
img.put_pixel(x, 49, black);
img.put_pixel(x, 50, black);
}
// Draw thick solid block to be very obvious (x: 80-180, y: 30-70)
for y in 30..70 {
for x in 80..180 {
img.put_pixel(x, y, black);
}
}
img
}
#[test]
fn diagnostic_detection_pipeline() {
let model_dir = get_model_dir();
if !model_dir.join("det/model.onnx").exists() {
eprintln!("SKIP: Models not downloaded at {:?}", model_dir);
return;
}
// Discover ORT library
discover_ort();
eprintln!("=== PaddleOCR Diagnostic Test ===");
eprintln!("Model dir: {:?}", model_dir);
// Step 1: Create test image
let img = create_test_image();
eprintln!("Step 1 - Test image created: {}x{}", img.width(), img.height());
// Step 2: Initialize OcrLite
let mut ocr_lite = kreuzberg_paddle_ocr::OcrLite::new();
let det_path = model_dir.join("det/model.onnx");
let cls_path = model_dir.join("cls/model.onnx");
let rec_path = model_dir.join("rec/model.onnx");
let init_result = ocr_lite.init_models(
det_path.to_str().unwrap(),
cls_path.to_str().unwrap(),
rec_path.to_str().unwrap(),
1,
);
match &init_result {
Ok(()) => eprintln!("Step 2 - Models initialized successfully"),
Err(e) => {
eprintln!("Step 2 - FAILED to init models: {:?}", e);
panic!("Model initialization failed: {:?}", e);
}
}
// Step 3: Run detection with various parameter sets
let test_cases = vec![
("A: Default params", 50u32, 960u32, 0.3f32, 0.5f32, 1.6f32, true, false),
("B: Very low thresholds", 50, 960, 0.01, 0.01, 1.6, false, false),
("C: No padding + low", 0, 960, 0.01, 0.01, 1.6, false, false),
("D: Higher unclip ratio", 50, 960, 0.1, 0.1, 3.0, false, false),
("E: No padding + medium", 0, 960, 0.1, 0.3, 2.0, false, false),
];
let mut any_detected = false;
for (name, padding, max_side, box_score, box_thresh, unclip, do_angle, most_angle) in &test_cases {
eprintln!("\n--- Test {} ---", name);
eprintln!(
" padding={}, max_side={}, box_score={}, box_thresh={}, unclip={}",
padding, max_side, box_score, box_thresh, unclip
);
let result = ocr_lite.detect(
&img,
*padding,
*max_side,
*box_score,
*box_thresh,
*unclip,
*do_angle,
*most_angle,
);
match &result {
Ok(ocr_result) => {
eprintln!(" Result: {} text blocks", ocr_result.text_blocks.len());
for (i, block) in ocr_result.text_blocks.iter().enumerate() {
eprintln!(
" Block {}: text='{}', text_score={:.3}, box_score={:.3}",
i, block.text, block.text_score, block.box_score
);
any_detected = true;
}
}
Err(e) => {
eprintln!(" FAILED: {:?}", e);
}
}
}
eprintln!("\n=== Diagnosis ===");
if !any_detected {
eprintln!("RESULT: Detection model produces NO output regardless of thresholds.");
eprintln!("This strongly suggests an ORT version compatibility issue.");
eprintln!(" ort crate version: check Cargo.lock for current version");
eprintln!(" ORT_DYLIB_PATH: {:?}", std::env::var("ORT_DYLIB_PATH"));
} else {
eprintln!("RESULT: Detection works. Issue may be threshold-related or image-specific.");
}
}
/// Also test with raw ONNX inference to check if ORT works at all.
#[test]
fn diagnostic_raw_ort_inference() {
let model_dir = get_model_dir();
let det_model = model_dir.join("det/model.onnx");
if !det_model.exists() {
eprintln!("SKIP: Detection model not found at {:?}", det_model);
return;
}
discover_ort();
eprintln!("=== Raw ORT Inference Test ===");
// Load model directly via ort
use ort::session::Session;
let mut session = Session::builder().unwrap().commit_from_file(&det_model).unwrap();
eprintln!("Model loaded successfully");
eprintln!("Inputs:");
for input in session.inputs() {
eprintln!(" name='{}'", input.name());
}
eprintln!("Outputs:");
for output in session.outputs() {
eprintln!(" name='{}'", output.name());
}
// Create a small 32x32 test tensor (NCHW format: batch=1, channels=3, h=32, w=32)
let input_data: Vec<f32> = vec![0.5; 3 * 32 * 32];
let tensor =
ort::value::Tensor::from_array(ndarray::Array::from_shape_vec((1, 3, 32, 32), input_data).unwrap()).unwrap();
let input_name = session.inputs()[0].name().to_string();
eprintln!("\nRunning inference with 32x32 gray image...");
let outputs = session.run(ort::inputs![input_name => tensor]).unwrap();
// Check output
let (output_name, output_value) = outputs.iter().next().unwrap();
eprintln!("Output name: {}", output_name);
let output_tensor = output_value.try_extract_tensor::<f32>().unwrap();
let output_shape = output_tensor.0;
let output_data = output_tensor.1;
eprintln!("Output shape: {:?}", output_shape);
eprintln!("Output len: {}", output_data.len());
if !output_data.is_empty() {
let min = output_data.iter().cloned().fold(f32::INFINITY, f32::min);
let max = output_data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let sum: f32 = output_data.iter().sum();
let mean = sum / output_data.len() as f32;
let non_zero = output_data.iter().filter(|&&v| v > 0.001).count();
eprintln!("Output stats: min={:.6}, max={:.6}, mean={:.6}", min, max, mean);
eprintln!("Non-zero values (>0.001): {} / {}", non_zero, output_data.len());
if max < 0.001 {
eprintln!("\nDIAGNOSIS: Model outputs are essentially all zeros.");
eprintln!("This confirms an ORT compatibility issue - model isn't executing correctly.");
} else {
eprintln!("\nDIAGNOSIS: Model produces non-zero output. ORT is working.");
}
}
}
/// Diagnostic: test the CRNN recognition model directly.
#[test]
fn diagnostic_crnn_model_output() {
let model_dir = get_model_dir();
let rec_model = model_dir.join("rec/model.onnx");
if !rec_model.exists() {
eprintln!("SKIP: Recognition model not found");
return;
}
discover_ort();
eprintln!("=== CRNN Recognition Model Diagnostic ===");
use ort::session::Session;
let mut session = Session::builder().unwrap().commit_from_file(&rec_model).unwrap();
eprintln!("Model loaded successfully");
eprintln!("Inputs:");
for input in session.inputs() {
eprintln!(" name='{}'", input.name());
}
eprintln!("Outputs:");
for output in session.outputs() {
eprintln!(" name='{}'", output.name());
}
// Check metadata for character list
{
let metadata = session.metadata().unwrap();
// Check all metadata custom keys
eprintln!("Model metadata:");
eprintln!(" description: {:?}", metadata.description());
eprintln!(" producer: {:?}", metadata.producer());
// Try to get the character key
match metadata.custom("character") {
Some(chars) => {
let bytes = chars.as_bytes();
let char_count = chars.split('\n').count();
eprintln!(
" custom('character'): len={}, bytes={}, split_count={}",
chars.len(),
bytes.len(),
char_count
);
if chars.len() < 500 {
eprintln!(" value: {:?}", chars);
} else {
let preview: String = chars.chars().take(100).collect();
eprintln!(" preview (first 100 chars): {:?}", preview);
}
// Check for null bytes or other encoding issues
let null_count = bytes.iter().filter(|&&b| b == 0).count();
if null_count > 0 {
eprintln!(" WARNING: {} null bytes found in character string!", null_count);
}
}
None => {
eprintln!(" ERROR: No 'character' key in model metadata!");
}
}
// Try other possible metadata keys
for key in [
"character",
"characters",
"dict",
"dictionary",
"labels",
"vocab",
"alphabet",
] {
if let Some(val) = metadata.custom(key) {
eprintln!(
" custom('{}'): len={}, preview={:?}",
key,
val.len(),
&val[..val.len().min(80)]
);
}
}
} // metadata dropped here
// Test 1: Run inference with a simple input (height=48, width=200)
// CRNN expects NCHW: [1, 3, 48, width]
let h = 48usize;
let w = 200usize;
// Create a pattern that looks like text (alternating black/white vertical stripes)
let mut input_data: Vec<f32> = vec![0.0; 3 * h * w];
for c in 0..3 {
for y in 10..38 {
for x in (20..180).step_by(2) {
input_data[c * h * w + y * w + x] = -1.0; // normalized black
}
}
}
let tensor =
ort::value::Tensor::from_array(ndarray::Array::from_shape_vec((1, 3, h, w), input_data).unwrap()).unwrap();
let input_name = session.inputs()[0].name().to_string();
eprintln!("\nRunning CRNN with striped pattern (48x200)...");
let outputs = session.run(ort::inputs![input_name => tensor]).unwrap();
let (_, output_value) = outputs.iter().next().unwrap();
let (shape, data) = output_value.try_extract_tensor::<f32>().unwrap();
eprintln!("Output shape: {:?}", shape);
eprintln!("Output total values: {}", data.len());
if shape.len() >= 3 {
let time_steps = shape[1] as usize;
let vocab_size = shape[2] as usize;
eprintln!("Time steps: {}, Vocabulary size: {}", time_steps, vocab_size);
// Check if outputs are meaningful
let data_vec: Vec<f32> = data.to_vec();
let min = data_vec.iter().cloned().fold(f32::INFINITY, f32::min);
let max = data_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mean: f32 = data_vec.iter().sum::<f32>() / data_vec.len() as f32;
eprintln!("Overall stats: min={:.6}, max={:.6}, mean={:.6}", min, max, mean);
// Check argmax distribution
let mut argmax_zero_count = 0;
let mut argmax_nonzero_count = 0;
for t in 0..time_steps {
let start = t * vocab_size;
let end = start + vocab_size;
let slice = &data_vec[start..end.min(data_vec.len())];
let (max_idx, max_val) =
slice.iter().enumerate().fold(
(0, f32::MIN),
|(mi, mv), (i, &v)| if v > mv { (i, v) } else { (mi, mv) },
);
if max_idx == 0 {
argmax_zero_count += 1;
} else {
argmax_nonzero_count += 1;
}
if t < 5 || (t > time_steps - 3) {
eprintln!(" Step {}: argmax={}, max_val={:.4}", t, max_idx, max_val);
} else if t == 5 {
eprintln!(" ... (skipping middle steps)");
}
}
eprintln!(
"\nArgmax distribution: {} blank (idx=0), {} non-blank",
argmax_zero_count, argmax_nonzero_count
);
if argmax_nonzero_count == 0 {
eprintln!("\nDIAGNOSIS: CRNN model outputs all blanks.");
eprintln!("Possible causes:");
eprintln!(" 1. ORT version incompatibility with CRNN model");
eprintln!(" 2. Model is not executing graph correctly");
eprintln!(" 3. Input normalization mismatch");
} else {
eprintln!("\nDIAGNOSIS: CRNN model produces non-blank output. Recognition works.");
}
}
// Drop outputs before reusing session
drop(outputs);
// Test 2: Run with a uniform white image (should produce all blanks - valid baseline)
let white_data: Vec<f32> = vec![1.0; 3 * h * w];
let white_tensor =
ort::value::Tensor::from_array(ndarray::Array::from_shape_vec((1, 3, h, w), white_data).unwrap()).unwrap();
let input_name2 = session.inputs()[0].name().to_string();
eprintln!("\nRunning CRNN with uniform white (48x200)...");
let white_outputs = session.run(ort::inputs![input_name2 => white_tensor]).unwrap();
let (_, white_val) = white_outputs.iter().next().unwrap();
let (_, white_data_out) = white_val.try_extract_tensor::<f32>().unwrap();
let white_vec: Vec<f32> = white_data_out.to_vec();
let white_max = white_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let white_min = white_vec.iter().cloned().fold(f32::INFINITY, f32::min);
eprintln!("White image output: min={:.6}, max={:.6}", white_min, white_max);
}
fn discover_ort() {
if let Ok(path) = std::env::var("ORT_DYLIB_PATH")
&& std::path::Path::new(&path).exists()
{
eprintln!("ORT found via ORT_DYLIB_PATH: {}", path);
return;
}
let candidates = [
"/opt/homebrew/lib/libonnxruntime.dylib",
"/usr/local/lib/libonnxruntime.dylib",
];
for candidate in &candidates {
if std::path::Path::new(candidate).exists() {
eprintln!("Setting ORT_DYLIB_PATH={}", candidate);
unsafe { std::env::set_var("ORT_DYLIB_PATH", candidate) };
return;
}
}
eprintln!("WARNING: Could not find ORT library!");
}