Nomad changes
All checks were successful
Deploy fil (kreuzberg) / deploy (push) Successful in 49s

This commit is contained in:
Henrik Jess Nielsen
2026-06-01 23:40:55 +02:00
parent 72b1a0a6ed
commit b4c07d3693
5723 changed files with 1130655 additions and 0 deletions

View File

@@ -0,0 +1,110 @@
//! Internal batch mode tracking using tokio task-local storage.
//!
//! This module provides a way to track whether we're in batch processing mode
//! without exposing it in the public API. Extractors check this flag to decide
//! whether to use `spawn_blocking` for CPU-intensive work.
use std::cell::Cell;
use tokio::task_local;
task_local! {
/// Task-local flag indicating batch processing mode.
///
/// When true, extractors use `spawn_blocking` for CPU-intensive work to enable
/// parallelism. When false (single-file mode), extractors run directly to avoid
/// spawn overhead.
static BATCH_MODE: Cell<bool>;
}
/// Check if we're currently in batch processing mode.
///
/// Returns `false` if the task-local is not set (single-file mode).
#[cfg(any(
feature = "pdf",
feature = "office",
feature = "excel",
feature = "excel-wasm",
feature = "archives"
))]
pub(crate) fn is_batch_mode() -> bool {
BATCH_MODE.try_with(|cell| cell.get()).unwrap_or(false)
}
/// Run a future with batch mode enabled.
///
/// This sets the task-local BATCH_MODE flag for the duration of the future.
pub(crate) async fn with_batch_mode<F, T>(future: F) -> T
where
F: std::future::Future<Output = T>,
{
BATCH_MODE.scope(Cell::new(true), future).await
}
#[cfg(all(
test,
any(
feature = "pdf",
feature = "office",
feature = "excel",
feature = "excel-wasm",
feature = "archives"
)
))]
mod tests {
use super::*;
#[tokio::test]
async fn test_batch_mode_not_set_by_default() {
let result = is_batch_mode();
assert!(!result, "batch mode should be false by default");
}
#[tokio::test]
async fn test_with_batch_mode_sets_flag() {
let result = with_batch_mode(async { is_batch_mode() }).await;
assert!(result, "batch mode should be true inside with_batch_mode");
}
#[tokio::test]
async fn test_batch_mode_scoped_to_future() {
assert!(!is_batch_mode(), "batch mode should be false before");
with_batch_mode(async {
assert!(is_batch_mode(), "batch mode should be true inside");
})
.await;
assert!(!is_batch_mode(), "batch mode should be false after future completes");
}
#[tokio::test]
async fn test_nested_batch_mode_calls() {
let result = with_batch_mode(async {
let outer = is_batch_mode();
let inner = with_batch_mode(async { is_batch_mode() }).await;
(outer, inner)
})
.await;
assert!(result.0, "outer batch mode should be true");
assert!(result.1, "inner batch mode should be true");
}
#[tokio::test]
async fn test_batch_mode_unaffected_after_with_batch_mode() {
with_batch_mode(async {
assert!(is_batch_mode(), "first call should set batch mode");
})
.await;
assert!(!is_batch_mode(), "batch mode should be false between calls");
with_batch_mode(async {
assert!(is_batch_mode(), "second call should set batch mode");
})
.await;
assert!(!is_batch_mode(), "batch mode should be false after all calls");
}
}

View File

@@ -0,0 +1,315 @@
//! Batch extraction optimizations using object pooling.
//!
//! This module provides optimized batch processing utilities that leverage
//! object pooling to reduce allocations during concurrent extraction of
//! multiple documents.
//!
//! # Performance Impact
//!
//! - Reuses temporary string/buffer allocations across documents
//! - Reduces garbage collection pressure by ~5-10%
//! - Overall throughput improvement of 5-10% for batch operations
//!
//! # Usage
//!
//! The batch extraction functions automatically use pooling internally.
//! For manual control, use `BatchProcessor` to create pools and manage
//! extraction with custom pool sizes.
use crate::utils::pool::{ByteBufferPool, StringBufferPool, create_byte_buffer_pool, create_string_buffer_pool};
use crate::utils::pool_sizing::PoolSizeHint;
use parking_lot::Mutex;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
/// Configuration for batch processing with pooling optimizations.
#[cfg_attr(alef, alef(skip))]
#[derive(Debug, Clone)]
pub struct BatchProcessorConfig {
/// Maximum number of string buffers to maintain in the pool
pub string_pool_size: usize,
/// Initial capacity for pooled string buffers in bytes
pub string_buffer_capacity: usize,
/// Maximum number of byte buffers to maintain in the pool
pub byte_pool_size: usize,
/// Initial capacity for pooled byte buffers in bytes
pub byte_buffer_capacity: usize,
/// Maximum concurrent extractions (for concurrency control)
pub max_concurrent: Option<usize>,
}
impl Default for BatchProcessorConfig {
fn default() -> Self {
BatchProcessorConfig {
string_pool_size: 10,
string_buffer_capacity: 8192,
byte_pool_size: 10,
byte_buffer_capacity: 65536,
max_concurrent: None,
}
}
}
/// Batch processor that manages object pools for optimized extraction.
///
/// This struct manages the lifecycle of reusable object pools used during
/// batch extraction. Pools are created lazily on first use and reused across
/// all documents processed by this batch processor.
///
/// # Lazy Initialization
///
/// Pools are initialized on demand to reduce memory usage for applications
/// that may not use batch processing immediately or at all.
#[cfg_attr(alef, alef(skip))]
pub struct BatchProcessor {
string_pool: Mutex<Option<Arc<StringBufferPool>>>,
byte_pool: Mutex<Option<Arc<ByteBufferPool>>>,
config: BatchProcessorConfig,
string_pool_initialized: AtomicBool,
byte_pool_initialized: AtomicBool,
}
impl BatchProcessor {
/// Create a new batch processor with default pool configuration.
///
/// # Returns
///
/// A new `BatchProcessor` ready to process documents.
///
/// # Example
///
/// ```rust,no_run
/// use kreuzberg::core::batch_optimizations::BatchProcessor;
///
/// let processor = BatchProcessor::new();
/// ```
pub fn new() -> Self {
Self::with_config(BatchProcessorConfig::default())
}
/// Create a new batch processor with custom pool configuration.
///
/// Pools are not created immediately but lazily on first access.
///
/// # Arguments
///
/// * `config` - Custom batch processor configuration
///
/// # Returns
///
/// A new `BatchProcessor` configured with the provided settings.
///
/// # Example
///
/// ```rust,no_run
/// use kreuzberg::core::batch_optimizations::{BatchProcessor, BatchProcessorConfig};
///
/// let mut config = BatchProcessorConfig::default();
/// config.string_pool_size = 20;
/// config.string_buffer_capacity = 16384;
/// let processor = BatchProcessor::with_config(config);
/// ```
pub fn with_config(config: BatchProcessorConfig) -> Self {
BatchProcessor {
string_pool: Mutex::new(None),
byte_pool: Mutex::new(None),
config,
string_pool_initialized: AtomicBool::new(false),
byte_pool_initialized: AtomicBool::new(false),
}
}
/// Create a batch processor with pool sizes optimized for a specific document.
///
/// This method uses a `PoolSizeHint` (derived from file size and MIME type)
/// to create a batch processor with appropriately sized pools. This reduces
/// memory waste by tailoring pool allocation to actual document complexity.
///
/// # Arguments
///
/// * `hint` - Pool sizing hint containing recommended buffer counts and capacities
///
/// # Returns
///
/// A new `BatchProcessor` configured with the hint-based pool sizes
///
/// # Example
///
/// ```ignore
/// use kreuzberg::core::batch_optimizations::BatchProcessor;
/// use kreuzberg::utils::pool_sizing::estimate_pool_size;
///
/// let hint = estimate_pool_size(5_000_000, "application/pdf");
/// let processor = BatchProcessor::with_pool_hint(&hint);
/// ```
pub fn with_pool_hint(hint: &PoolSizeHint) -> Self {
let config = BatchProcessorConfig {
string_pool_size: hint.string_buffer_count,
string_buffer_capacity: hint.string_buffer_capacity,
byte_pool_size: hint.byte_buffer_count,
byte_buffer_capacity: hint.byte_buffer_capacity,
max_concurrent: None,
};
Self::with_config(config)
}
/// Get a reference to the string buffer pool.
///
/// Creates the pool lazily on first access.
/// Useful for custom pooling implementations that need direct pool access.
pub fn string_pool(&self) -> Arc<StringBufferPool> {
if self.string_pool_initialized.load(Ordering::Acquire) {
return Arc::clone(self.string_pool.lock().as_ref().unwrap());
}
let mut pool_opt = self.string_pool.lock();
if pool_opt.is_none() {
let pool = Arc::new(create_string_buffer_pool(
self.config.string_pool_size,
self.config.string_buffer_capacity,
));
*pool_opt = Some(pool);
self.string_pool_initialized.store(true, Ordering::Release);
}
Arc::clone(pool_opt.as_ref().unwrap())
}
/// Get a reference to the byte buffer pool.
///
/// Creates the pool lazily on first access.
/// Useful for custom pooling implementations that need direct pool access.
pub fn byte_pool(&self) -> Arc<ByteBufferPool> {
if self.byte_pool_initialized.load(Ordering::Acquire) {
return Arc::clone(self.byte_pool.lock().as_ref().unwrap());
}
let mut pool_opt = self.byte_pool.lock();
if pool_opt.is_none() {
let pool = Arc::new(create_byte_buffer_pool(
self.config.byte_pool_size,
self.config.byte_buffer_capacity,
));
*pool_opt = Some(pool);
self.byte_pool_initialized.store(true, Ordering::Release);
}
Arc::clone(pool_opt.as_ref().unwrap())
}
/// Get the current configuration.
pub fn config(&self) -> &BatchProcessorConfig {
&self.config
}
/// Get the number of pooled string buffers currently available.
pub fn string_pool_size(&self) -> usize {
self.string_pool.lock().as_ref().map(|p| p.size()).unwrap_or(0)
}
/// Get the number of pooled byte buffers currently available.
pub fn byte_pool_size(&self) -> usize {
self.byte_pool.lock().as_ref().map(|p| p.size()).unwrap_or(0)
}
/// Clear all pooled objects, forcing new allocations on next acquire.
///
/// Useful for memory-constrained environments or to reclaim memory
/// after processing large batches.
pub fn clear_pools(&self) {
if let Some(pool) = self.string_pool.lock().as_ref() {
pool.clear();
}
if let Some(pool) = self.byte_pool.lock().as_ref() {
pool.clear();
}
}
}
impl Default for BatchProcessor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_processor_creation() {
let processor = BatchProcessor::new();
assert_eq!(processor.string_pool_size(), 0);
assert_eq!(processor.byte_pool_size(), 0);
}
#[test]
fn test_batch_processor_with_config() {
let config = BatchProcessorConfig {
string_pool_size: 5,
string_buffer_capacity: 1024,
byte_pool_size: 3,
byte_buffer_capacity: 4096,
max_concurrent: None,
};
let processor = BatchProcessor::with_config(config);
assert_eq!(processor.config().string_pool_size, 5);
assert_eq!(processor.config().byte_pool_size, 3);
}
#[test]
fn test_batch_processor_string_pool_usage() {
let processor = BatchProcessor::new();
let pool = processor.string_pool();
{
let mut s = pool.acquire();
s.push_str("test");
}
{
let s = pool.acquire();
assert_eq!(s.len(), 0);
}
}
#[test]
fn test_batch_processor_byte_pool_usage() {
let processor = BatchProcessor::new();
let pool = processor.byte_pool();
{
let mut buf = pool.acquire();
buf.extend_from_slice(b"test");
}
{
let buf = pool.acquire();
assert_eq!(buf.len(), 0);
}
}
#[test]
fn test_batch_processor_clear_pools() {
let processor = BatchProcessor::new();
let s1 = processor.string_pool().acquire();
let s2 = processor.byte_pool().acquire();
drop(s1);
drop(s2);
assert!(processor.string_pool_size() > 0);
assert!(processor.byte_pool_size() > 0);
processor.clear_pools();
assert_eq!(processor.string_pool_size(), 0);
assert_eq!(processor.byte_pool_size(), 0);
}
}

View File

@@ -0,0 +1,55 @@
//! Acceleration configuration for ONNX Runtime execution providers.
use serde::{Deserialize, Serialize};
/// Hardware acceleration configuration for ONNX Runtime models.
///
/// Controls which execution provider (CPU, CoreML, CUDA, TensorRT) is used
/// for inference in layout detection and embedding generation.
///
/// # Example
///
/// ```rust
/// use kreuzberg::AccelerationConfig;
///
/// // Auto-select: CoreML on macOS, CUDA on Linux, CPU elsewhere
/// let config = AccelerationConfig::default();
///
/// // Force CPU only
/// let config = AccelerationConfig {
/// provider: kreuzberg::ExecutionProviderType::Cpu,
/// ..Default::default()
/// };
/// ```
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct AccelerationConfig {
/// Execution provider to use for ONNX inference.
#[serde(default)]
pub provider: ExecutionProviderType,
/// GPU device ID (for CUDA/TensorRT). Ignored for CPU/CoreML/Auto.
#[serde(default)]
pub device_id: u32,
}
/// ONNX Runtime execution provider type.
///
/// Determines which hardware backend is used for model inference.
/// `Auto` (default) selects the best available provider per platform.
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ExecutionProviderType {
/// Auto-select: CoreML on macOS, CUDA on Linux, CPU elsewhere.
#[default]
Auto,
/// CPU execution provider (always available).
Cpu,
/// Apple CoreML (macOS/iOS Neural Engine + GPU).
#[serde(alias = "coreml")]
CoreMl,
/// NVIDIA CUDA GPU acceleration.
Cuda,
/// NVIDIA TensorRT (optimized CUDA inference).
#[serde(alias = "tensorrt")]
TensorRt,
}

View File

@@ -0,0 +1,140 @@
//! Concurrency and thread pool configuration.
use std::sync::Once;
use serde::{Deserialize, Serialize};
/// Controls thread usage for constrained environments.
///
/// Set `max_threads` to cap all internal thread pools (Rayon, ONNX Runtime
/// intra-op) and batch concurrency to a single limit.
///
/// # Example
///
/// ```rust
/// use kreuzberg::core::config::ConcurrencyConfig;
///
/// let config = ConcurrencyConfig {
/// max_threads: Some(2),
/// };
/// ```
#[cfg_attr(alef, alef(skip))]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default)]
pub struct ConcurrencyConfig {
/// Maximum number of threads for all internal thread pools.
///
/// Caps Rayon global pool size, ONNX Runtime intra-op threads, and
/// (when `max_concurrent_extractions` is unset) the batch concurrency
/// semaphore. When `None`, system defaults are used.
pub max_threads: Option<usize>,
}
static POOL_INIT: Once = Once::new();
/// Resolve the effective thread budget from config or auto-detection.
///
/// User-set `max_threads` takes priority. Otherwise auto-detects from `num_cpus`,
/// capped at 8 for sane defaults in serverless environments.
///
/// # Example
///
/// ```ignore
/// use kreuzberg::core::config::ConcurrencyConfig;
/// use kreuzberg::core::config::concurrency::resolve_thread_budget;
///
/// let config = ConcurrencyConfig { max_threads: Some(4) };
/// assert_eq!(resolve_thread_budget(Some(&config)), 4);
/// assert!(resolve_thread_budget(None) >= 1);
/// ```
pub(crate) fn resolve_thread_budget(config: Option<&ConcurrencyConfig>) -> usize {
if let Some(n) = config.and_then(|c| c.max_threads) {
return n.max(1);
}
num_cpus::get().min(8)
}
/// Initialize the global Rayon thread pool with the given budget.
///
/// Safe to call multiple times — only the first call takes effect (subsequent
/// calls are silently ignored).
///
/// # Example
///
/// ```ignore
/// use kreuzberg::core::config::concurrency::init_thread_pools;
///
/// init_thread_pools(4);
/// init_thread_pools(2); // no-op: pool already initialized
/// ```
pub(crate) fn init_thread_pools(budget: usize) {
POOL_INIT.call_once(|| {
#[cfg(not(target_arch = "wasm32"))]
rayon::ThreadPoolBuilder::new().num_threads(budget).build_global().ok();
#[cfg(target_arch = "wasm32")]
let _ = budget;
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resolve_thread_budget_none() {
let budget = resolve_thread_budget(None);
assert!(budget >= 1);
assert!(budget <= 8);
}
#[test]
fn test_resolve_thread_budget_with_config() {
let config = ConcurrencyConfig { max_threads: Some(4) };
assert_eq!(resolve_thread_budget(Some(&config)), 4);
}
#[test]
fn test_resolve_thread_budget_clamps_to_one() {
let config = ConcurrencyConfig { max_threads: Some(0) };
assert_eq!(resolve_thread_budget(Some(&config)), 1);
}
#[test]
fn test_resolve_thread_budget_no_max() {
let config = ConcurrencyConfig { max_threads: None };
let budget = resolve_thread_budget(Some(&config));
assert!(budget >= 1);
assert!(budget <= 8);
}
#[test]
fn test_init_thread_pools_idempotent() {
// Should not panic when called multiple times.
init_thread_pools(2);
init_thread_pools(4);
}
#[test]
fn test_default() {
let config = ConcurrencyConfig::default();
assert!(config.max_threads.is_none());
}
#[test]
fn test_serde_roundtrip() {
let json = r#"{"max_threads": 2}"#;
let config: ConcurrencyConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.max_threads, Some(2));
let serialized = serde_json::to_string(&config).unwrap();
let roundtripped: ConcurrencyConfig = serde_json::from_str(&serialized).unwrap();
assert_eq!(roundtripped.max_threads, Some(2));
}
#[test]
fn test_serde_empty() {
let json = r#"{}"#;
let config: ConcurrencyConfig = serde_json::from_str(json).unwrap();
assert!(config.max_threads.is_none());
}
}

View File

@@ -0,0 +1,80 @@
//! Cross-extractor content filtering configuration.
use serde::{Deserialize, Serialize};
fn default_true() -> bool {
true
}
/// Cross-extractor content filtering configuration.
///
/// Controls whether "furniture" content (headers, footers, page numbers,
/// watermarks, repeating text) is included in or stripped from extraction
/// results. Applies across all extractors (PDF, DOCX, RTF, ODT, HTML, etc.)
/// with format-specific implementation.
///
/// When `None` on `ExtractionConfig`, each extractor uses its current
/// default behavior unchanged.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContentFilterConfig {
/// Include running headers in extraction output.
///
/// - PDF: Disables top-margin furniture stripping and prevents the layout
/// model from treating `PageHeader`-classified regions as furniture.
/// - DOCX: Includes document headers in text output.
/// - RTF/ODT: Headers already included; this is a no-op when true.
/// - HTML/EPUB: Keeps `<header>` element content.
///
/// Default: `false` (headers are stripped or excluded).
#[serde(default)]
pub include_headers: bool,
/// Include running footers in extraction output.
///
/// - PDF: Disables bottom-margin furniture stripping and prevents the layout
/// model from treating `PageFooter`-classified regions as furniture.
/// - DOCX: Includes document footers in text output.
/// - RTF/ODT: Footers already included; this is a no-op when true.
/// - HTML/EPUB: Keeps `<footer>` element content.
///
/// Default: `false` (footers are stripped or excluded).
#[serde(default)]
pub include_footers: bool,
/// Enable the heuristic cross-page repeating text detector.
///
/// When `true` (default), text that repeats verbatim across a supermajority
/// of pages is classified as furniture and stripped. Disable this if brand
/// names or repeated headings are being incorrectly removed by the heuristic.
///
/// Note: when a layout-detection model is active, the model may independently
/// classify page-header / page-footer regions as furniture on a per-page basis.
/// To preserve those regions, set `include_headers = true`, `include_footers = true`,
/// or both, in addition to disabling this flag.
///
/// Primarily affects PDF extraction.
///
/// Default: `true`.
#[serde(default = "default_true")]
pub strip_repeating_text: bool,
/// Include watermark text in extraction output.
///
/// - PDF: Keeps watermark artifacts and arXiv identifiers.
/// - Other formats: No effect currently.
///
/// Default: `false` (watermarks are stripped).
#[serde(default)]
pub include_watermarks: bool,
}
impl Default for ContentFilterConfig {
fn default() -> Self {
Self {
include_headers: false,
include_footers: false,
strip_repeating_text: true,
include_watermarks: false,
}
}
}

View File

@@ -0,0 +1,58 @@
//! Email extraction configuration.
use serde::{Deserialize, Serialize};
/// Configuration for email extraction.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct EmailConfig {
/// Windows codepage number to use when an MSG file contains no codepage property.
/// Defaults to `None`, which falls back to windows-1252.
///
/// If an unrecognized or invalid codepage number is supplied (including 0),
/// the behavior silently falls back to windows-1252 — the same as when the
/// MSG file itself contains an unrecognized codepage. No error or warning is
/// emitted. Users should verify output when supplying unusual values.
///
/// Common values:
/// - 1250: Central European (Polish, Czech, Hungarian, etc.)
/// - 1251: Cyrillic (Russian, Ukrainian, Bulgarian, etc.)
/// - 1252: Western European (default)
/// - 1253: Greek
/// - 1254: Turkish
/// - 1255: Hebrew
/// - 1256: Arabic
/// - 932: Japanese (Shift-JIS)
/// - 936: Simplified Chinese (GBK)
pub msg_fallback_codepage: Option<u32>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_email_config_default() {
let config = EmailConfig::default();
assert!(config.msg_fallback_codepage.is_none());
}
#[test]
fn test_email_config_serde_roundtrip() {
let json = r#"{"msg_fallback_codepage": 1251}"#;
let config: EmailConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.msg_fallback_codepage, Some(1251));
let serialized = serde_json::to_string(&config).unwrap();
let roundtripped: EmailConfig = serde_json::from_str(&serialized).unwrap();
assert_eq!(roundtripped.msg_fallback_codepage, Some(1251));
}
#[test]
fn test_email_config_serde_default_omitted() {
let json = r#"{}"#;
let config: EmailConfig = serde_json::from_str(json).unwrap();
assert!(config.msg_fallback_codepage.is_none());
}
}

View File

@@ -0,0 +1,728 @@
//! Main extraction configuration struct.
//!
//! This module contains the main `ExtractionConfig` struct that aggregates all
//! configuration options for the extraction process.
use serde::{Deserialize, Serialize};
use super::super::acceleration::AccelerationConfig;
use super::super::content_filter::ContentFilterConfig;
use super::super::formats::OutputFormat;
use super::super::ocr::OcrConfig;
use super::super::page::PageConfig;
use super::super::processing::{ChunkingConfig, PostProcessorConfig};
use super::file_config::FileExtractionConfig;
use super::types::{ImageExtractionConfig, LanguageDetectionConfig, TokenReductionOptions};
/// Main extraction configuration.
///
/// This struct contains all configuration options for the extraction process.
/// It can be loaded from TOML, YAML, or JSON files, or created programmatically.
///
/// # Example
///
/// ```rust
/// use kreuzberg::core::config::ExtractionConfig;
///
/// // Create with defaults
/// let config = ExtractionConfig::default();
///
/// // Load from TOML file
/// // let config = ExtractionConfig::from_toml_file("kreuzberg.toml")?;
/// ```
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ExtractionConfig {
/// Enable caching of extraction results
#[serde(default = "default_true")]
pub use_cache: bool,
/// Enable quality post-processing
#[serde(default = "default_true")]
pub enable_quality_processing: bool,
/// OCR configuration (None = OCR disabled)
#[serde(default)]
pub ocr: Option<OcrConfig>,
/// Force OCR even for searchable PDFs
#[serde(default)]
pub force_ocr: bool,
/// Force OCR on specific pages only (1-indexed page numbers, must be >= 1).
///
/// When set, only the listed pages are OCR'd regardless of text layer quality.
/// Unlisted pages use native text extraction. Ignored when `force_ocr` is `true`.
/// Only applies to PDF documents. Duplicates are automatically deduplicated.
/// An `ocr` config is recommended for backend/language selection; defaults are used if absent.
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub force_ocr_pages: Option<Vec<u32>>,
/// Disable OCR entirely, even for images.
///
/// When `true`, OCR is skipped for all document types. Images return metadata
/// only (dimensions, format, EXIF) without text extraction. PDFs use only
/// native text extraction without OCR fallback.
///
/// Cannot be `true` simultaneously with `force_ocr`.
///
/// *Added in v4.7.0.*
#[serde(default)]
pub disable_ocr: bool,
/// Text chunking configuration (None = chunking disabled)
#[serde(default)]
pub chunking: Option<ChunkingConfig>,
/// Content filtering configuration (None = use extractor defaults).
///
/// Controls whether document "furniture" (headers, footers, watermarks,
/// repeating text) is included in or stripped from extraction results.
/// See [`ContentFilterConfig`] for per-field documentation.
#[serde(default)]
pub content_filter: Option<ContentFilterConfig>,
/// Image extraction configuration (None = no image extraction)
#[serde(default)]
pub images: Option<ImageExtractionConfig>,
/// PDF-specific options (None = use defaults)
#[cfg(feature = "pdf")]
#[serde(default)]
pub pdf_options: Option<super::super::pdf::PdfConfig>,
/// Token reduction configuration (None = no token reduction)
#[serde(default)]
pub token_reduction: Option<TokenReductionOptions>,
/// Language detection configuration (None = no language detection)
#[serde(default)]
pub language_detection: Option<LanguageDetectionConfig>,
/// Page extraction configuration (None = no page tracking)
#[serde(default)]
pub pages: Option<PageConfig>,
/// Keyword extraction configuration (None = no keyword extraction)
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
#[serde(default)]
pub keywords: Option<crate::keywords::KeywordConfig>,
/// Post-processor configuration (None = use defaults)
#[serde(default)]
pub postprocessor: Option<PostProcessorConfig>,
/// HTML to Markdown conversion options (None = use defaults)
///
/// Configure how HTML documents are converted to Markdown, including heading styles,
/// list formatting, code block styles, and preprocessing options.
#[cfg(feature = "html")]
#[serde(default)]
pub html_options: Option<html_to_markdown_rs::ConversionOptions>,
/// Styled HTML output configuration.
///
/// When set alongside `output_format = OutputFormat::Html`, the extraction
/// pipeline uses [`StyledHtmlRenderer`](crate::rendering::StyledHtmlRenderer)
/// which emits stable `kb-*` CSS class hooks on every structural element
/// and optionally embeds theme CSS or user-supplied CSS in a `<style>` block.
///
/// When `None`, the existing plain comrak-based HTML renderer is used.
#[cfg(feature = "html")]
#[serde(default)]
pub html_output: Option<crate::core::config::html_output::HtmlOutputConfig>,
/// Default per-file timeout in seconds for batch extraction.
///
/// When set, each file in a batch will be canceled after this duration
/// unless overridden by [`FileExtractionConfig::timeout_secs`].
///
/// Defaults to `Some(60)` to prevent pathological files (e.g. deeply
/// nested archives, documents with millions of cells) from running
/// indefinitely and exhausting caller resources. Set to `None` to
/// disable the timeout for trusted input or long-running workloads.
#[serde(default = "default_extraction_timeout")]
pub extraction_timeout_secs: Option<u64>,
/// Maximum concurrent extractions in batch operations (None = (num_cpus × 1.5).ceil()).
///
/// Limits parallelism to prevent resource exhaustion when processing
/// large batches. Defaults to (num_cpus × 1.5).ceil() when not set.
#[serde(default)]
pub max_concurrent_extractions: Option<usize>,
/// Result structure format
///
/// Controls whether results are returned in unified format (default) with all
/// content in the `content` field, or element-based format with semantic
/// elements (for Unstructured-compatible output).
#[serde(default)]
pub result_format: crate::types::ResultFormat,
/// Security limits for archive extraction.
///
/// Controls maximum archive size, compression ratio, file count, and other
/// security thresholds to prevent decompression bomb attacks. Also caps
/// nesting depth, iteration count, entity / token length, total
/// content size, and table cell count for every extraction path that
/// ingests user-controlled bytes.
/// When `None`, default limits are used.
#[serde(default)]
pub security_limits: Option<crate::extractors::security::SecurityLimits>,
/// Maximum uncompressed size in bytes for a single embedded file before
/// recursive extraction is attempted (default: 50 MiB).
///
/// Applies to embedded objects inside OOXML containers (DOCX, PPTX) and
/// to email attachments processed via recursive extraction. Files that
/// exceed this limit are skipped with a `ProcessingWarning` rather than
/// passed to the extraction pipeline, preventing a single oversized
/// embedded object from consuming unbounded memory or time.
///
/// Set to `None` to disable the per-embedded-file cap (falls back to
/// `security_limits.max_archive_size` as the only guard).
#[serde(default = "default_max_embedded_file_bytes")]
pub max_embedded_file_bytes: Option<u64>,
/// Content text format (default: Plain).
///
/// Controls the format of the extracted content:
/// - `Plain`: Raw extracted text (default)
/// - `Markdown`: Markdown formatted output
/// - `Djot`: Djot markup format (requires djot feature)
/// - `Html`: HTML formatted output
///
/// When set to a structured format, extraction results will include
/// formatted output. The `formatted_content` field may be populated
/// when format conversion is applied.
#[serde(default)]
pub output_format: OutputFormat,
/// Layout detection configuration (None = layout detection disabled).
///
/// When set, PDF pages and images are analyzed for document structure
/// (headings, code, formulas, tables, figures, etc.) using RT-DETR models
/// via ONNX Runtime. For PDFs, layout hints override paragraph classification
/// in the markdown pipeline. For images, per-region OCR is performed with
/// markdown formatting based on detected layout classes.
/// Requires the `layout-detection` feature to run inference; the field is
/// present whenever the `layout-types` feature is active (which includes
/// `layout-detection` as well as the no-ORT target groups).
#[cfg(feature = "layout-types")]
#[serde(default)]
pub layout: Option<super::super::layout::LayoutDetectionConfig>,
/// Run layout detection on the non-OCR PDF markdown path.
///
/// When `true` and `layout` is `Some(_)`, layout regions inform heading,
/// table, list, and figure detection in the structure pipeline that would
/// otherwise rely on font-clustering heuristics alone. Significantly
/// improves SF1 (structural F1) at the cost of inference latency
/// (~150-300ms/page CPU, ~20-50ms/page GPU). Default: `false`.
/// Requires the `layout-detection` feature.
#[serde(default)]
pub use_layout_for_markdown: bool,
/// Enable structured document tree output.
///
/// When true, populates the `document` field on `ExtractionResult` with a
/// hierarchical `DocumentStructure` containing heading-driven section nesting,
/// table grids, content layer classification, and inline annotations.
///
/// Independent of `result_format` — can be combined with Unified or ElementBased.
#[serde(default)]
pub include_document_structure: bool,
/// Hardware acceleration configuration for ONNX Runtime models.
///
/// Controls execution provider selection for layout detection and embedding
/// models. When `None`, uses platform defaults (CoreML on macOS, CUDA on
/// Linux, CPU on Windows).
#[serde(default)]
pub acceleration: Option<AccelerationConfig>,
/// Cache namespace for tenant isolation.
///
/// When set, cache entries are stored under `{cache_dir}/{namespace}/`.
/// Must be alphanumeric, hyphens, or underscores only (max 64 chars).
/// Different namespaces have isolated cache spaces on the same filesystem.
#[serde(default)]
pub cache_namespace: Option<String>,
/// Per-request cache TTL in seconds.
///
/// Overrides the global `max_age_days` for this specific extraction.
/// When `0`, caching is completely skipped (no read or write).
/// When `None`, the global TTL applies.
#[serde(default)]
pub cache_ttl_secs: Option<u64>,
/// Email extraction configuration (None = use defaults).
///
/// Currently supports configuring the fallback codepage for MSG files
/// that do not specify one. See [`crate::core::config::EmailConfig`] for details.
#[serde(default)]
pub email: Option<super::super::email::EmailConfig>,
/// Concurrency limits for constrained environments (None = use defaults).
///
/// Controls Rayon thread pool size, ONNX Runtime intra-op threads, and
/// (when `max_concurrent_extractions` is unset) the batch concurrency
/// semaphore. See [`crate::core::config::ConcurrencyConfig`] for details.
#[serde(default)]
pub concurrency: Option<super::super::concurrency::ConcurrencyConfig>,
/// Maximum recursion depth for archive extraction (default: 3).
/// Set to 0 to disable recursive extraction (legacy behavior).
#[serde(default = "default_archive_depth")]
pub max_archive_depth: usize,
/// Tree-sitter language pack configuration (None = tree-sitter disabled).
///
/// When set, enables code file extraction using tree-sitter parsers.
/// Controls grammar download behavior and code analysis options.
#[cfg(feature = "tree-sitter")]
#[serde(default)]
pub tree_sitter: Option<super::super::tree_sitter::TreeSitterConfig>,
/// Structured extraction via LLM (None = disabled).
///
/// When set, the extracted document content is sent to an LLM with the
/// provided JSON schema. The structured response is stored in
/// `ExtractionResult::structured_output`.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub structured_extraction: Option<super::super::llm::StructuredExtractionConfig>,
/// Cancellation token for this extraction (None = no external cancellation).
///
/// Pass a [`CancellationToken`] clone here and call [`CancellationToken::cancel`]
/// from another thread / task to abort the extraction in progress. The extractor
/// checks the token at safe checkpoints (before lock acquisition, between pages,
/// between batch items) and returns [`KreuzbergError::Cancelled`] when set.
///
/// The field is excluded from serialization because `CancellationToken` is a
/// runtime handle, not a configuration value.
#[serde(skip)]
pub cancel_token: Option<crate::cancellation::CancellationToken>,
}
impl Default for ExtractionConfig {
fn default() -> Self {
Self {
use_cache: true,
enable_quality_processing: true,
ocr: None,
force_ocr: false,
force_ocr_pages: None,
disable_ocr: false,
chunking: None,
content_filter: None,
images: None,
#[cfg(feature = "pdf")]
pdf_options: None,
token_reduction: None,
language_detection: None,
pages: None,
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
keywords: None,
postprocessor: None,
#[cfg(feature = "html")]
html_options: None,
#[cfg(feature = "html")]
html_output: None,
extraction_timeout_secs: default_extraction_timeout(),
max_concurrent_extractions: None,
security_limits: None,
max_embedded_file_bytes: default_max_embedded_file_bytes(),
#[cfg(feature = "layout-types")]
layout: None,
use_layout_for_markdown: false,
result_format: crate::types::ResultFormat::Unified,
output_format: OutputFormat::Plain,
include_document_structure: false,
acceleration: None,
cache_namespace: None,
cache_ttl_secs: None,
email: None,
concurrency: None,
max_archive_depth: default_archive_depth(),
#[cfg(feature = "tree-sitter")]
tree_sitter: None,
structured_extraction: None,
cancel_token: None,
}
}
}
impl ExtractionConfig {
/// Create a new `ExtractionConfig` by applying per-file overrides from a
/// [`FileExtractionConfig`]. Fields that are `Some` in the override replace the
/// corresponding field in `self`; `None` fields keep the original value.
///
/// Batch-level fields (`max_concurrent_extractions`, `use_cache`, `acceleration`,
/// `security_limits`) are never affected by overrides.
///
/// # Example
///
/// ```ignore
/// use kreuzberg::{ExtractionConfig, FileExtractionConfig};
///
/// let base = ExtractionConfig::default();
/// let override_config = FileExtractionConfig {
/// force_ocr: Some(true),
/// ..Default::default()
/// };
/// let resolved = base.with_file_overrides(&override_config);
/// assert!(resolved.force_ocr);
/// ```
pub(crate) fn with_file_overrides(&self, overrides: &FileExtractionConfig) -> Self {
// Destructure to ensure compile-time exhaustiveness: adding a field to
// FileExtractionConfig without handling it here will produce a compile error.
let FileExtractionConfig {
ref enable_quality_processing,
ref ocr,
ref force_ocr,
ref force_ocr_pages,
ref disable_ocr,
ref chunking,
ref content_filter,
ref images,
#[cfg(feature = "pdf")]
ref pdf_options,
ref token_reduction,
ref language_detection,
ref pages,
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
ref keywords,
ref postprocessor,
#[cfg(feature = "html")]
ref html_options,
ref result_format,
ref output_format,
ref include_document_structure,
#[cfg(feature = "layout-types")]
ref layout,
ref timeout_secs,
#[cfg(feature = "tree-sitter")]
ref tree_sitter,
ref structured_extraction,
} = *overrides;
let mut config = self.clone();
if let Some(v) = enable_quality_processing {
config.enable_quality_processing = *v;
}
if let Some(v) = ocr {
config.ocr = Some(v.clone());
}
if let Some(v) = force_ocr {
config.force_ocr = *v;
}
if let Some(v) = force_ocr_pages {
config.force_ocr_pages = Some(v.clone());
}
if let Some(v) = disable_ocr {
config.disable_ocr = *v;
}
if let Some(v) = chunking {
config.chunking = Some(v.clone());
}
if let Some(v) = content_filter {
config.content_filter = Some(v.clone());
}
if let Some(v) = images {
config.images = Some(v.clone());
}
#[cfg(feature = "pdf")]
if let Some(v) = pdf_options {
config.pdf_options = Some(v.clone());
}
if let Some(v) = token_reduction {
config.token_reduction = Some(v.clone());
}
if let Some(v) = language_detection {
config.language_detection = Some(v.clone());
}
if let Some(v) = pages {
config.pages = Some(v.clone());
}
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
if let Some(v) = keywords {
config.keywords = Some(v.clone());
}
if let Some(v) = postprocessor {
config.postprocessor = Some(v.clone());
}
#[cfg(feature = "html")]
if let Some(v) = html_options {
config.html_options = Some(v.clone());
}
if let Some(v) = result_format {
config.result_format = *v;
}
if let Some(v) = output_format {
config.output_format = v.clone();
}
if let Some(v) = include_document_structure {
config.include_document_structure = *v;
}
#[cfg(feature = "layout-types")]
if let Some(v) = layout {
config.layout = Some(v.clone());
}
if let Some(v) = timeout_secs {
config.extraction_timeout_secs = Some(*v);
}
#[cfg(feature = "tree-sitter")]
if let Some(v) = tree_sitter {
config.tree_sitter = Some(v.clone());
}
if let Some(v) = structured_extraction {
config.structured_extraction = Some(v.clone());
}
config
}
/// Normalize configuration for implicit requirements.
///
/// Currently handles:
/// - Auto-enabling `extract_pages` when `result_format` is `ElementBased`, because
/// the element transformation requires per-page data to assign correct page numbers.
/// Without this, all elements would incorrectly get `page_number=1`.
/// - Auto-enabling `extract_pages` when chunking is configured, because the chunker
/// needs page boundaries to assign correct page numbers to chunks.
pub(crate) fn normalized(&self) -> std::borrow::Cow<'_, Self> {
let needs_pages = |cfg: &Self| -> bool {
match &cfg.pages {
Some(page_config) => !page_config.extract_pages,
None => true,
}
};
let needs_pages_for_elements =
self.result_format == crate::types::ResultFormat::ElementBased && needs_pages(self);
let needs_pages_for_chunking = self.chunking.is_some() && needs_pages(self);
if needs_pages_for_elements || needs_pages_for_chunking {
let mut config = self.clone();
let page_config = config.pages.get_or_insert_with(super::super::page::PageConfig::default);
page_config.extract_pages = true;
return std::borrow::Cow::Owned(config);
}
std::borrow::Cow::Borrowed(self)
}
/// Validate the configuration, returning an error if any settings are invalid.
///
/// Checks:
/// Returns the effective disable-OCR value, accounting for both the top-level
/// `disable_ocr` flag and the `ocr.enabled` shorthand on [`OcrConfig`].
///
/// Setting `ocr.enabled = false` in configuration is treated as equivalent to
/// `disable_ocr = true`. This method is the single source of truth for whether
/// OCR should be skipped.
pub(crate) fn effective_disable_ocr(&self) -> bool {
self.disable_ocr || self.ocr.as_ref().is_some_and(|o| !o.enabled)
}
/// Check if image processing is needed by examining OCR and image extraction settings.
///
/// Returns `true` if either OCR is enabled or image extraction is configured,
/// indicating that image decompression and processing should occur.
/// Returns `false` if both are disabled, allowing optimization to skip unnecessary
/// image decompression for text-only extraction workflows.
///
/// # Optimization Impact
/// For text-only extractions (no OCR, no image extraction), skipping image
/// decompression can improve CPU utilization by 5-10% by avoiding wasteful
/// image I/O and processing when results won't be used.
pub fn needs_image_processing(&self) -> bool {
let ocr_enabled = !self.effective_disable_ocr() && (self.ocr.is_some() || self.force_ocr);
let image_extraction_enabled = self.images.as_ref().map(|i| i.extract_images).unwrap_or(false);
#[cfg(feature = "layout-detection")]
let layout_enabled = self.layout.is_some();
#[cfg(not(feature = "layout-detection"))]
let layout_enabled = false;
ocr_enabled || image_extraction_enabled || layout_enabled
}
}
fn default_true() -> bool {
true
}
fn default_archive_depth() -> usize {
3
}
/// Default per-embedded-file cap: 50 MiB.
///
/// A single embedded object larger than this can consume significant memory
/// when the recursive extractor materialises it. 50 MiB is generous for
/// real-world embedded documents while still bounding worst-case allocation.
fn default_max_embedded_file_bytes() -> Option<u64> {
Some(50 * 1024 * 1024)
}
/// Default extraction timeout: 60 seconds.
///
/// Pathological files (deeply nested archives, sheets with millions of cells,
/// adversarial PDFs) can otherwise run indefinitely and exhaust caller
/// resources. 60 s is generous for legitimate documents while bounding the
/// worst-case cost of a single untrusted input.
fn default_extraction_timeout() -> Option<u64> {
Some(60)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::config::OcrConfig;
#[test]
fn test_effective_disable_ocr_from_top_level_flag() {
let config = ExtractionConfig {
disable_ocr: true,
..Default::default()
};
assert!(config.effective_disable_ocr());
}
#[test]
fn test_effective_disable_ocr_from_ocr_enabled_false() {
let config = ExtractionConfig {
ocr: Some(OcrConfig {
enabled: false,
..Default::default()
}),
..Default::default()
};
assert!(
config.effective_disable_ocr(),
"ocr.enabled = false should be treated as disable_ocr = true"
);
}
#[test]
fn test_effective_disable_ocr_default_is_false() {
let config = ExtractionConfig::default();
assert!(!config.effective_disable_ocr());
}
#[test]
fn test_effective_disable_ocr_ocr_enabled_true_does_not_disable() {
let config = ExtractionConfig {
ocr: Some(OcrConfig {
enabled: true,
..Default::default()
}),
..Default::default()
};
assert!(!config.effective_disable_ocr());
}
#[test]
fn test_ocr_enabled_false_deserialized_from_json() {
let json = r#"{"ocr": {"enabled": false}}"#;
let config: ExtractionConfig = serde_json::from_str(json).unwrap();
assert!(
config.effective_disable_ocr(),
"JSON ocr.enabled=false should disable OCR"
);
}
#[test]
fn test_ocr_enabled_defaults_to_true() {
let json = r#"{"ocr": {"backend": "tesseract"}}"#;
let config: ExtractionConfig = serde_json::from_str(json).unwrap();
assert!(!config.effective_disable_ocr(), "OCR should be enabled by default");
}
#[cfg(feature = "layout-detection")]
#[test]
fn test_use_layout_for_markdown_defaults_to_false() {
let config = ExtractionConfig::default();
assert!(!config.use_layout_for_markdown);
}
#[cfg(feature = "layout-detection")]
#[test]
fn test_use_layout_for_markdown_can_be_set_true() {
let config = ExtractionConfig {
use_layout_for_markdown: true,
..Default::default()
};
assert!(config.use_layout_for_markdown);
}
#[cfg(feature = "layout-detection")]
#[test]
fn test_use_layout_for_markdown_serde_round_trip() {
let config = ExtractionConfig {
use_layout_for_markdown: true,
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: ExtractionConfig = serde_json::from_str(&json).unwrap();
assert!(deserialized.use_layout_for_markdown);
}
#[cfg(feature = "layout-detection")]
#[test]
fn test_use_layout_for_markdown_serde_default_false() {
// Field absent in JSON → should default to false.
let json = r#"{}"#;
let config: ExtractionConfig = serde_json::from_str(json).unwrap();
assert!(!config.use_layout_for_markdown);
}
// --- extraction_timeout_secs defaults ----------------------------------
#[test]
fn test_default_extraction_timeout_is_sixty_seconds() {
let config = ExtractionConfig::default();
assert_eq!(
config.extraction_timeout_secs,
Some(60),
"default timeout must be Some(60) to prevent unbounded extraction"
);
}
#[test]
fn test_extraction_timeout_can_be_disabled_by_setting_none() {
let config = ExtractionConfig {
extraction_timeout_secs: None,
..Default::default()
};
assert_eq!(config.extraction_timeout_secs, None);
}
#[test]
fn test_extraction_timeout_serde_round_trip() {
let config = ExtractionConfig {
extraction_timeout_secs: Some(120),
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: ExtractionConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.extraction_timeout_secs, Some(120));
}
#[test]
fn test_extraction_timeout_serde_absent_field_defaults_to_sixty() {
// When the JSON field is absent the serde default function must fire.
let json = r#"{}"#;
let config: ExtractionConfig = serde_json::from_str(json).unwrap();
assert_eq!(
config.extraction_timeout_secs,
Some(60),
"absent field must use default_extraction_timeout() -> Some(60)"
);
}
}

View File

@@ -0,0 +1,493 @@
//! Environment variable override support for extraction configuration.
//!
//! This module provides functionality to apply environment variable overrides
//! to extraction configuration, allowing runtime configuration changes.
use crate::{KreuzbergError, Result};
use super::super::ocr::OcrConfig;
use super::super::processing::ChunkingConfig;
use super::core::ExtractionConfig;
use super::types::TokenReductionOptions;
impl ExtractionConfig {
/// Apply environment variable overrides to configuration.
///
/// Environment variables have the highest precedence and will override any values
/// loaded from configuration files. This method supports the following environment variables:
///
/// - `KREUZBERG_OCR_LANGUAGE`: OCR language (ISO 639-1 or 639-3 code, e.g., "eng", "fra", "deu")
/// - `KREUZBERG_OCR_BACKEND`: OCR backend ("tesseract", "easyocr", or "paddleocr")
/// - `KREUZBERG_CHUNKING_MAX_CHARS`: Maximum characters per chunk (positive integer)
/// - `KREUZBERG_CHUNKING_MAX_OVERLAP`: Maximum overlap between chunks (non-negative integer)
/// - `KREUZBERG_CACHE_ENABLED`: Cache enabled flag ("true" or "false")
/// - `KREUZBERG_TOKEN_REDUCTION_MODE`: Token reduction mode ("off", "light", "moderate", "aggressive", or "maximum")
/// - `KREUZBERG_CHUNKING_TOKENIZER`: HuggingFace tokenizer model ID for token-based chunk sizing (requires `chunking-tokenizers` feature)
/// - `KREUZBERG_DISABLE_OCR`: Disable OCR entirely ("true" or "false")
/// - `KREUZBERG_LLM_MODEL`: LLM model for structured extraction (e.g., "openai/gpt-4o")
/// - `KREUZBERG_LLM_API_KEY`: API key for the structured extraction LLM provider
/// - `KREUZBERG_LLM_BASE_URL`: Custom base URL for the structured extraction LLM provider
/// - `KREUZBERG_VLM_OCR_MODEL`: VLM model for vision-based OCR (e.g., "openai/gpt-4o")
/// - `KREUZBERG_VLM_EMBEDDING_MODEL`: LLM model for embedding generation (e.g., "openai/text-embedding-3-small")
/// - `KREUZBERG_EMBEDDING_PLUGIN_NAME`: Name of an in-process embedding backend registered via `plugins::register_embedding_backend`
/// - `KREUZBERG_MSG_FALLBACK_CODEPAGE`: (deferred) Windows codepage for MSG PT_STRING8 fallback
///
/// # Behavior
///
/// - If an environment variable is set and valid, it overrides the current configuration value
/// - If a required parent config is `None` (e.g., `self.ocr` is None), it's created with defaults before applying the override
/// - Invalid values return a `KreuzbergError::Validation` with helpful error messages
/// - Missing or unset environment variables are silently ignored
///
/// # Example
///
/// ```rust
/// # use kreuzberg::core::config::ExtractionConfig;
/// # fn example() -> kreuzberg::Result<()> {
/// let mut config = ExtractionConfig::from_file("config.toml")?;
/// // Set KREUZBERG_OCR_LANGUAGE=fra before calling
/// config.apply_env_overrides()?; // OCR language is now "fra"
/// # Ok(())
/// # }
/// ```
///
/// # Errors
///
/// Returns `KreuzbergError::Validation` if:
/// - An environment variable contains an invalid value
/// - A number cannot be parsed as the expected type
/// - A boolean is not "true" or "false"
pub fn apply_env_overrides(&mut self) -> Result<()> {
use crate::core::config_validation::{
validate_chunking_params, validate_language_code, validate_ocr_backend, validate_token_reduction_level,
};
// KREUZBERG_OCR_LANGUAGE override
if let Ok(lang) = std::env::var("KREUZBERG_OCR_LANGUAGE") {
validate_language_code(&lang)?;
if self.ocr.is_none() {
self.ocr = Some(OcrConfig::default());
}
if let Some(ref mut ocr) = self.ocr {
ocr.language = lang;
}
}
// KREUZBERG_OCR_BACKEND override
if let Ok(backend) = std::env::var("KREUZBERG_OCR_BACKEND") {
validate_ocr_backend(&backend)?;
if self.ocr.is_none() {
self.ocr = Some(OcrConfig::default());
}
if let Some(ref mut ocr) = self.ocr {
ocr.backend = backend;
}
}
// KREUZBERG_CHUNKING_MAX_CHARS override
if let Ok(max_chars_str) = std::env::var("KREUZBERG_CHUNKING_MAX_CHARS") {
let max_chars: usize = max_chars_str.parse().map_err(|_| KreuzbergError::Validation {
message: format!(
"Invalid value for KREUZBERG_CHUNKING_MAX_CHARS: '{}'. Must be a positive integer.",
max_chars_str
),
source: None,
})?;
if max_chars == 0 {
return Err(KreuzbergError::Validation {
message: "KREUZBERG_CHUNKING_MAX_CHARS must be greater than 0".to_string(),
source: None,
});
}
if self.chunking.is_none() {
self.chunking = Some(ChunkingConfig::default());
}
if let Some(ref mut chunking) = self.chunking {
// Validate against current overlap before updating
validate_chunking_params(max_chars, chunking.overlap)?;
chunking.max_characters = max_chars;
}
}
// KREUZBERG_CHUNKING_MAX_OVERLAP override
if let Ok(max_overlap_str) = std::env::var("KREUZBERG_CHUNKING_MAX_OVERLAP") {
let max_overlap: usize = max_overlap_str.parse().map_err(|_| KreuzbergError::Validation {
message: format!(
"Invalid value for KREUZBERG_CHUNKING_MAX_OVERLAP: '{}'. Must be a non-negative integer.",
max_overlap_str
),
source: None,
})?;
if self.chunking.is_none() {
self.chunking = Some(ChunkingConfig::default());
}
if let Some(ref mut chunking) = self.chunking {
// Validate against current max_characters before updating
validate_chunking_params(chunking.max_characters, max_overlap)?;
chunking.overlap = max_overlap;
}
}
// KREUZBERG_CACHE_ENABLED override
if let Ok(cache_str) = std::env::var("KREUZBERG_CACHE_ENABLED") {
let cache_enabled = match cache_str.to_lowercase().as_str() {
"true" => true,
"false" => false,
_ => {
return Err(KreuzbergError::Validation {
message: format!(
"Invalid value for KREUZBERG_CACHE_ENABLED: '{}'. Must be 'true' or 'false'.",
cache_str
),
source: None,
});
}
};
self.use_cache = cache_enabled;
}
// KREUZBERG_TOKEN_REDUCTION_MODE override
if let Ok(mode) = std::env::var("KREUZBERG_TOKEN_REDUCTION_MODE") {
validate_token_reduction_level(&mode)?;
if self.token_reduction.is_none() {
self.token_reduction = Some(TokenReductionOptions {
mode: "off".to_string(),
preserve_important_words: true,
});
}
if let Some(ref mut token_reduction) = self.token_reduction {
token_reduction.mode = mode;
}
}
// KREUZBERG_OUTPUT_FORMAT override
if let Ok(val) = std::env::var("KREUZBERG_OUTPUT_FORMAT") {
self.output_format = val.parse().map_err(|e: String| KreuzbergError::Validation {
message: format!("Invalid value for KREUZBERG_OUTPUT_FORMAT: {}", e),
source: None,
})?;
}
// KREUZBERG_CHUNKING_TOKENIZER override
#[cfg(feature = "chunking-tokenizers")]
if let Ok(model) = std::env::var("KREUZBERG_CHUNKING_TOKENIZER") {
if model.is_empty() {
return Err(KreuzbergError::Validation {
message: "KREUZBERG_CHUNKING_TOKENIZER must not be empty".to_string(),
source: None,
});
}
if self.chunking.is_none() {
self.chunking = Some(ChunkingConfig::default());
}
if let Some(ref mut chunking) = self.chunking {
chunking.sizing = crate::core::config::processing::ChunkSizing::Tokenizer { model, cache_dir: None };
}
}
// KREUZBERG_LAYOUT_PRESET override (backward compat: enables layout detection).
// Only one model (RT-DETR) exists, so the specific preset value is ignored.
#[cfg(feature = "layout-detection")]
if let Ok(preset) = std::env::var("KREUZBERG_LAYOUT_PRESET") {
let lower = preset.to_lowercase();
if !["fast", "accurate", "yolo", "rtdetr", "rt-detr"].contains(&lower.as_str()) {
return Err(KreuzbergError::Validation {
message: format!(
"Invalid value for KREUZBERG_LAYOUT_PRESET: '{}'. Valid presets: fast, accurate",
preset
),
source: None,
});
}
if self.layout.is_none() {
self.layout = Some(super::super::layout::LayoutDetectionConfig::default());
}
// preset value is accepted but ignored -- only RT-DETR is available
let _ = lower;
}
// KREUZBERG_DISABLE_OCR override
if let Ok(val) = std::env::var("KREUZBERG_DISABLE_OCR") {
self.disable_ocr = match val.to_lowercase().as_str() {
"true" | "1" => true,
"false" | "0" => false,
_ => {
return Err(KreuzbergError::Validation {
message: format!(
"Invalid value for KREUZBERG_DISABLE_OCR: '{}'. Must be 'true' or 'false'.",
val
),
source: None,
});
}
};
}
// KREUZBERG_LLM_MODEL override
if let Ok(value) = std::env::var("KREUZBERG_LLM_MODEL") {
if value.is_empty() {
return Err(KreuzbergError::Validation {
message: "KREUZBERG_LLM_MODEL must not be empty".to_string(),
source: None,
});
}
if self.structured_extraction.is_none() {
self.structured_extraction = Some(super::super::llm::StructuredExtractionConfig {
schema: serde_json::Value::Object(Default::default()),
schema_name: "extraction".to_string(),
schema_description: None,
strict: false,
prompt: None,
llm: super::super::llm::LlmConfig {
model: value,
api_key: None,
base_url: None,
timeout_secs: None,
max_retries: None,
temperature: None,
max_tokens: None,
},
});
} else if let Some(ref mut config) = self.structured_extraction {
config.llm.model = value;
}
}
// KREUZBERG_LLM_API_KEY override
if let Ok(value) = std::env::var("KREUZBERG_LLM_API_KEY") {
if value.is_empty() {
return Err(KreuzbergError::Validation {
message: "KREUZBERG_LLM_API_KEY must not be empty".to_string(),
source: None,
});
}
if self.structured_extraction.is_none() {
self.structured_extraction = Some(super::super::llm::StructuredExtractionConfig {
schema: serde_json::Value::Object(Default::default()),
schema_name: "extraction".to_string(),
schema_description: None,
strict: false,
prompt: None,
llm: super::super::llm::LlmConfig {
model: String::new(),
api_key: Some(value),
base_url: None,
timeout_secs: None,
max_retries: None,
temperature: None,
max_tokens: None,
},
});
} else if let Some(ref mut config) = self.structured_extraction {
config.llm.api_key = Some(value);
}
}
// KREUZBERG_LLM_BASE_URL override
if let Ok(value) = std::env::var("KREUZBERG_LLM_BASE_URL") {
if value.is_empty() {
return Err(KreuzbergError::Validation {
message: "KREUZBERG_LLM_BASE_URL must not be empty".to_string(),
source: None,
});
}
if self.structured_extraction.is_none() {
self.structured_extraction = Some(super::super::llm::StructuredExtractionConfig {
schema: serde_json::Value::Object(Default::default()),
schema_name: "extraction".to_string(),
schema_description: None,
strict: false,
prompt: None,
llm: super::super::llm::LlmConfig {
model: String::new(),
api_key: None,
base_url: Some(value),
timeout_secs: None,
max_retries: None,
temperature: None,
max_tokens: None,
},
});
} else if let Some(ref mut config) = self.structured_extraction {
config.llm.base_url = Some(value);
}
}
// KREUZBERG_VLM_OCR_MODEL override
if let Ok(value) = std::env::var("KREUZBERG_VLM_OCR_MODEL") {
if value.is_empty() {
return Err(KreuzbergError::Validation {
message: "KREUZBERG_VLM_OCR_MODEL must not be empty".to_string(),
source: None,
});
}
if self.ocr.is_none() {
self.ocr = Some(OcrConfig::default());
}
if let Some(ref mut ocr) = self.ocr {
if ocr.vlm_config.is_none() {
ocr.vlm_config = Some(super::super::llm::LlmConfig {
model: value,
api_key: None,
base_url: None,
timeout_secs: None,
max_retries: None,
temperature: None,
max_tokens: None,
});
} else if let Some(ref mut vlm) = ocr.vlm_config {
vlm.model = value;
}
}
}
// KREUZBERG_VLM_EMBEDDING_MODEL override
if let Ok(value) = std::env::var("KREUZBERG_VLM_EMBEDDING_MODEL") {
if value.is_empty() {
return Err(KreuzbergError::Validation {
message: "KREUZBERG_VLM_EMBEDDING_MODEL must not be empty".to_string(),
source: None,
});
}
if self.chunking.is_none() {
self.chunking = Some(ChunkingConfig::default());
}
if let Some(ref mut chunking) = self.chunking {
chunking.embedding = Some(super::super::processing::EmbeddingConfig {
model: super::super::processing::EmbeddingModelType::Llm {
llm: super::super::llm::LlmConfig {
model: value,
api_key: None,
base_url: None,
timeout_secs: None,
max_retries: None,
temperature: None,
max_tokens: None,
},
},
..super::super::processing::EmbeddingConfig::default()
});
}
}
// KREUZBERG_EMBEDDING_PLUGIN_NAME override.
// Selects an already-registered in-process embedding backend by name.
// Setting this together with KREUZBERG_VLM_EMBEDDING_MODEL is rejected — they
// configure mutually-exclusive embedding sources and the result of "both set"
// would otherwise depend on source order in this function. Pick one.
let plugin_name = std::env::var("KREUZBERG_EMBEDDING_PLUGIN_NAME").ok();
if plugin_name.is_some() && std::env::var("KREUZBERG_VLM_EMBEDDING_MODEL").is_ok() {
return Err(KreuzbergError::Validation {
message:
"KREUZBERG_EMBEDDING_PLUGIN_NAME and KREUZBERG_VLM_EMBEDDING_MODEL are mutually exclusive — set one or the other, not both."
.to_string(),
source: None,
});
}
if let Some(value) = plugin_name {
if value.is_empty() {
return Err(KreuzbergError::Validation {
message: "KREUZBERG_EMBEDDING_PLUGIN_NAME must not be empty".to_string(),
source: None,
});
}
if self.chunking.is_none() {
self.chunking = Some(ChunkingConfig::default());
}
if let Some(ref mut chunking) = self.chunking {
chunking.embedding = Some(super::super::processing::EmbeddingConfig {
model: super::super::processing::EmbeddingModelType::Plugin { name: value },
..super::super::processing::EmbeddingConfig::default()
});
}
}
Ok(())
}
}
#[cfg(test)]
#[allow(unsafe_code)] // env mutation in 2024 edition is unsafe; tests serialize via ENV_LOCK
mod tests {
use super::*;
use crate::core::config::processing::EmbeddingModelType;
/// Lock guarding env-var mutation across tests in this module — `std::env::set_var`
/// is process-global and concurrent tests would race.
static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
fn clear_embedding_env() {
// SAFETY: callers hold ENV_LOCK so no other thread is reading these vars.
unsafe {
std::env::remove_var("KREUZBERG_EMBEDDING_PLUGIN_NAME");
std::env::remove_var("KREUZBERG_VLM_EMBEDDING_MODEL");
}
}
#[test]
fn embedding_plugin_and_vlm_embedding_model_are_mutually_exclusive() {
let _guard = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
clear_embedding_env();
// SAFETY: see clear_embedding_env.
unsafe {
std::env::set_var("KREUZBERG_EMBEDDING_PLUGIN_NAME", "my-embedder");
std::env::set_var("KREUZBERG_VLM_EMBEDDING_MODEL", "openai/text-embedding-3-small");
}
let mut config = ExtractionConfig::default();
let err = config
.apply_env_overrides()
.expect_err("should reject conflicting embedding env vars");
assert!(
matches!(err, KreuzbergError::Validation { .. }),
"expected Validation, got {err:?}"
);
let msg = err.to_string();
assert!(msg.contains("mutually exclusive"), "message: {msg}");
clear_embedding_env();
}
#[test]
fn empty_embedding_plugin_name_rejected() {
let _guard = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
clear_embedding_env();
// SAFETY: see clear_embedding_env.
unsafe { std::env::set_var("KREUZBERG_EMBEDDING_PLUGIN_NAME", "") };
let mut config = ExtractionConfig::default();
let err = config
.apply_env_overrides()
.expect_err("should reject empty plugin name");
assert!(
matches!(err, KreuzbergError::Validation { .. }),
"expected Validation, got {err:?}"
);
clear_embedding_env();
}
#[test]
fn embedding_plugin_env_sets_chunking_embedding_to_plugin_variant() {
let _guard = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
clear_embedding_env();
// SAFETY: see clear_embedding_env.
unsafe { std::env::set_var("KREUZBERG_EMBEDDING_PLUGIN_NAME", "my-embedder") };
let mut config = ExtractionConfig::default();
config
.apply_env_overrides()
.expect("should succeed with only plugin name set");
let chunking = config.chunking.as_ref().expect("chunking should be created");
let embedding = chunking.embedding.as_ref().expect("embedding should be set");
match &embedding.model {
EmbeddingModelType::Plugin { name } => {
assert_eq!(name, "my-embedder");
}
other => panic!("expected Plugin variant, got {other:?}"),
}
clear_embedding_env();
}
}

View File

@@ -0,0 +1,148 @@
//! Per-file extraction configuration overrides for batch processing.
//!
//! This module contains [`FileExtractionConfig`], a subset of [`super::ExtractionConfig`]
//! where every field is optional. When used with batch extraction functions, each file
//! can specify overrides that are merged with the batch-level default config.
//!
//! Fields that are batch-level concerns (concurrency, caching, acceleration, security)
//! are intentionally excluded and can only be set on the batch-level [`super::ExtractionConfig`].
use serde::{Deserialize, Serialize};
use super::super::formats::OutputFormat;
use super::super::ocr::OcrConfig;
use super::super::page::PageConfig;
use super::super::processing::{ChunkingConfig, PostProcessorConfig};
use super::types::{ImageExtractionConfig, LanguageDetectionConfig, TokenReductionOptions};
/// Per-file extraction configuration overrides for batch processing.
///
/// All fields are `Option<T>` — `None` means "use the batch-level default."
/// This type is used with [`crate::batch_extract_files`] and
/// [`crate::batch_extract_bytes`] to allow heterogeneous
/// extraction settings within a single batch.
///
/// # Excluded Fields
///
/// The following [`super::ExtractionConfig`] fields are batch-level only and
/// cannot be overridden per file:
/// - `max_concurrent_extractions` — controls batch parallelism
/// - `use_cache` — global caching policy
/// - `acceleration` — shared ONNX execution provider
/// - `security_limits` — global archive security policy
///
/// # Example
///
/// ```rust
/// use kreuzberg::FileExtractionConfig;
///
/// // Override just OCR forcing for a specific file
/// let config = FileExtractionConfig {
/// force_ocr: Some(true),
/// ..Default::default()
/// };
/// ```
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default)]
pub struct FileExtractionConfig {
/// Override quality post-processing for this file.
#[serde(skip_serializing_if = "Option::is_none")]
pub enable_quality_processing: Option<bool>,
/// Override OCR configuration for this file (None in the Option = use batch default).
#[serde(skip_serializing_if = "Option::is_none")]
pub ocr: Option<OcrConfig>,
/// Override force OCR for this file.
#[serde(skip_serializing_if = "Option::is_none")]
pub force_ocr: Option<bool>,
/// Override force OCR pages for this file (1-indexed page numbers).
#[serde(skip_serializing_if = "Option::is_none")]
pub force_ocr_pages: Option<Vec<u32>>,
/// Override disable OCR for this file.
#[serde(skip_serializing_if = "Option::is_none")]
pub disable_ocr: Option<bool>,
/// Override chunking configuration for this file.
#[serde(skip_serializing_if = "Option::is_none")]
pub chunking: Option<ChunkingConfig>,
/// Override content filtering configuration for this file.
#[serde(skip_serializing_if = "Option::is_none")]
pub content_filter: Option<super::super::content_filter::ContentFilterConfig>,
/// Override image extraction configuration for this file.
#[serde(skip_serializing_if = "Option::is_none")]
pub images: Option<ImageExtractionConfig>,
/// Override PDF options for this file.
#[cfg(feature = "pdf")]
#[serde(skip_serializing_if = "Option::is_none")]
pub pdf_options: Option<super::super::pdf::PdfConfig>,
/// Override token reduction for this file.
#[serde(skip_serializing_if = "Option::is_none")]
pub token_reduction: Option<TokenReductionOptions>,
/// Override language detection for this file.
#[serde(skip_serializing_if = "Option::is_none")]
pub language_detection: Option<LanguageDetectionConfig>,
/// Override page extraction for this file.
#[serde(skip_serializing_if = "Option::is_none")]
pub pages: Option<PageConfig>,
/// Override keyword extraction for this file.
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
#[serde(skip_serializing_if = "Option::is_none")]
pub keywords: Option<crate::keywords::KeywordConfig>,
/// Override post-processor for this file.
#[serde(skip_serializing_if = "Option::is_none")]
pub postprocessor: Option<PostProcessorConfig>,
/// Override HTML conversion options for this file.
#[cfg(feature = "html")]
#[serde(skip_serializing_if = "Option::is_none")]
pub html_options: Option<html_to_markdown_rs::ConversionOptions>,
/// Override result format for this file.
#[serde(skip_serializing_if = "Option::is_none")]
pub result_format: Option<crate::types::ResultFormat>,
/// Override output content format for this file.
#[serde(skip_serializing_if = "Option::is_none")]
pub output_format: Option<OutputFormat>,
/// Override document structure output for this file.
#[serde(skip_serializing_if = "Option::is_none")]
pub include_document_structure: Option<bool>,
/// Override layout detection for this file.
#[cfg(feature = "layout-types")]
#[serde(skip_serializing_if = "Option::is_none")]
pub layout: Option<super::super::layout::LayoutDetectionConfig>,
/// Override per-file extraction timeout in seconds.
///
/// When set, the extraction for this file will be canceled after the
/// specified duration. A timed-out file produces an error result without
/// affecting other files in the batch.
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout_secs: Option<u64>,
/// Override tree-sitter configuration for this file.
#[cfg(feature = "tree-sitter")]
#[serde(skip_serializing_if = "Option::is_none")]
pub tree_sitter: Option<super::super::tree_sitter::TreeSitterConfig>,
/// Override structured extraction configuration for this file.
///
/// When set, enables LLM-based structured extraction with a JSON schema
/// for this specific file. The extracted content is sent to a VLM/LLM
/// and the response is parsed according to the provided schema.
#[serde(skip_serializing_if = "Option::is_none")]
pub structured_extraction: Option<super::super::llm::StructuredExtractionConfig>,
}

View File

@@ -0,0 +1,87 @@
//! Configuration file loading.
//!
//! This module provides methods for loading extraction configuration from
//! TOML, YAML, and JSON files.
use crate::{KreuzbergError, Result};
use std::path::Path;
use super::core::ExtractionConfig;
impl ExtractionConfig {
/// Load configuration from a TOML file.
///
/// # Errors
///
/// Returns `KreuzbergError::Validation` if file doesn't exist or is invalid TOML.
pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let content = std::fs::read_to_string(path)
.map_err(|e| KreuzbergError::validation(format!("Failed to read config file {}: {}", path.display(), e)))?;
toml::from_str(&content)
.map_err(|e| KreuzbergError::validation(format!("Invalid TOML in {}: {}", path.display(), e)))
}
/// Load configuration from a YAML file.
pub fn from_yaml_file(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let content = std::fs::read_to_string(path)
.map_err(|e| KreuzbergError::validation(format!("Failed to read config file {}: {}", path.display(), e)))?;
serde_yaml_ng::from_str(&content)
.map_err(|e| KreuzbergError::validation(format!("Invalid YAML in {}: {}", path.display(), e)))
}
/// Load configuration from a JSON file.
pub fn from_json_file(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let content = std::fs::read_to_string(path)
.map_err(|e| KreuzbergError::validation(format!("Failed to read config file {}: {}", path.display(), e)))?;
serde_json::from_str(&content)
.map_err(|e| KreuzbergError::validation(format!("Invalid JSON in {}: {}", path.display(), e)))
}
/// Load configuration from a file, auto-detecting format by extension.
///
/// Supported formats: `.toml`, `.yaml`, `.yml`, `.json`.
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let extension = path.extension().and_then(|ext| ext.to_str()).ok_or_else(|| {
KreuzbergError::validation(format!(
"Cannot determine file format: no extension found in {}",
path.display()
))
})?;
match extension.to_lowercase().as_str() {
"toml" => Self::from_toml_file(path),
"yaml" | "yml" => Self::from_yaml_file(path),
"json" => Self::from_json_file(path),
other => Err(KreuzbergError::validation(format!(
"Unsupported config file format: .{}. Supported formats: .toml, .yaml, .json",
other
))),
}
}
/// Discover configuration file in parent directories.
///
/// Searches for `kreuzberg.toml` in current directory and parent directories.
pub fn discover() -> Result<Option<Self>> {
let mut current = std::env::current_dir().map_err(KreuzbergError::Io)?;
loop {
let kreuzberg_toml = current.join("kreuzberg.toml");
if kreuzberg_toml.exists() {
return Ok(Some(Self::from_toml_file(kreuzberg_toml)?));
}
if let Some(parent) = current.parent() {
current = parent.to_path_buf();
} else {
break;
}
}
Ok(None)
}
}

View File

@@ -0,0 +1,46 @@
//! Main extraction configuration and environment variable handling.
//!
//! This module contains the main `ExtractionConfig` struct and related utilities
//! for loading configuration from files and applying environment variable overrides.
//!
//! The module is organized into focused submodules:
//! - `types`: Feature-specific configuration types (image, token reduction, language detection)
//! - `core`: Main ExtractionConfig struct and implementation
//! - `env`: Environment variable override support
//! - `loaders`: Configuration file loading with caching
mod core;
mod env;
mod file_config;
mod loaders;
mod types;
// Re-export all public types for backward compatibility
pub use self::core::ExtractionConfig;
pub use self::file_config::FileExtractionConfig;
pub use self::types::{
BatchBytesItem, BatchFileItem, ImageExtractionConfig, LanguageDetectionConfig, TokenReductionOptions,
};
#[cfg(test)]
mod tests {
use super::*;
use crate::core::config::ocr::OcrConfig;
#[test]
fn test_default_config() {
let config = ExtractionConfig::default();
assert!(config.use_cache);
assert!(config.enable_quality_processing);
assert!(config.ocr.is_none());
}
#[test]
fn test_needs_image_processing() {
let mut config = ExtractionConfig::default();
assert!(!config.needs_image_processing());
config.ocr = Some(OcrConfig::default());
assert!(config.needs_image_processing());
}
}

View File

@@ -0,0 +1,345 @@
//! Feature-specific configuration types for extraction.
//!
//! This module contains configuration structs for specific extraction features:
//! - Image extraction and processing
//! - Token reduction
//! - Language detection
//! - Batch extraction items
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// Batch item for byte array extraction.
///
/// Used with [`crate::batch_extract_bytes`] and [`crate::batch_extract_bytes_sync`]
/// to represent a single item in a batch extraction job.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchBytesItem {
/// The content bytes to extract from
pub content: Vec<u8>,
/// MIME type of the content (e.g., "application/pdf", "text/html")
pub mime_type: String,
/// Per-item configuration overrides (None uses batch-level defaults)
#[serde(skip_serializing_if = "Option::is_none")]
pub config: Option<super::FileExtractionConfig>,
}
/// Batch item for file extraction.
///
/// Used with [`crate::batch_extract_files`] and [`crate::batch_extract_files_sync`]
/// to represent a single file in a batch extraction job.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchFileItem {
/// Path to the file to extract from
pub path: PathBuf,
/// Per-file configuration overrides (None uses batch-level defaults)
#[serde(skip_serializing_if = "Option::is_none")]
pub config: Option<super::FileExtractionConfig>,
}
/// Image extraction configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageExtractionConfig {
/// Extract images from documents
#[serde(default = "default_true")]
pub extract_images: bool,
/// Target DPI for image normalization
#[serde(default = "default_target_dpi")]
pub target_dpi: i32,
/// Maximum dimension for images (width or height)
#[serde(default = "default_max_dimension")]
pub max_image_dimension: i32,
/// Whether to inject image reference placeholders into markdown output.
/// When `true` (default), image references like `![Image 1](embedded:p1_i0)`
/// are appended to the markdown. Set to `false` to extract images as data
/// without polluting the markdown output.
#[serde(default = "default_true")]
pub inject_placeholders: bool,
/// Automatically adjust DPI based on image content
#[serde(default = "default_true")]
pub auto_adjust_dpi: bool,
/// Minimum DPI threshold
#[serde(default = "default_min_dpi")]
pub min_dpi: i32,
/// Maximum DPI threshold
#[serde(default = "default_max_dpi")]
pub max_dpi: i32,
/// Maximum number of image objects to extract per PDF page.
///
/// Some PDFs (e.g. technical diagrams stored as thousands of raster fragments)
/// can trigger extremely long or indefinite extraction times when every image
/// object on a dense page is decoded individually via the PDF extractor. Setting this
/// limit causes kreuzberg to stop collecting individual images once the count
/// per page reaches the cap and emit a warning instead.
///
/// `None` (default) means no limit — all images are extracted.
#[serde(default)]
pub max_images_per_page: Option<u32>,
/// When `true` (default), extracted images are classified by kind and grouped
/// into clusters where they appear to belong to one figure.
#[serde(default = "default_true")]
pub classify: bool,
/// When `true`, full-page renders produced during OCR preprocessing are captured
/// and returned as `ImageKind::PageRaster` entries in `ExtractionResult.images`.
///
/// **PDF + OCR only.** No rasters are captured for non-PDF inputs or when the
/// document-level OCR bypass is active (whole-document backend). When OCR is
/// enabled and this flag is set but the active backend skips per-page rendering,
/// a `ProcessingWarning` is emitted in `ExtractionResult.processing_warnings`.
///
/// Defaults to `false`. Enable when downstream consumers need page thumbnails
/// (e.g. citation previews, visual grounding).
#[serde(default)]
pub include_page_rasters: bool,
/// Run OCR on extracted images and include the recognized text in the document content.
///
/// When `true` (default) and `ExtractionConfig.ocr` is configured, extracted images
/// are processed with the configured OCR backend. Set to `false` to extract images
/// without OCR processing, even when OCR is enabled.
#[serde(default = "default_true")]
pub run_ocr_on_images: bool,
/// When `true`, image OCR results are rendered as plain text without the
/// `![...](...)` markdown placeholder. Only takes effect when `run_ocr_on_images`
/// is also `true`.
#[serde(default)]
pub ocr_text_only: bool,
/// When `true` and `ocr_text_only` is `false`, append the OCR text after
/// the image placeholder in the rendered output.
#[serde(default)]
pub append_ocr_text: bool,
}
/// Token reduction configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenReductionOptions {
/// Reduction mode: "off", "light", "moderate", "aggressive", "maximum"
#[serde(default = "default_reduction_mode")]
pub mode: String,
/// Preserve important words (capitalized, technical terms)
#[serde(default = "default_true")]
pub preserve_important_words: bool,
}
/// Language detection configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LanguageDetectionConfig {
/// Enable language detection
#[serde(default = "default_true")]
pub enabled: bool,
/// Minimum confidence threshold (0.0-1.0)
#[serde(default = "default_confidence")]
pub min_confidence: f64,
/// Detect multiple languages in the document
#[serde(default)]
pub detect_multiple: bool,
}
impl Default for ImageExtractionConfig {
fn default() -> Self {
Self {
extract_images: true,
target_dpi: 300,
max_image_dimension: 4096,
inject_placeholders: true,
auto_adjust_dpi: true,
min_dpi: 72,
max_dpi: 600,
max_images_per_page: None,
classify: true,
include_page_rasters: false,
run_ocr_on_images: true,
ocr_text_only: false,
append_ocr_text: false,
}
}
}
impl Default for TokenReductionOptions {
fn default() -> Self {
Self {
mode: default_reduction_mode(),
preserve_important_words: true,
}
}
}
impl Default for LanguageDetectionConfig {
fn default() -> Self {
Self {
enabled: true,
min_confidence: 0.8,
detect_multiple: false,
}
}
}
// Default value functions
fn default_true() -> bool {
true
}
fn default_target_dpi() -> i32 {
300
}
fn default_max_dimension() -> i32 {
4096
}
fn default_min_dpi() -> i32 {
72
}
fn default_max_dpi() -> i32 {
600
}
fn default_reduction_mode() -> String {
"off".to_string()
}
fn default_confidence() -> f64 {
0.8
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_image_extraction_config_default_booleans_are_true() {
let cfg = ImageExtractionConfig::default();
assert!(cfg.extract_images, "extract_images must default to true");
assert!(cfg.inject_placeholders, "inject_placeholders must default to true");
assert!(cfg.auto_adjust_dpi, "auto_adjust_dpi must default to true");
assert!(cfg.classify, "classify must default to true");
assert_eq!(cfg.target_dpi, 300);
assert_eq!(cfg.max_image_dimension, 4096);
assert_eq!(cfg.min_dpi, 72);
assert_eq!(cfg.max_dpi, 600);
}
#[test]
fn test_image_extraction_config_defaults() {
let cfg = ImageExtractionConfig::default();
assert!(cfg.run_ocr_on_images, "run_ocr_on_images must default to true");
}
#[test]
fn test_image_extraction_config_explicit_false_disables_placeholders() {
let cfg = ImageExtractionConfig {
inject_placeholders: false,
..ImageExtractionConfig::default()
};
assert!(!cfg.inject_placeholders);
assert!(cfg.extract_images);
}
#[test]
fn test_image_extraction_config_explicit_false_disables_classify() {
let cfg = ImageExtractionConfig {
classify: false,
..ImageExtractionConfig::default()
};
assert!(!cfg.classify);
assert!(cfg.extract_images);
}
#[test]
fn test_image_extraction_config_absent_json_fields_get_canonical_defaults() {
let json = r#"{"extract_images": true}"#;
let cfg: ImageExtractionConfig = serde_json::from_str(json).unwrap();
assert!(
cfg.inject_placeholders,
"absent inject_placeholders must deserialize to true"
);
assert!(cfg.auto_adjust_dpi, "absent auto_adjust_dpi must deserialize to true");
assert_eq!(cfg.target_dpi, 300);
}
#[test]
fn test_max_images_per_page_defaults_none() {
let config = ImageExtractionConfig::default();
assert_eq!(config.max_images_per_page, None);
}
#[test]
fn test_max_images_per_page_serializes_as_null_when_none() {
let config = ImageExtractionConfig::default();
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("\"max_images_per_page\":null"));
}
#[test]
fn test_max_images_per_page_roundtrips_via_json() {
let config = ImageExtractionConfig {
max_images_per_page: Some(50),
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
let back: ImageExtractionConfig = serde_json::from_str(&json).unwrap();
assert_eq!(back.max_images_per_page, Some(50));
}
/// Regression test for issue #766: missing field in JSON must not break
/// deserialization (backwards-compat — existing configs without this key
/// must still deserialize cleanly).
#[test]
fn test_max_images_per_page_absent_in_json_deserializes_as_none() {
let json = r#"{"extract_images":true,"target_dpi":300,"max_image_dimension":4096,
"inject_placeholders":true,"auto_adjust_dpi":true,
"min_dpi":72,"max_dpi":600}"#;
let config: ImageExtractionConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.max_images_per_page, None);
}
#[test]
fn test_include_page_rasters_defaults_false() {
let config = ImageExtractionConfig::default();
assert!(
!config.include_page_rasters,
"include_page_rasters must default to false"
);
}
#[test]
fn test_include_page_rasters_absent_in_json_deserializes_as_false() {
let json = r#"{"extract_images":true,"target_dpi":300,"max_image_dimension":4096,
"inject_placeholders":true,"auto_adjust_dpi":true,
"min_dpi":72,"max_dpi":600}"#;
let config: ImageExtractionConfig = serde_json::from_str(json).unwrap();
assert!(
!config.include_page_rasters,
"absent include_page_rasters must deserialize to false (backward compat)"
);
}
#[test]
fn test_include_page_rasters_roundtrips_via_json() {
let config = ImageExtractionConfig {
include_page_rasters: true,
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
let back: ImageExtractionConfig = serde_json::from_str(&json).unwrap();
assert!(back.include_page_rasters);
}
}

View File

@@ -0,0 +1,201 @@
//! Output format configuration and validation.
//!
//! This module defines the `OutputFormat` enum for controlling how extraction
//! results are formatted (plain text, markdown, HTML, etc.) and provides
//! serialization/deserialization support.
use serde::{Deserialize, Serialize};
use std::str::FromStr;
/// Output format for extraction results.
///
/// Controls the format of the `content` field in `ExtractionResult`.
/// When set to `Markdown`, `Djot`, or `Html`, the output uses that format.
/// `Plain` returns the raw extracted text.
/// `Structured` returns JSON with full OCR element data including bounding
/// boxes and confidence scores.
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum OutputFormat {
/// Plain text content only (default)
#[default]
Plain,
/// Markdown format
Markdown,
/// Djot markup format
Djot,
/// HTML format
Html,
/// JSON tree format with heading-driven sections.
Json,
/// Structured JSON format with full OCR element metadata.
Structured,
/// Custom renderer registered via the RendererRegistry.
/// The string is the renderer name (e.g., "docx", "latex").
#[serde(untagged)]
Custom(String),
}
#[cfg(test)]
impl OutputFormat {
/// Get the renderer name for this format.
/// Returns `None` for formats that don't use the renderer registry
/// (Plain, Structured, Toon — these are handled differently).
pub(crate) fn renderer_name(&self) -> Option<&str> {
match self {
OutputFormat::Plain | OutputFormat::Json | OutputFormat::Structured => None,
OutputFormat::Markdown => Some("markdown"),
OutputFormat::Djot => Some("djot"),
OutputFormat::Html => Some("html"),
OutputFormat::Custom(name) => Some(name.as_str()),
}
}
}
impl std::fmt::Display for OutputFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OutputFormat::Plain => write!(f, "plain"),
OutputFormat::Markdown => write!(f, "markdown"),
OutputFormat::Djot => write!(f, "djot"),
OutputFormat::Html => write!(f, "html"),
OutputFormat::Json => write!(f, "json"),
OutputFormat::Structured => write!(f, "structured"),
OutputFormat::Custom(name) => write!(f, "{}", name),
}
}
}
impl FromStr for OutputFormat {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"plain" | "text" => Ok(OutputFormat::Plain),
"markdown" | "md" => Ok(OutputFormat::Markdown),
"djot" => Ok(OutputFormat::Djot),
"html" => Ok(OutputFormat::Html),
"json" => Ok(OutputFormat::Json),
"structured" | "structured-ocr" => Ok(OutputFormat::Structured),
other => Ok(OutputFormat::Custom(other.to_string())),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_output_format_from_str_plain() {
assert_eq!("plain".parse::<OutputFormat>().unwrap(), OutputFormat::Plain);
assert_eq!("PLAIN".parse::<OutputFormat>().unwrap(), OutputFormat::Plain);
assert_eq!("text".parse::<OutputFormat>().unwrap(), OutputFormat::Plain);
assert_eq!("TEXT".parse::<OutputFormat>().unwrap(), OutputFormat::Plain);
}
#[test]
fn test_output_format_from_str_markdown() {
assert_eq!("markdown".parse::<OutputFormat>().unwrap(), OutputFormat::Markdown);
assert_eq!("MARKDOWN".parse::<OutputFormat>().unwrap(), OutputFormat::Markdown);
assert_eq!("md".parse::<OutputFormat>().unwrap(), OutputFormat::Markdown);
assert_eq!("MD".parse::<OutputFormat>().unwrap(), OutputFormat::Markdown);
}
#[test]
fn test_output_format_from_str_djot() {
assert_eq!("djot".parse::<OutputFormat>().unwrap(), OutputFormat::Djot);
assert_eq!("DJOT".parse::<OutputFormat>().unwrap(), OutputFormat::Djot);
assert_eq!("Djot".parse::<OutputFormat>().unwrap(), OutputFormat::Djot);
}
#[test]
fn test_output_format_from_str_html() {
assert_eq!("html".parse::<OutputFormat>().unwrap(), OutputFormat::Html);
assert_eq!("HTML".parse::<OutputFormat>().unwrap(), OutputFormat::Html);
assert_eq!("Html".parse::<OutputFormat>().unwrap(), OutputFormat::Html);
}
#[test]
fn test_output_format_from_str_json() {
assert_eq!("json".parse::<OutputFormat>().unwrap(), OutputFormat::Json);
assert_eq!("JSON".parse::<OutputFormat>().unwrap(), OutputFormat::Json);
}
#[test]
fn test_output_format_from_str_structured() {
assert_eq!("structured".parse::<OutputFormat>().unwrap(), OutputFormat::Structured);
assert_eq!("STRUCTURED".parse::<OutputFormat>().unwrap(), OutputFormat::Structured);
assert_eq!(
"structured-ocr".parse::<OutputFormat>().unwrap(),
OutputFormat::Structured
);
assert_eq!(
"STRUCTURED-OCR".parse::<OutputFormat>().unwrap(),
OutputFormat::Structured
);
}
#[test]
fn test_output_format_from_str_custom() {
let result = "docx".parse::<OutputFormat>().unwrap();
assert_eq!(result, OutputFormat::Custom("docx".to_string()));
}
#[test]
fn test_output_format_to_string() {
assert_eq!(OutputFormat::Plain.to_string(), "plain");
assert_eq!(OutputFormat::Markdown.to_string(), "markdown");
assert_eq!(OutputFormat::Djot.to_string(), "djot");
assert_eq!(OutputFormat::Html.to_string(), "html");
assert_eq!(OutputFormat::Json.to_string(), "json");
assert_eq!(OutputFormat::Structured.to_string(), "structured");
assert_eq!(OutputFormat::Custom("docx".to_string()).to_string(), "docx");
}
#[test]
fn test_output_format_default() {
let format = OutputFormat::default();
assert_eq!(format, OutputFormat::Plain);
}
#[test]
fn test_output_format_serde_roundtrip() {
for format in [
OutputFormat::Plain,
OutputFormat::Markdown,
OutputFormat::Djot,
OutputFormat::Html,
OutputFormat::Json,
OutputFormat::Structured,
] {
let json = serde_json::to_string(&format).unwrap();
let deserialized: OutputFormat = serde_json::from_str(&json).unwrap();
assert_eq!(format, deserialized);
}
}
#[test]
fn test_output_format_serde_values() {
assert_eq!(serde_json::to_string(&OutputFormat::Plain).unwrap(), "\"plain\"");
assert_eq!(serde_json::to_string(&OutputFormat::Markdown).unwrap(), "\"markdown\"");
assert_eq!(serde_json::to_string(&OutputFormat::Djot).unwrap(), "\"djot\"");
assert_eq!(serde_json::to_string(&OutputFormat::Html).unwrap(), "\"html\"");
assert_eq!(serde_json::to_string(&OutputFormat::Json).unwrap(), "\"json\"");
assert_eq!(
serde_json::to_string(&OutputFormat::Structured).unwrap(),
"\"structured\""
);
}
#[test]
fn test_output_format_renderer_name() {
assert_eq!(OutputFormat::Plain.renderer_name(), None);
assert_eq!(OutputFormat::Markdown.renderer_name(), Some("markdown"));
assert_eq!(OutputFormat::Html.renderer_name(), Some("html"));
assert_eq!(OutputFormat::Djot.renderer_name(), Some("djot"));
assert_eq!(OutputFormat::Json.renderer_name(), None);
assert_eq!(OutputFormat::Structured.renderer_name(), None);
assert_eq!(OutputFormat::Custom("docx".to_string()).renderer_name(), Some("docx"));
}
}

View File

@@ -0,0 +1,136 @@
//! HTML output configuration.
//!
//! Controls how `OutputFormat::Html` renders an `InternalDocument`:
//! which built-in theme to use, whether to embed the CSS in a `<style>`
//! block, and optional user-supplied CSS (inline string or file path).
use std::path::PathBuf;
use serde::{Deserialize, Serialize};
fn default_class_prefix() -> String {
"kb-".to_string()
}
fn default_true() -> bool {
true
}
/// Configuration for styled HTML output.
///
/// When set on [`ExtractionConfig::html_output`] alongside
/// `output_format = OutputFormat::Html`, the pipeline builds a
/// [`StyledHtmlRenderer`](crate::rendering::StyledHtmlRenderer) instead of
/// the plain comrak-based renderer.
///
/// # Example
///
/// ```rust
/// use kreuzberg::core::config::{HtmlOutputConfig, HtmlTheme};
///
/// let config = HtmlOutputConfig {
/// theme: HtmlTheme::GitHub,
/// css: Some(".kb-p { font-size: 1.1rem; }".to_string()),
/// ..Default::default()
/// };
/// ```
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HtmlOutputConfig {
/// Inline CSS string injected into the output after the theme stylesheet.
/// Concatenated after `css_file` content when both are set.
#[serde(default)]
pub css: Option<String>,
/// Path to a CSS file loaded once at renderer construction time.
/// Concatenated before `css` when both are set.
#[serde(default)]
pub css_file: Option<PathBuf>,
/// Built-in colour/typography theme. Default: [`HtmlTheme::Unstyled`].
#[serde(default)]
pub theme: HtmlTheme,
/// CSS class prefix applied to every emitted class name.
///
/// Default: `"kb-"`. Change this if your host application already uses
/// classes that start with `kb-`.
#[serde(default = "default_class_prefix")]
pub class_prefix: String,
/// When `true` (default), write the resolved CSS into a `<style>` block
/// immediately after the opening `<div class="{prefix}doc">`.
///
/// Set to `false` to emit only the structural markup and wire up your
/// own stylesheet targeting the `kb-*` class names.
#[serde(default = "default_true")]
pub embed_css: bool,
}
impl Default for HtmlOutputConfig {
fn default() -> Self {
Self {
css: None,
css_file: None,
theme: HtmlTheme::Unstyled,
class_prefix: default_class_prefix(),
embed_css: true,
}
}
}
/// Built-in HTML theme selection.
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum HtmlTheme {
/// Sensible defaults: system font stack, neutral colours, readable line
/// measure. CSS custom properties (`--kb-*`) are all defined so user CSS
/// can override individual values.
Default,
/// GitHub Markdown-inspired palette and spacing.
GitHub,
/// Dark background, light text.
Dark,
/// Minimal light theme with generous whitespace.
Light,
/// No built-in stylesheet emitted. CSS custom properties are still defined
/// on `:root` so user stylesheets can reference `var(--kb-*)` tokens.
#[default]
Unstyled,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_values() {
let cfg = HtmlOutputConfig::default();
assert_eq!(cfg.class_prefix, "kb-");
assert!(cfg.embed_css);
assert!(cfg.css.is_none());
assert!(cfg.css_file.is_none());
assert_eq!(cfg.theme, HtmlTheme::Unstyled);
}
#[test]
fn serde_roundtrip() {
let cfg = HtmlOutputConfig {
css: Some(".kb-p { color: red; }".to_string()),
theme: HtmlTheme::GitHub,
embed_css: false,
..Default::default()
};
let json = serde_json::to_string(&cfg).unwrap();
let back: HtmlOutputConfig = serde_json::from_str(&json).unwrap();
assert_eq!(back.css, cfg.css);
assert_eq!(back.theme, HtmlTheme::GitHub);
assert!(!back.embed_css);
}
#[test]
fn theme_serde() {
assert_eq!(serde_json::to_string(&HtmlTheme::GitHub).unwrap(), "\"github\"");
let t: HtmlTheme = serde_json::from_str("\"dark\"").unwrap();
assert_eq!(t, HtmlTheme::Dark);
}
}

View File

@@ -0,0 +1,180 @@
//! Layout detection configuration.
use std::fmt;
use serde::{Deserialize, Serialize};
/// Which table structure recognition model to use.
///
/// Controls the model used for table cell detection within layout-detected
/// table regions. Wire format is snake_case in all serializers (JSON, TOML,
/// YAML).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TableModel {
/// TATR (Table Transformer) -- default, 30MB, DETR-based row/column detection.
#[default]
Tatr,
/// SLANeXT wired variant -- 365MB, optimized for bordered tables.
SlanetWired,
/// SLANeXT wireless variant -- 365MB, optimized for borderless tables.
SlanetWireless,
/// SLANet-plus -- 7.78MB, lightweight general-purpose.
SlanetPlus,
/// Classifier-routed SLANeXT: auto-select wired/wireless per table.
/// Uses PP-LCNet classifier (6.78MB) + both SLANeXT variants (730MB total).
SlanetAuto,
/// Disable table structure model inference entirely; use heuristic path only.
Disabled,
}
impl std::str::FromStr for TableModel {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"tatr" => Ok(Self::Tatr),
"slanet_wired" => Ok(Self::SlanetWired),
"slanet_wireless" => Ok(Self::SlanetWireless),
"slanet_plus" => Ok(Self::SlanetPlus),
"slanet_auto" => Ok(Self::SlanetAuto),
"disabled" => Ok(Self::Disabled),
other => Err(format!(
"unknown table model: '{other}'. Valid: tatr, slanet_wired, slanet_wireless, slanet_plus, slanet_auto, disabled"
)),
}
}
}
impl fmt::Display for TableModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TableModel::Tatr => write!(f, "tatr"),
TableModel::SlanetWired => write!(f, "slanet_wired"),
TableModel::SlanetWireless => write!(f, "slanet_wireless"),
TableModel::SlanetPlus => write!(f, "slanet_plus"),
TableModel::SlanetAuto => write!(f, "slanet_auto"),
TableModel::Disabled => write!(f, "disabled"),
}
}
}
/// Layout detection configuration.
///
/// Controls layout detection behavior in the extraction pipeline.
/// When set on [`ExtractionConfig`](super::ExtractionConfig), layout detection
/// is enabled for PDF extraction.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayoutDetectionConfig {
/// Confidence threshold override (None = use model default).
#[serde(default, skip_serializing_if = "Option::is_none")]
pub confidence_threshold: Option<f32>,
/// Whether to apply postprocessing heuristics (default: true).
#[serde(default = "default_true")]
pub apply_heuristics: bool,
/// Table structure recognition model.
///
/// Controls which model is used for table cell detection within layout-detected
/// table regions. Defaults to [`TableModel::Tatr`].
#[serde(default)]
pub table_model: TableModel,
/// Hardware acceleration for ONNX models (layout detection + table structure).
///
/// When set, controls which execution provider (CPU, CUDA, CoreML, TensorRT)
/// is used for inference. Defaults to `None` (auto-select per platform).
#[serde(default, skip_serializing_if = "Option::is_none")]
pub acceleration: Option<super::acceleration::AccelerationConfig>,
}
impl Default for LayoutDetectionConfig {
fn default() -> Self {
Self {
confidence_threshold: None,
apply_heuristics: true,
table_model: TableModel::default(),
acceleration: None,
}
}
}
fn default_true() -> bool {
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = LayoutDetectionConfig::default();
assert_eq!(config.table_model, TableModel::Tatr);
assert!(config.apply_heuristics);
assert!(config.confidence_threshold.is_none());
}
#[test]
fn test_table_model_deserialize() {
let json = r#""tatr""#;
let model: TableModel = serde_json::from_str(json).unwrap();
assert_eq!(model, TableModel::Tatr);
let json = r#""slanet_auto""#;
let model: TableModel = serde_json::from_str(json).unwrap();
assert_eq!(model, TableModel::SlanetAuto);
let json = r#""disabled""#;
let model: TableModel = serde_json::from_str(json).unwrap();
assert_eq!(model, TableModel::Disabled);
}
#[test]
fn test_table_model_serialize() {
let json = serde_json::to_string(&TableModel::SlanetWired).unwrap();
assert_eq!(json, r#""slanet_wired""#);
}
#[test]
fn test_table_model_round_trip() {
for model in [
TableModel::Tatr,
TableModel::SlanetWired,
TableModel::SlanetWireless,
TableModel::SlanetPlus,
TableModel::SlanetAuto,
TableModel::Disabled,
] {
let serialized = serde_json::to_string(&model).unwrap();
let parsed: TableModel = serde_json::from_str(&serialized).unwrap();
assert_eq!(parsed, model, "round-trip failed for {model:?}");
}
}
#[test]
fn test_backward_compat_unknown_fields_ignored() {
// Old configs with "preset" field should still deserialize because
// serde ignores unknown fields by default.
let json = r#"{"preset": "accurate", "apply_heuristics": true}"#;
let config: LayoutDetectionConfig = serde_json::from_str(json).unwrap();
assert!(config.apply_heuristics);
assert_eq!(config.table_model, TableModel::Tatr);
}
#[test]
fn test_backward_compat_old_table_model_field() {
// Old configs with table_model as a string should still work
let json = r#"{"table_model": "slanet_wired"}"#;
let config: LayoutDetectionConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.table_model, TableModel::SlanetWired);
}
#[test]
fn test_table_model_display() {
assert_eq!(TableModel::Tatr.to_string(), "tatr");
assert_eq!(TableModel::SlanetWired.to_string(), "slanet_wired");
assert_eq!(TableModel::Disabled.to_string(), "disabled");
}
}

View File

@@ -0,0 +1,154 @@
//! LLM configuration types for liter-llm integration.
//!
//! These types are always available (not feature-gated) since they are
//! pure configuration data with no runtime dependency on liter-llm.
use serde::{Deserialize, Serialize};
/// Configuration for an LLM provider/model via liter-llm.
///
/// Each feature (VLM OCR, VLM embeddings, structured extraction) carries
/// its own `LlmConfig`, allowing different providers per feature.
///
/// # Example
///
/// ```toml
/// [structured_extraction.llm]
/// model = "openai/gpt-4o"
/// api_key = "sk-..." # or use KREUZBERG_LLM_API_KEY env var
/// ```
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LlmConfig {
/// Provider/model string using liter-llm routing format.
///
/// Examples: `"openai/gpt-4o"`, `"anthropic/claude-sonnet-4-20250514"`,
/// `"groq/llama-3.1-70b-versatile"`.
pub model: String,
/// API key for the provider. When `None`, liter-llm falls back to
/// the provider's standard environment variable (e.g., `OPENAI_API_KEY`).
#[serde(default, skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
/// Custom base URL override for the provider endpoint.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub base_url: Option<String>,
/// Request timeout in seconds (default: 60).
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timeout_secs: Option<u64>,
/// Maximum retry attempts (default: 3).
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_retries: Option<u32>,
/// Sampling temperature for generation tasks.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
/// Maximum tokens to generate.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u64>,
}
/// Configuration for LLM-based structured data extraction.
///
/// Sends extracted document content to a VLM with a JSON schema,
/// returning structured data that conforms to the schema.
///
/// # Example
///
/// ```toml
/// [structured_extraction]
/// schema_name = "invoice_data"
/// strict = true
///
/// [structured_extraction.schema]
/// type = "object"
/// properties.vendor = { type = "string" }
/// properties.total = { type = "number" }
/// required = ["vendor", "total"]
///
/// [structured_extraction.llm]
/// model = "openai/gpt-4o"
/// ```
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StructuredExtractionConfig {
/// JSON Schema defining the desired output structure.
pub schema: serde_json::Value,
/// Schema name passed to the LLM's structured output mode.
#[serde(default = "default_schema_name")]
pub schema_name: String,
/// Optional schema description for the LLM.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub schema_description: Option<String>,
/// Enable strict mode — output must exactly match the schema.
#[serde(default)]
pub strict: bool,
/// Custom Jinja2 extraction prompt template. When `None`, a default template is used.
///
/// Available template variables:
/// - `{{ content }}` — The extracted document text.
/// - `{{ schema }}` — The JSON schema as a formatted string.
/// - `{{ schema_name }}` — The schema name.
/// - `{{ schema_description }}` — The schema description (may be empty).
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
/// LLM configuration for the extraction.
pub llm: LlmConfig,
}
fn default_schema_name() -> String {
"extraction".to_string()
}
#[cfg(test)]
mod tests {
use super::*;
/// Regression test for https://github.com/kreuzberg-dev/kreuzberg/issues/716
///
/// `LlmConfig` must implement `Default` so callers can use the struct-update
/// syntax documented in the VLM OCR guide:
///
/// ```rust
/// use kreuzberg::core::config::LlmConfig;
/// let cfg = LlmConfig {
/// model: "openai/gpt-4o-mini".to_string(),
/// ..Default::default()
/// };
/// ```
#[test]
fn test_llm_config_default_trait_is_satisfied() {
let cfg = LlmConfig::default();
assert!(cfg.model.is_empty(), "default model should be empty string");
assert!(cfg.api_key.is_none());
assert!(cfg.base_url.is_none());
assert!(cfg.timeout_secs.is_none());
assert!(cfg.max_retries.is_none());
assert!(cfg.temperature.is_none());
assert!(cfg.max_tokens.is_none());
}
/// Verify the struct-update pattern from the issue compiles and produces
/// only the explicitly set field.
#[test]
fn test_llm_config_struct_update_syntax() {
let cfg = LlmConfig {
model: "openai/gpt-4o-mini".to_string(),
..Default::default()
};
assert_eq!(cfg.model, "openai/gpt-4o-mini");
assert!(cfg.api_key.is_none());
assert!(cfg.base_url.is_none());
assert!(cfg.timeout_secs.is_none());
assert!(cfg.max_retries.is_none());
assert!(cfg.temperature.is_none());
assert!(cfg.max_tokens.is_none());
}
}

View File

@@ -0,0 +1,150 @@
//! JSON-level configuration merging.
//!
//! Provides a unified merge function for combining a base `ExtractionConfig` with
//! JSON overrides. Used by both the CLI (`--config-json`) and MCP server to apply
//! partial configuration overrides without losing unspecified fields.
use super::ExtractionConfig;
/// Merge extraction configuration using JSON-level field override.
///
/// Serializes the base config to JSON, merges each field from the override JSON
/// (top-level only), and deserializes back. This correctly handles boolean fields
/// explicitly set to their default values — the override always wins for any field
/// present in `override_json`.
///
/// Fields **not** present in `override_json` are preserved from `base`.
///
/// # Errors
///
/// Returns `Err` if the base config cannot be serialized, or if the merged JSON
/// cannot be deserialized back into `ExtractionConfig` (e.g., wrong field types).
///
/// # Examples
///
/// ```rust,ignore
/// use kreuzberg::ExtractionConfig;
/// use serde_json::json;
///
/// let mut base = ExtractionConfig::default();
/// base.use_cache = true;
///
/// let overrides = r#"{"force_ocr": true}"#;
/// let merged = kreuzberg::core::config::merge::merge_config_json(&base, overrides).unwrap();
/// assert!(merged.use_cache); // preserved from base
/// assert!(merged.force_ocr); // applied from override
/// ```
#[cfg_attr(alef, alef(skip))]
pub fn merge_config_json(base: &ExtractionConfig, override_json: &str) -> Result<ExtractionConfig, String> {
let override_value: serde_json::Value =
serde_json::from_str(override_json).map_err(|e| format!("Failed to parse override JSON: {e}"))?;
let mut config_json =
serde_json::to_value(base).map_err(|e| format!("Failed to serialize base config to JSON: {e}"))?;
if let serde_json::Value::Object(json_obj) = override_value
&& let Some(config_obj) = config_json.as_object_mut()
{
for (key, value) in json_obj {
config_obj.insert(key, value);
}
}
serde_json::from_value(config_json).map_err(|e| format!("Failed to deserialize merged config: {e}"))
}
/// Build extraction config by optionally merging JSON overrides into a base config.
///
/// If `override_json` is `None`, returns a clone of `base`. Otherwise delegates
/// to [`merge_config_json`].
#[cfg_attr(alef, alef(skip))]
pub fn build_config_from_json(
base: &ExtractionConfig,
override_json: Option<&str>,
) -> Result<ExtractionConfig, String> {
match override_json {
Some(json) => merge_config_json(base, json),
None => Ok(base.clone()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_merge_preserves_unspecified_fields() {
let base = ExtractionConfig {
use_cache: false,
enable_quality_processing: true,
force_ocr: false,
..Default::default()
};
let merged = merge_config_json(&base, r#"{"force_ocr": true}"#).unwrap();
assert!(!merged.use_cache, "use_cache should be preserved from base");
assert!(
merged.enable_quality_processing,
"enable_quality_processing should be preserved"
);
assert!(merged.force_ocr, "force_ocr should be overridden");
}
#[test]
fn test_merge_override_to_default_value() {
let base = ExtractionConfig {
use_cache: false,
..Default::default()
};
let merged = merge_config_json(&base, r#"{"use_cache": true}"#).unwrap();
assert!(
merged.use_cache,
"Should use explicit override even if it matches the struct default"
);
}
#[test]
fn test_merge_multiple_fields() {
let base = ExtractionConfig {
use_cache: true,
force_ocr: true,
..Default::default()
};
let merged = merge_config_json(&base, r#"{"use_cache": false, "output_format": "markdown"}"#).unwrap();
assert!(!merged.use_cache);
assert!(merged.force_ocr, "force_ocr should be preserved");
assert_eq!(
merged.output_format,
crate::core::config::formats::OutputFormat::Markdown,
);
}
#[test]
fn test_merge_invalid_field_type_returns_error() {
let base = ExtractionConfig::default();
let result = merge_config_json(&base, r#"{"use_cache": "not_a_boolean"}"#);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Failed to deserialize"));
}
#[test]
fn test_build_config_from_json_none_returns_clone() {
let base = ExtractionConfig {
use_cache: false,
..Default::default()
};
let result = build_config_from_json(&base, None).unwrap();
assert!(!result.use_cache);
}
#[test]
fn test_build_config_from_json_some_merges() {
let base = ExtractionConfig::default();
let result = build_config_from_json(&base, Some(r#"{"force_ocr": true}"#)).unwrap();
assert!(result.force_ocr);
}
}

View File

@@ -0,0 +1,47 @@
//! Configuration loading and management.
//!
//! This module provides utilities for loading extraction configuration from various
//! sources (TOML, YAML, JSON) and discovering configuration files in the project hierarchy.
pub mod acceleration;
pub mod concurrency;
pub mod content_filter;
pub mod email;
pub mod extraction;
pub mod formats;
#[cfg(feature = "html")]
pub mod html_output;
pub mod layout;
pub mod llm;
pub mod merge;
pub mod ocr;
pub mod page;
pub mod pdf;
pub mod processing;
#[cfg(feature = "tree-sitter")]
pub mod tree_sitter;
// Re-export main types for backward compatibility
pub use acceleration::{AccelerationConfig, ExecutionProviderType};
pub use concurrency::ConcurrencyConfig;
pub use content_filter::ContentFilterConfig;
pub use email::EmailConfig;
pub use extraction::{
BatchBytesItem, BatchFileItem, ExtractionConfig, FileExtractionConfig, ImageExtractionConfig,
LanguageDetectionConfig, TokenReductionOptions,
};
pub use formats::OutputFormat;
#[cfg(feature = "html")]
pub use html_output::{HtmlOutputConfig, HtmlTheme};
#[cfg(feature = "layout-types")]
pub use layout::{LayoutDetectionConfig, TableModel};
pub use llm::{LlmConfig, StructuredExtractionConfig};
pub use ocr::{OcrConfig, OcrPipelineConfig, OcrPipelineStage, OcrQualityThresholds};
pub use page::PageConfig;
#[cfg(feature = "pdf")]
pub use pdf::{HierarchyConfig, PdfConfig};
pub use processing::{
ChunkSizing, ChunkerType, ChunkingConfig, EmbeddingConfig, EmbeddingModelType, PostProcessorConfig,
};
#[cfg(feature = "tree-sitter")]
pub use tree_sitter::{CodeContentMode, TreeSitterConfig, TreeSitterProcessConfig};

View File

@@ -0,0 +1,932 @@
//! OCR configuration.
//!
//! Defines OCR-specific configuration including backend selection, language settings,
//! Tesseract-specific parameters, quality thresholds, and multi-backend pipeline config.
use serde::{Deserialize, Serialize};
use super::formats::OutputFormat;
#[cfg(test)]
use crate::core::config_validation::validate_ocr_backend;
#[cfg(test)]
use crate::error::KreuzbergError;
use crate::types::OcrElementConfig;
/// Quality thresholds for OCR fallback decisions and pipeline quality gating.
///
/// All fields default to the values that match the previous hardcoded behavior,
/// so `OcrQualityThresholds::default()` preserves existing semantics exactly.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrQualityThresholds {
/// Minimum total non-whitespace characters to consider text substantive.
#[serde(default = "default_min_total_non_whitespace")]
pub min_total_non_whitespace: usize,
/// Minimum non-whitespace characters per page on average.
#[serde(default = "default_min_non_whitespace_per_page")]
pub min_non_whitespace_per_page: f64,
/// Minimum character count for a word to be "meaningful".
#[serde(default = "default_min_meaningful_word_len")]
pub min_meaningful_word_len: usize,
/// Minimum count of meaningful words before text is accepted.
#[serde(default = "default_min_meaningful_words")]
pub min_meaningful_words: usize,
/// Minimum alphanumeric ratio (non-whitespace chars that are alphanumeric).
#[serde(default = "default_min_alnum_ratio")]
pub min_alnum_ratio: f64,
/// Minimum Unicode replacement characters (U+FFFD) to trigger OCR fallback.
#[serde(default = "default_min_garbage_chars")]
pub min_garbage_chars: usize,
/// Maximum fraction of short (1-2 char) words before text is considered fragmented.
#[serde(default = "default_max_fragmented_word_ratio")]
pub max_fragmented_word_ratio: f64,
/// Critical fragmentation threshold — triggers OCR regardless of meaningful words.
/// Normal English text has ~20-30% short words. 80%+ is definitive garbage.
#[serde(default = "default_critical_fragmented_word_ratio")]
pub critical_fragmented_word_ratio: f64,
/// Minimum average word length. Below this with enough words indicates garbled extraction.
#[serde(default = "default_min_avg_word_length")]
pub min_avg_word_length: f64,
/// Minimum word count before average word length check applies.
#[serde(default = "default_min_words_for_avg_length_check")]
pub min_words_for_avg_length_check: usize,
/// Minimum consecutive word repetition ratio to detect column scrambling.
#[serde(default = "default_min_consecutive_repeat_ratio")]
pub min_consecutive_repeat_ratio: f64,
/// Minimum word count before consecutive repetition check is applied.
#[serde(default = "default_min_words_for_repeat_check")]
pub min_words_for_repeat_check: usize,
/// Minimum character count for "substantive markdown" OCR skip gate.
#[serde(default = "default_substantive_min_chars")]
pub substantive_min_chars: usize,
/// Minimum character count for "non-text content" OCR skip gate.
#[serde(default = "default_non_text_min_chars")]
pub non_text_min_chars: usize,
/// Alphanumeric+whitespace ratio threshold for skip decisions.
#[serde(default = "default_alnum_ws_ratio_threshold")]
pub alnum_ws_ratio_threshold: f64,
/// Minimum quality score (0.0-1.0) for a pipeline stage result to be accepted.
/// If the result from a backend scores below this, try the next backend.
#[serde(default = "default_pipeline_min_quality")]
pub pipeline_min_quality: f64,
}
impl Default for OcrQualityThresholds {
fn default() -> Self {
Self {
min_total_non_whitespace: 64,
min_non_whitespace_per_page: 32.0,
min_meaningful_word_len: 4,
min_meaningful_words: 3,
min_alnum_ratio: 0.3,
min_garbage_chars: 5,
max_fragmented_word_ratio: 0.6,
critical_fragmented_word_ratio: 0.80,
min_avg_word_length: 2.0,
min_words_for_avg_length_check: 50,
min_consecutive_repeat_ratio: 0.08,
min_words_for_repeat_check: 50,
substantive_min_chars: 100,
non_text_min_chars: 20,
alnum_ws_ratio_threshold: 0.4,
pipeline_min_quality: 0.5,
}
}
}
fn default_min_total_non_whitespace() -> usize {
64
}
fn default_min_non_whitespace_per_page() -> f64 {
32.0
}
fn default_min_meaningful_word_len() -> usize {
4
}
fn default_min_meaningful_words() -> usize {
3
}
fn default_min_alnum_ratio() -> f64 {
0.3
}
fn default_min_garbage_chars() -> usize {
5
}
fn default_max_fragmented_word_ratio() -> f64 {
0.6
}
fn default_critical_fragmented_word_ratio() -> f64 {
0.80
}
fn default_min_avg_word_length() -> f64 {
2.0
}
fn default_min_words_for_avg_length_check() -> usize {
50
}
fn default_min_consecutive_repeat_ratio() -> f64 {
0.08
}
fn default_min_words_for_repeat_check() -> usize {
50
}
fn default_substantive_min_chars() -> usize {
100
}
fn default_non_text_min_chars() -> usize {
20
}
fn default_alnum_ws_ratio_threshold() -> f64 {
0.4
}
fn default_pipeline_min_quality() -> f64 {
0.5
}
/// A single backend stage in the OCR pipeline.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrPipelineStage {
/// Backend name: "tesseract", "paddleocr", "easyocr", or a custom registered name.
pub backend: String,
/// Priority weight (higher = tried first). Stages are sorted by priority descending.
#[serde(default = "default_priority")]
pub priority: u32,
/// Language override for this stage (None = use parent OcrConfig.language).
#[serde(default, skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
/// Tesseract-specific config override for this stage.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tesseract_config: Option<crate::types::TesseractConfig>,
/// PaddleOCR-specific config for this stage.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub paddle_ocr_config: Option<serde_json::Value>,
/// VLM config override for this pipeline stage.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub vlm_config: Option<super::llm::LlmConfig>,
/// Arbitrary per-call options passed through to the backend unchanged.
///
/// Backends that support runtime tuning (mode switching, preprocessing
/// flags, inference parameters, etc.) read this value and deserialize
/// the keys they care about. Keys unknown to the backend are silently
/// ignored, so options from different backends can coexist in the same
/// config without conflict.
///
/// Example (custom backend):
/// ```json
/// { "mode": "fast", "enable_layout": true }
/// ```
#[serde(default, skip_serializing_if = "Option::is_none")]
pub backend_options: Option<serde_json::Value>,
}
fn default_priority() -> u32 {
100
}
/// Multi-backend OCR pipeline with quality-based fallback.
///
/// Backends are tried in priority order (highest first). After each backend
/// produces output, quality is evaluated. If it meets `quality_thresholds.pipeline_min_quality`,
/// the result is accepted. Otherwise the next backend is tried.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrPipelineConfig {
/// Ordered list of backends to try. Sorted by priority (descending) at runtime.
pub stages: Vec<OcrPipelineStage>,
/// Quality thresholds for deciding whether to accept a result or try the next backend.
#[serde(default)]
pub quality_thresholds: OcrQualityThresholds,
}
/// OCR configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrConfig {
/// Whether OCR is enabled.
///
/// Setting `enabled: false` is a shorthand for `disable_ocr: true` on the parent
/// [`ExtractionConfig`](crate::core::config::ExtractionConfig). Images return
/// metadata only; PDFs use native text extraction without OCR fallback.
///
/// Defaults to `true`. When `false`, all other OCR settings are ignored.
#[serde(default = "default_ocr_enabled")]
pub enabled: bool,
/// OCR backend: tesseract, easyocr, paddleocr
#[serde(default = "default_tesseract_backend")]
pub backend: String,
/// Language code (e.g., "eng", "deu")
#[serde(default = "default_eng")]
pub language: String,
/// Tesseract-specific configuration (optional)
#[serde(default)]
pub tesseract_config: Option<crate::types::TesseractConfig>,
/// Output format for OCR results (optional, for format conversion)
#[serde(default)]
pub output_format: Option<OutputFormat>,
/// PaddleOCR-specific configuration (optional, JSON passthrough)
#[serde(default, skip_serializing_if = "Option::is_none")]
pub paddle_ocr_config: Option<serde_json::Value>,
/// Arbitrary per-call options passed through to the backend unchanged.
///
/// Custom OCR backends and built-in backends that support runtime tuning
/// can read this value and deserialize the keys they care about. Keys
/// unknown to the backend are silently ignored.
///
/// This is the recommended extension point for per-call parameters that
/// are not covered by the typed fields above (e.g. mode switching,
/// preprocessing flags, inference batch size).
///
/// **Scope:** when `pipeline` is `None`, this value is propagated to the
/// primary stage of the auto-constructed pipeline. When `pipeline` is
/// explicitly set, this field has **no effect** — the caller must set
/// `OcrPipelineStage.backend_options` directly on the relevant stage(s)
/// instead.
///
/// Example:
/// ```json
/// { "mode": "fast", "enable_layout": true, "timeout_ms": 5000 }
/// ```
#[serde(default, skip_serializing_if = "Option::is_none")]
pub backend_options: Option<serde_json::Value>,
/// OCR element extraction configuration
#[serde(default, skip_serializing_if = "Option::is_none")]
pub element_config: Option<OcrElementConfig>,
/// Quality thresholds for the native-text-to-OCR fallback decision.
/// When None, uses compiled defaults (matching previous hardcoded behavior).
#[serde(default, skip_serializing_if = "Option::is_none")]
pub quality_thresholds: Option<OcrQualityThresholds>,
/// Multi-backend OCR pipeline configuration. When set, enables weighted
/// fallback across multiple OCR backends based on output quality.
/// When None, uses the single `backend` field (same as today).
#[serde(default, skip_serializing_if = "Option::is_none")]
pub pipeline: Option<OcrPipelineConfig>,
/// Enable automatic page rotation based on orientation detection.
///
/// When enabled, uses Tesseract's `DetectOrientationScript()` to detect
/// page orientation (0/90/180/270 degrees) before OCR. If the page is
/// rotated with high confidence, the image is corrected before recognition.
/// This is critical for handling rotated scanned documents.
#[serde(default)]
pub auto_rotate: bool,
/// VLM (Vision Language Model) OCR configuration.
///
/// Required when `backend` is `"vlm"`. Uses liter-llm to send page
/// images to a vision model for text extraction.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub vlm_config: Option<super::llm::LlmConfig>,
/// Custom Jinja2 prompt template for VLM OCR.
///
/// When `None`, uses the default template. Available variables:
/// - `{{ language }}` — The document language code (e.g., "eng", "deu").
#[serde(default, skip_serializing_if = "Option::is_none")]
pub vlm_prompt: Option<String>,
/// Hardware acceleration for ONNX Runtime models (e.g. PaddleOCR, layout detection).
///
/// Not user-configurable via config files — injected at runtime from
/// `ExtractionConfig::acceleration` before each `process_image` call.
#[serde(skip)]
pub acceleration: Option<super::acceleration::AccelerationConfig>,
/// Caller-supplied Tesseract `traineddata` bytes per language code.
///
/// Primary use case is the WASM build, which has no filesystem and cannot
/// download tessdata at runtime. Native builds typically rely on
/// `TessdataManager` and ignore this field. When present, the WASM
/// Tesseract backend prefers these bytes over its compile-time-bundled
/// English data.
///
/// Skipped by serde to keep config files small — supply via the typed API
/// at runtime.
#[serde(skip)]
pub tessdata_bytes: Option<std::collections::HashMap<String, Vec<u8>>>,
}
impl Default for OcrConfig {
fn default() -> Self {
Self {
enabled: true,
backend: default_tesseract_backend(),
language: default_eng(),
tesseract_config: None,
output_format: None,
paddle_ocr_config: None,
backend_options: None,
element_config: None,
quality_thresholds: None,
pipeline: None,
auto_rotate: false,
vlm_config: None,
vlm_prompt: None,
acceleration: None,
tessdata_bytes: None,
}
}
}
impl OcrConfig {
/// Validates that the configured backend is supported.
///
/// This method checks that the backend name is one of the supported OCR backends:
/// - tesseract
/// - easyocr
/// - paddleocr
///
/// Typos in backend names are caught at configuration validation time, not at runtime.
/// Also validates pipeline stage backends when a pipeline is configured.
#[cfg(test)]
pub(crate) fn validate(&self) -> Result<(), KreuzbergError> {
validate_ocr_backend(&self.backend)?;
// When backend is "vlm", vlm_config must be present.
crate::core::config_validation::validate_vlm_backend_config(&self.backend, self.vlm_config.as_ref())?;
if let Some(ref pipeline) = self.pipeline {
for stage in &pipeline.stages {
validate_ocr_backend(&stage.backend)?;
crate::core::config_validation::validate_vlm_backend_config(&stage.backend, stage.vlm_config.as_ref())?;
}
}
Ok(())
}
/// Returns the effective quality thresholds, using configured values or defaults.
#[cfg(feature = "ocr")]
pub(crate) fn effective_thresholds(&self) -> OcrQualityThresholds {
self.quality_thresholds.clone().unwrap_or_default()
}
/// Returns the effective pipeline config.
///
/// - If `pipeline` is explicitly set, returns it.
/// - If `paddle-ocr` is compiled in and the backend is the default
/// (tesseract), auto-constructs `[tesseract @ 100, paddleocr @ 50]`.
/// - Otherwise returns `None` (single-backend mode).
///
/// Explicit non-default backend selections are honored as-is — a silent
/// paddleocr fallback would mask errors from the chosen backend.
#[cfg(feature = "ocr")]
pub(crate) fn effective_pipeline(&self) -> Option<OcrPipelineConfig> {
if self.pipeline.is_some() {
return self.pipeline.clone();
}
#[cfg(feature = "paddle-ocr")]
{
if self.backend != default_tesseract_backend() {
return None;
}
let stages = vec![
OcrPipelineStage {
backend: self.backend.clone(),
priority: 100,
language: None,
tesseract_config: self.tesseract_config.clone(),
paddle_ocr_config: None,
vlm_config: self.vlm_config.clone(),
backend_options: self.backend_options.clone(),
},
OcrPipelineStage {
backend: "paddleocr".to_string(),
priority: 50,
language: None,
tesseract_config: None,
paddle_ocr_config: self.paddle_ocr_config.clone(),
vlm_config: None,
backend_options: None,
},
];
Some(OcrPipelineConfig {
stages,
quality_thresholds: self.effective_thresholds(),
})
}
#[cfg(not(feature = "paddle-ocr"))]
{
None
}
}
}
fn default_ocr_enabled() -> bool {
true
}
fn default_tesseract_backend() -> String {
"tesseract".to_string()
}
fn default_eng() -> String {
"eng".to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ocr_config_default() {
let config = OcrConfig::default();
assert_eq!(config.backend, "tesseract");
assert_eq!(config.language, "eng");
assert!(config.tesseract_config.is_none());
assert!(config.output_format.is_none());
}
#[test]
fn test_ocr_config_with_tesseract() {
let config = OcrConfig {
backend: "tesseract".to_string(),
language: "fra".to_string(),
..Default::default()
};
assert_eq!(config.backend, "tesseract");
assert_eq!(config.language, "fra");
}
#[test]
fn test_validate_tesseract_backend() {
let config = OcrConfig {
backend: "tesseract".to_string(),
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_validate_easyocr_backend() {
let config = OcrConfig {
backend: "easyocr".to_string(),
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_validate_paddleocr_backend() {
let config = OcrConfig {
backend: "paddleocr".to_string(),
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_validate_invalid_backend_typo() {
let config = OcrConfig {
backend: "tesseract_typo".to_string(),
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Invalid OCR backend"));
}
#[test]
fn test_validate_invalid_backend_completely_wrong() {
let config = OcrConfig {
backend: "ocr_lib".to_string(),
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Invalid OCR backend") || err_msg.contains("Valid options are"));
}
#[test]
fn test_validate_default_backend() {
let config = OcrConfig::default();
assert!(config.validate().is_ok());
}
// ── effective_pipeline tests ──
#[cfg(feature = "ocr")]
#[test]
fn test_effective_pipeline_explicit_pipeline_returned_unchanged() {
let explicit_pipeline = OcrPipelineConfig {
stages: vec![OcrPipelineStage {
backend: "easyocr".to_string(),
priority: 200,
language: Some("fra".to_string()),
tesseract_config: None,
paddle_ocr_config: None,
vlm_config: None,
backend_options: None,
}],
quality_thresholds: OcrQualityThresholds::default(),
};
let config = OcrConfig {
pipeline: Some(explicit_pipeline.clone()),
..Default::default()
};
let result = config.effective_pipeline().unwrap();
assert_eq!(result.stages.len(), 1);
assert_eq!(result.stages[0].backend, "easyocr");
assert_eq!(result.stages[0].priority, 200);
assert_eq!(result.stages[0].language, Some("fra".to_string()));
}
#[cfg(feature = "ocr")]
#[test]
fn test_effective_pipeline_explicit_paddleocr_no_autofallback() {
let config = OcrConfig {
backend: "paddleocr".to_string(),
..Default::default()
};
assert!(config.effective_pipeline().is_none());
}
#[cfg(feature = "ocr")]
#[test]
fn test_effective_pipeline_explicit_easyocr_no_autofallback() {
let config = OcrConfig {
backend: "easyocr".to_string(),
..Default::default()
};
assert!(config.effective_pipeline().is_none());
}
#[cfg(feature = "ocr")]
#[test]
fn test_effective_pipeline_default_tesseract_backend() {
let config = OcrConfig::default();
let result = config.effective_pipeline();
#[cfg(feature = "paddle-ocr")]
{
let pipeline = result.unwrap();
assert_eq!(pipeline.stages.len(), 2);
assert_eq!(pipeline.stages[0].backend, "tesseract");
assert_eq!(pipeline.stages[0].priority, 100);
assert_eq!(pipeline.stages[1].backend, "paddleocr");
assert_eq!(pipeline.stages[1].priority, 50);
}
#[cfg(not(feature = "paddle-ocr"))]
{
assert!(result.is_none());
}
}
#[cfg(feature = "ocr")]
#[test]
fn test_effective_thresholds_custom_vs_default() {
// With custom thresholds
let custom = OcrQualityThresholds {
min_total_non_whitespace: 128,
min_meaningful_words: 10,
..Default::default()
};
let config_custom = OcrConfig {
quality_thresholds: Some(custom.clone()),
..Default::default()
};
let eff = config_custom.effective_thresholds();
assert_eq!(eff.min_total_non_whitespace, 128);
assert_eq!(eff.min_meaningful_words, 10);
// Without custom thresholds (should return defaults)
let config_default = OcrConfig::default();
let eff_default = config_default.effective_thresholds();
assert_eq!(eff_default.min_total_non_whitespace, 64);
assert_eq!(eff_default.min_meaningful_words, 3);
}
// ── Serde tests ──
#[test]
fn test_pipeline_config_serde_roundtrip() {
let pipeline = OcrPipelineConfig {
stages: vec![
OcrPipelineStage {
backend: "tesseract".to_string(),
priority: 100,
language: Some("eng".to_string()),
tesseract_config: None,
paddle_ocr_config: None,
vlm_config: None,
backend_options: None,
},
OcrPipelineStage {
backend: "paddleocr".to_string(),
priority: 50,
language: None,
tesseract_config: None,
paddle_ocr_config: Some(serde_json::json!({"use_gpu": false})),
vlm_config: None,
backend_options: None,
},
],
quality_thresholds: OcrQualityThresholds::default(),
};
let json = serde_json::to_string(&pipeline).unwrap();
let deserialized: OcrPipelineConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.stages.len(), 2);
assert_eq!(deserialized.stages[0].backend, "tesseract");
assert_eq!(deserialized.stages[0].priority, 100);
assert_eq!(deserialized.stages[1].backend, "paddleocr");
assert_eq!(deserialized.stages[1].priority, 50);
assert!(deserialized.stages[1].paddle_ocr_config.is_some());
}
#[test]
fn test_pipeline_stage_deserialization_missing_optional_fields() {
// Only backend is required; everything else should use defaults
let json = r#"{"backend": "tesseract"}"#;
let stage: OcrPipelineStage = serde_json::from_str(json).unwrap();
assert_eq!(stage.backend, "tesseract");
assert_eq!(stage.priority, 100); // default_priority
assert!(stage.language.is_none());
assert!(stage.tesseract_config.is_none());
assert!(stage.paddle_ocr_config.is_none());
}
#[test]
fn test_pipeline_stage_default_priority_is_100() {
let json = r#"{"backend": "easyocr"}"#;
let stage: OcrPipelineStage = serde_json::from_str(json).unwrap();
assert_eq!(stage.priority, 100);
}
#[test]
fn test_ocr_config_deserialization_missing_optional_fields() {
let json = r#"{}"#;
let config: OcrConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.backend, "tesseract");
assert_eq!(config.language, "eng");
assert!(config.pipeline.is_none());
assert!(config.quality_thresholds.is_none());
assert!(config.element_config.is_none());
}
#[test]
fn test_quality_thresholds_deserialization_partial() {
let json = r#"{"min_total_non_whitespace": 256}"#;
let thresholds: OcrQualityThresholds = serde_json::from_str(json).unwrap();
assert_eq!(thresholds.min_total_non_whitespace, 256);
// All other fields should be defaults
assert_eq!(thresholds.min_meaningful_words, 3);
assert_eq!(thresholds.min_garbage_chars, 5);
assert!((thresholds.pipeline_min_quality - 0.5).abs() < f64::EPSILON);
}
// ── Validation tests ──
#[test]
fn test_validate_catches_invalid_pipeline_stage_backend() {
let config = OcrConfig {
pipeline: Some(OcrPipelineConfig {
stages: vec![
OcrPipelineStage {
backend: "tesseract".to_string(),
priority: 100,
language: None,
tesseract_config: None,
paddle_ocr_config: None,
vlm_config: None,
backend_options: None,
},
OcrPipelineStage {
backend: "invalid_backend".to_string(),
priority: 50,
language: None,
tesseract_config: None,
paddle_ocr_config: None,
vlm_config: None,
backend_options: None,
},
],
quality_thresholds: OcrQualityThresholds::default(),
}),
..Default::default()
};
let result = config.validate();
assert!(result.is_err(), "Should catch invalid backend in pipeline stages");
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Invalid OCR backend") || err_msg.contains("invalid_backend"));
}
// ── backend_options tests ──
#[test]
fn test_ocr_config_backend_options_default_is_none() {
let config = OcrConfig::default();
assert!(config.backend_options.is_none());
}
#[test]
fn test_ocr_config_backend_options_serde_roundtrip() {
let config = OcrConfig {
backend_options: Some(serde_json::json!({"mode": "fast", "threshold": 0.8, "enable_layout": true})),
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: OcrConfig = serde_json::from_str(&json).unwrap();
let opts = deserialized.backend_options.unwrap();
assert_eq!(opts["mode"], "fast");
assert!((opts["threshold"].as_f64().unwrap() - 0.8).abs() < f64::EPSILON);
assert_eq!(opts["enable_layout"], true);
}
#[test]
fn test_ocr_config_backend_options_omitted_when_none() {
let config = OcrConfig::default();
let json = serde_json::to_string(&config).unwrap();
assert!(
!json.contains("backend_options"),
"backend_options must be omitted when None"
);
}
#[test]
fn test_pipeline_stage_backend_options_serde_roundtrip() {
let stage = OcrPipelineStage {
backend: "custom".to_string(),
priority: 80,
language: None,
tesseract_config: None,
paddle_ocr_config: None,
vlm_config: None,
backend_options: Some(serde_json::json!({"batch_size": 4, "device": "cpu"})),
};
let json = serde_json::to_string(&stage).unwrap();
let deserialized: OcrPipelineStage = serde_json::from_str(&json).unwrap();
let opts = deserialized.backend_options.unwrap();
assert_eq!(opts["batch_size"], 4);
assert_eq!(opts["device"], "cpu");
}
#[test]
fn test_pipeline_stage_backend_options_omitted_when_none() {
let stage = OcrPipelineStage {
backend: "tesseract".to_string(),
priority: 100,
language: None,
tesseract_config: None,
paddle_ocr_config: None,
vlm_config: None,
backend_options: None,
};
let json = serde_json::to_string(&stage).unwrap();
assert!(
!json.contains("backend_options"),
"backend_options must be omitted when None"
);
}
#[cfg(all(feature = "ocr", feature = "paddle-ocr"))]
#[test]
fn test_effective_pipeline_propagates_backend_options_to_primary_stage() {
let config = OcrConfig {
backend_options: Some(serde_json::json!({"mode": "fast"})),
..Default::default()
};
let pipeline = config
.effective_pipeline()
.expect("paddle-ocr feature must produce a pipeline");
assert_eq!(pipeline.stages.len(), 2);
let primary = &pipeline.stages[0];
assert_eq!(primary.backend, "tesseract");
let opts = primary
.backend_options
.as_ref()
.expect("primary stage must carry backend_options");
assert_eq!(opts["mode"], "fast");
// PaddleOCR stage should not inherit backend_options from the top-level config.
let fallback = &pipeline.stages[1];
assert_eq!(fallback.backend, "paddleocr");
assert!(
fallback.backend_options.is_none(),
"paddleocr stage must not inherit backend_options"
);
}
#[cfg(feature = "ocr")]
#[test]
fn test_explicit_pipeline_ignores_top_level_backend_options() {
// When the caller provides an explicit pipeline, OcrConfig.backend_options
// must NOT be injected into the returned stages — the stage owns its own value.
let config = OcrConfig {
backend_options: Some(serde_json::json!({"mode": "fast"})),
pipeline: Some(OcrPipelineConfig {
stages: vec![OcrPipelineStage {
backend: "tesseract".to_string(),
priority: 100,
language: None,
tesseract_config: None,
paddle_ocr_config: None,
vlm_config: None,
backend_options: None,
}],
quality_thresholds: OcrQualityThresholds::default(),
}),
..Default::default()
};
let pipeline = config
.effective_pipeline()
.expect("explicit pipeline must be returned as-is");
assert_eq!(pipeline.stages.len(), 1);
assert!(
pipeline.stages[0].backend_options.is_none(),
"top-level backend_options must not be injected into an explicit pipeline stage"
);
}
#[cfg(feature = "ocr")]
#[test]
fn test_stage_level_backend_options_preserved_in_explicit_pipeline() {
// Stage-level backend_options in an explicit pipeline are returned unchanged —
// neither cleared nor overridden by any top-level value.
let stage_opts = serde_json::json!({"device": "gpu", "batch": 8});
let config = OcrConfig {
backend_options: Some(serde_json::json!({"mode": "fast"})),
pipeline: Some(OcrPipelineConfig {
stages: vec![OcrPipelineStage {
backend: "custom".to_string(),
priority: 100,
language: None,
tesseract_config: None,
paddle_ocr_config: None,
vlm_config: None,
backend_options: Some(stage_opts.clone()),
}],
quality_thresholds: OcrQualityThresholds::default(),
}),
..Default::default()
};
let pipeline = config
.effective_pipeline()
.expect("explicit pipeline must be returned as-is");
let returned_opts = pipeline.stages[0]
.backend_options
.as_ref()
.expect("stage-level backend_options must be preserved");
assert_eq!(returned_opts["device"], "gpu");
assert_eq!(returned_opts["batch"], 8);
}
#[test]
fn test_validate_passes_with_valid_pipeline_stages() {
let config = OcrConfig {
pipeline: Some(OcrPipelineConfig {
stages: vec![
OcrPipelineStage {
backend: "tesseract".to_string(),
priority: 100,
language: None,
tesseract_config: None,
paddle_ocr_config: None,
vlm_config: None,
backend_options: None,
},
OcrPipelineStage {
backend: "paddleocr".to_string(),
priority: 50,
language: None,
tesseract_config: None,
paddle_ocr_config: None,
vlm_config: None,
backend_options: None,
},
],
quality_thresholds: OcrQualityThresholds::default(),
}),
..Default::default()
};
assert!(config.validate().is_ok());
}
}

View File

@@ -0,0 +1,57 @@
//! Page extraction and tracking configuration.
//!
//! Controls how pages are extracted, tracked, and represented in extraction results.
//! When `None`, page tracking is disabled.
use serde::{Deserialize, Serialize};
/// Page extraction and tracking configuration.
///
/// Controls how pages are extracted, tracked, and represented in the extraction results.
/// When `None`, page tracking is disabled.
///
/// Page range tracking in chunk metadata (first_page/last_page) is automatically enabled
/// when page boundaries are available and chunking is configured.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct PageConfig {
/// Extract pages as separate array (ExtractionResult.pages)
#[serde(default)]
pub extract_pages: bool,
/// Insert page markers in main content string
#[serde(default)]
pub insert_page_markers: bool,
/// Page marker format (use {page_num} placeholder)
/// Default: "\n\n<!-- PAGE {page_num} -->\n\n"
#[serde(default = "default_page_marker_format")]
pub marker_format: String,
}
impl Default for PageConfig {
fn default() -> Self {
Self {
extract_pages: false,
insert_page_markers: false,
marker_format: "\n\n<!-- PAGE {page_num} -->\n\n".to_string(),
}
}
}
fn default_page_marker_format() -> String {
"\n\n<!-- PAGE {page_num} -->\n\n".to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_page_config_default() {
let config = PageConfig::default();
assert!(!config.extract_pages);
assert!(!config.insert_page_markers);
assert_eq!(config.marker_format, "\n\n<!-- PAGE {page_num} -->\n\n");
}
}

View File

@@ -0,0 +1,192 @@
//! PDF-specific configuration.
//!
//! Defines PDF extraction options including metadata handling, image extraction,
//! password management, and hierarchy extraction for document structure analysis.
use serde::{Deserialize, Serialize};
/// PDF-specific configuration.
#[cfg(feature = "pdf")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PdfConfig {
/// Extract images from PDF
#[serde(default)]
pub extract_images: bool,
/// Extract tables from PDF.
///
/// When `true` (default), runs pdf_oxide's native grid detector and, if it
/// finds nothing, falls back to the heuristic text-layer reconstruction in
/// `pdf::oxide::table::extract_tables_heuristic`. Set to `false` to skip
/// both passes — `tables` will then be empty in the result.
#[serde(default = "default_true")]
pub extract_tables: bool,
/// List of passwords to try when opening encrypted PDFs
#[serde(default)]
pub passwords: Option<Vec<String>>,
/// Extract PDF metadata
#[serde(default = "default_true")]
pub extract_metadata: bool,
/// Hierarchy extraction configuration (None = hierarchy extraction disabled)
#[serde(default)]
pub hierarchy: Option<HierarchyConfig>,
/// Extract PDF annotations (text notes, highlights, links, stamps).
/// Default: false
#[serde(default)]
pub extract_annotations: bool,
/// Top margin fraction (0.01.0) of page height to exclude headers/running heads.
/// Default: 0.06 (6%)
#[serde(default)]
pub top_margin_fraction: Option<f32>,
/// Bottom margin fraction (0.01.0) of page height to exclude footers/page numbers.
/// Default: 0.05 (5%)
#[serde(default)]
pub bottom_margin_fraction: Option<f32>,
/// Allow single-column pseudo tables in extraction results.
///
/// By default, tables with fewer than 2 columns (layout-guided) or 3 columns
/// (heuristic) are rejected. When `true`, the minimum column count is relaxed
/// to 1, allowing single-column structured data (glossaries, itemized lists)
/// to be emitted as tables. Other quality filters (density, sparsity, prose
/// detection) still apply.
#[serde(default)]
pub allow_single_column_tables: bool,
/// Perform OCR on inline images extracted from PDF pages and attach the
/// recognized text to each `ExtractedImage.ocr_result`. Requires Tesseract
/// to be available; if `ExtractionConfig.ocr` is `None` the extractor
/// falls back to `TesseractConfig::default()`. Per-image failures degrade
/// gracefully (the image is returned without OCR text rather than failing
/// the whole extraction). Default: `false`.
#[serde(default)]
pub ocr_inline_images: bool,
}
/// Hierarchy extraction configuration for PDF text structure analysis.
///
/// Enables extraction of document hierarchy levels (H1-H6) based on font size
/// clustering and semantic analysis. When enabled, hierarchical blocks are
/// included in page content.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HierarchyConfig {
/// Enable hierarchy extraction
#[serde(default = "default_true")]
pub enabled: bool,
/// Number of font size clusters to use for hierarchy levels (1-7)
///
/// Default: 6, which provides H1-H6 heading levels with body text.
/// Larger values create more fine-grained hierarchy levels.
#[serde(default = "default_k_clusters")]
pub k_clusters: usize,
/// Include bounding box information in hierarchy blocks
#[serde(default = "default_true")]
pub include_bbox: bool,
/// OCR coverage threshold for smart OCR triggering (0.0-1.0)
///
/// Determines when OCR should be triggered based on text block coverage.
/// OCR is triggered when text blocks cover less than this fraction of the page.
/// Default: 0.5 (trigger OCR if less than 50% of page has text)
#[serde(default = "default_ocr_coverage_threshold")]
pub ocr_coverage_threshold: Option<f32>,
}
#[cfg(feature = "pdf")]
impl Default for PdfConfig {
fn default() -> Self {
Self {
extract_images: false,
extract_tables: true,
passwords: None,
extract_metadata: true,
hierarchy: None,
extract_annotations: false,
top_margin_fraction: None,
bottom_margin_fraction: None,
allow_single_column_tables: false,
ocr_inline_images: false,
}
}
}
impl Default for HierarchyConfig {
fn default() -> Self {
Self {
enabled: true,
k_clusters: 3,
include_bbox: true,
ocr_coverage_threshold: None,
}
}
}
fn default_true() -> bool {
true
}
fn default_k_clusters() -> usize {
3
}
fn default_ocr_coverage_threshold() -> Option<f32> {
None
}
#[cfg(test)]
mod tests {
#[test]
#[cfg(feature = "pdf")]
fn test_hierarchy_config_default() {
use super::*;
let config = HierarchyConfig::default();
assert!(config.enabled);
assert_eq!(config.k_clusters, 3);
assert!(config.include_bbox);
assert!(config.ocr_coverage_threshold.is_none());
}
#[test]
#[cfg(feature = "pdf")]
fn test_hierarchy_config_disabled() {
use super::*;
let config = HierarchyConfig {
enabled: false,
k_clusters: 3,
include_bbox: false,
ocr_coverage_threshold: Some(0.7),
};
assert!(!config.enabled);
assert_eq!(config.k_clusters, 3);
assert!(!config.include_bbox);
assert_eq!(config.ocr_coverage_threshold, Some(0.7));
}
#[test]
#[cfg(feature = "pdf")]
fn test_pdf_config_custom_margins() {
use super::*;
let config = PdfConfig {
extract_images: false,
extract_tables: true,
passwords: None,
extract_metadata: true,
hierarchy: None,
extract_annotations: false,
top_margin_fraction: Some(0.10),
bottom_margin_fraction: Some(0.08),
allow_single_column_tables: false,
ocr_inline_images: false,
};
assert_eq!(config.top_margin_fraction, Some(0.10));
assert_eq!(config.bottom_margin_fraction, Some(0.08));
}
}

View File

@@ -0,0 +1,838 @@
//! Post-processing and chunking configuration.
//!
//! Defines configuration for post-processing pipelines, text chunking,
//! and embedding generation.
use ahash::AHashSet;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// Type of text chunker to use.
///
/// # Variants
///
/// * `Text` - Generic text splitter, splits on whitespace and punctuation
/// * `Markdown` - Markdown-aware splitter, preserves formatting and structure
/// * `Yaml` - YAML-aware splitter, creates one chunk per top-level key
/// * `Semantic` - Topic-aware chunker. With an `EmbeddingConfig`, splits at
/// embedding-based topic shifts tuned by `topic_threshold` (default 0.75,
/// lower = more splits). Without an embedding, falls back to a
/// structural-boundary heuristic (ALL-CAPS headers, numbered sections,
/// blank-line paragraphs) and merges groups into chunks capped at
/// `max_characters` (default 1000). `topic_threshold` has no effect in the
/// fallback path. For best results, pair with an embedding model.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum ChunkerType {
#[default]
Text,
Markdown,
Yaml,
Semantic,
}
/// How chunk size is measured.
///
/// Defaults to `Characters` (Unicode character count). When using token-based sizing,
/// chunks are sized by token count according to the specified tokenizer.
///
/// Token-based sizing uses HuggingFace tokenizers loaded at runtime. Any tokenizer
/// available on HuggingFace Hub can be used, including OpenAI-compatible tokenizers
/// (e.g., `Xenova/gpt-4o`, `Xenova/cl100k_base`).
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ChunkSizing {
/// Size measured in Unicode characters (default).
#[default]
Characters,
/// Size measured in tokens from a HuggingFace tokenizer.
#[cfg(feature = "chunking-tokenizers")]
Tokenizer {
/// HuggingFace model ID or path, e.g. "Xenova/gpt-4o", "bert-base-uncased".
model: String,
/// Optional cache directory override for tokenizer files.
/// Defaults to hf-hub's standard cache (`~/.cache/huggingface/`).
/// Can also be set via `KREUZBERG_TOKENIZER_CACHE_DIR` environment variable.
#[serde(default, skip_serializing_if = "Option::is_none")]
cache_dir: Option<std::path::PathBuf>,
},
}
/// Post-processor configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PostProcessorConfig {
/// Enable post-processors
#[serde(default = "default_true")]
pub enabled: bool,
/// Whitelist of processor names to run (None = all enabled)
#[serde(default)]
pub enabled_processors: Option<Vec<String>>,
/// Blacklist of processor names to skip (None = none disabled)
#[serde(default)]
pub disabled_processors: Option<Vec<String>>,
/// Pre-computed AHashSet for O(1) enabled processor lookup
#[serde(skip)]
pub enabled_set: Option<AHashSet<String>>,
/// Pre-computed AHashSet for O(1) disabled processor lookup
#[serde(skip)]
pub disabled_set: Option<AHashSet<String>>,
}
impl PostProcessorConfig {
/// Pre-compute HashSets for O(1) processor name lookups.
///
/// This method converts the enabled/disabled processor Vec to HashSet
/// for constant-time lookups in the pipeline.
#[cfg(test)]
pub(crate) fn build_lookup_sets(&mut self) {
if let Some(ref enabled) = self.enabled_processors {
self.enabled_set = Some(enabled.iter().cloned().collect());
}
if let Some(ref disabled) = self.disabled_processors {
self.disabled_set = Some(disabled.iter().cloned().collect());
}
}
}
impl Default for PostProcessorConfig {
fn default() -> Self {
Self {
enabled: true,
enabled_processors: None,
disabled_processors: None,
enabled_set: None,
disabled_set: None,
}
}
}
/// Chunking configuration.
///
/// Configures text chunking for document content, including chunk size,
/// overlap, trimming behavior, and optional embeddings.
///
/// Use `..Default::default()` when constructing to allow for future field additions:
/// ```rust
/// # use kreuzberg::ChunkingConfig;
/// let config = ChunkingConfig {
/// max_characters: 500,
/// ..Default::default()
/// };
/// ```
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkingConfig {
/// Maximum size per chunk (in units determined by `sizing`).
///
/// When `sizing` is `Characters` (default), this is the max character count.
/// When using token-based sizing, this is the max token count.
///
/// Default: 1000
#[serde(default = "default_chunk_size", rename = "max_chars", alias = "max_characters")]
pub max_characters: usize,
/// Overlap between chunks (in units determined by `sizing`).
///
/// Default: 200
#[serde(default = "default_chunk_overlap", rename = "max_overlap", alias = "overlap")]
pub overlap: usize,
/// Whether to trim whitespace from chunk boundaries.
///
/// Default: true
#[serde(default = "default_trim")]
pub trim: bool,
/// Type of chunker to use (Text or Markdown).
///
/// Default: Text
#[serde(default = "default_chunker_type")]
pub chunker_type: ChunkerType,
/// Optional embedding configuration for chunk embeddings.
#[serde(skip_serializing_if = "Option::is_none")]
pub embedding: Option<EmbeddingConfig>,
/// Use a preset configuration (overrides individual settings if provided).
#[serde(skip_serializing_if = "Option::is_none")]
pub preset: Option<String>,
/// How to measure chunk size.
///
/// Default: `Characters` (Unicode character count).
/// Enable `chunking-tiktoken` or `chunking-tokenizers` features for token-based sizing.
#[serde(default, deserialize_with = "deserialize_null_default")]
pub sizing: ChunkSizing,
/// When `true` and `chunker_type` is `Markdown`, prepend the heading hierarchy
/// path (e.g. `"# Title > ## Section\n\n"`) to each chunk's content string.
///
/// This is useful for RAG pipelines where each chunk needs self-contained
/// context about its position in the document structure.
///
/// Default: `false`
#[serde(default)]
pub prepend_heading_context: bool,
/// Optional cosine similarity threshold for semantic topic boundary detection.
///
/// Only used when `chunker_type` is `Semantic` and an `EmbeddingConfig` is
/// provided. You almost never need to set this. When omitted, defaults to
/// `0.75` which works well for most documents. Lower values detect more
/// topic boundaries (more, smaller chunks); higher values detect fewer.
/// Range: `0.0..=1.0`.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub topic_threshold: Option<f32>,
}
impl ChunkingConfig {
/// Set the cosine similarity threshold for semantic topic boundary detection.
///
/// # Panics
///
/// Panics if `threshold` is outside `[0.0, 1.0]`.
#[cfg(test)]
pub(crate) fn with_topic_threshold(mut self, threshold: f32) -> Self {
assert!(
(0.0..=1.0).contains(&threshold),
"topic_threshold must be in [0.0, 1.0], got {threshold}"
);
self.topic_threshold = Some(threshold);
self
}
/// Resolve a preset name into concrete chunking and embedding configuration.
///
/// When `preset` is set (e.g., `"balanced"`), this overrides `max_characters` and
/// `overlap` from the preset definition, and configures the embedding model if
/// no embedding config was explicitly provided.
///
/// If the preset name is not recognized, a warning is logged and the config
/// is returned unchanged.
///
/// Requires the `embeddings` feature. Without it, this is a no-op that returns
/// the config unchanged.
#[cfg(feature = "embeddings")]
pub(crate) fn resolve_preset(&self) -> Self {
let preset_name = match &self.preset {
Some(name) => name,
None => return self.clone(),
};
let preset = match crate::embeddings::get_preset(preset_name) {
Some(p) => p,
None => {
tracing::warn!(
"Unknown chunking preset '{}', using manual config. Available: {:?}",
preset_name,
crate::embeddings::list_presets()
);
return self.clone();
}
};
// Preserve the caller's embedding choice, including None.
// Presets configure chunking parameters only; users must explicitly
// provide an EmbeddingConfig to opt into embedding generation.
let embedding = self.embedding.clone();
Self {
max_characters: preset.chunk_size,
overlap: preset.overlap,
embedding,
// Preserve caller's other settings
trim: self.trim,
chunker_type: self.chunker_type,
preset: self.preset.clone(),
sizing: self.sizing.clone(),
prepend_heading_context: self.prepend_heading_context,
topic_threshold: self.topic_threshold,
}
}
/// Resolve a preset name (no-op without the `embeddings` feature).
#[cfg(all(feature = "chunking", not(feature = "embeddings")))]
pub(crate) fn resolve_preset(&self) -> Self {
if self.preset.is_some() {
tracing::warn!("Chunking presets require the 'embeddings' feature");
}
self.clone()
}
}
impl Default for ChunkingConfig {
fn default() -> Self {
Self {
max_characters: 1000,
overlap: 200,
trim: true,
chunker_type: ChunkerType::Text,
embedding: None,
preset: None,
sizing: ChunkSizing::default(),
prepend_heading_context: false,
topic_threshold: None,
}
}
}
/// Embedding configuration for text chunks.
///
/// Configures embedding generation using ONNX models via the vendored embedding engine.
/// Requires the `embeddings` feature to be enabled.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
/// The embedding model to use (defaults to "balanced" preset if not specified)
#[serde(default = "default_model", deserialize_with = "deserialize_null_model")]
pub model: EmbeddingModelType,
/// Whether to normalize embedding vectors (recommended for cosine similarity)
#[serde(default = "default_normalize")]
pub normalize: bool,
/// Batch size for embedding generation
#[serde(default = "default_batch_size")]
pub batch_size: usize,
/// Show model download progress
#[serde(default)]
pub show_download_progress: bool,
/// Custom cache directory for model files
///
/// Defaults to `~/.cache/kreuzberg/embeddings/` if not specified.
/// Allows full customization of model download location.
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_dir: Option<PathBuf>,
/// Hardware acceleration for the embedding ONNX model.
///
/// When set, controls which execution provider (CPU, CUDA, CoreML, TensorRT)
/// is used for inference. Defaults to `None` (auto-select per platform).
#[serde(default, skip_serializing_if = "Option::is_none")]
pub acceleration: Option<super::acceleration::AccelerationConfig>,
/// Maximum wall-clock duration (in seconds) for a single `embed()` call when
/// using [`EmbeddingModelType::Plugin`].
///
/// Applies only to the in-process plugin path — protects against hung
/// host-language backends (e.g. a Python callback deadlocked on the GIL,
/// a model stuck on CUDA OOM retries, etc.). On timeout, the dispatcher
/// returns [`crate::KreuzbergError::Plugin`] instead of blocking forever.
///
/// `None` disables the timeout. The default (60 seconds) is conservative
/// for common in-process inference; increase for large batches on slow
/// hardware.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_embed_duration_secs: Option<u64>,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
model: EmbeddingModelType::Preset {
name: "balanced".to_string(),
},
normalize: true,
batch_size: 32,
show_download_progress: false,
cache_dir: None,
acceleration: None,
max_embed_duration_secs: Some(60),
}
}
}
/// Embedding model types supported by Kreuzberg.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum EmbeddingModelType {
/// Use a preset model configuration (recommended)
Preset { name: String },
/// Use a custom ONNX model from HuggingFace
Custom { model_id: String, dimensions: usize },
/// Provider-hosted embedding model via liter-llm.
///
/// Uses the model specified in the nested `LlmConfig` (e.g.,
/// `"openai/text-embedding-3-small"`).
Llm { llm: super::llm::LlmConfig },
/// In-process embedding backend registered via the plugin system.
///
/// The caller registers an [`EmbeddingBackend`](crate::plugins::EmbeddingBackend) once
/// (e.g. a wrapper around an already-loaded `llama-cpp-python`, `sentence-transformers`,
/// or tuned ONNX model), then references it by name in config. Kreuzberg calls back
/// into the registered backend during chunking and standalone embed requests —
/// no HuggingFace download, no ONNX Runtime requirement, no HTTP sidecar.
///
/// When this variant is selected, only the following [`EmbeddingConfig`] fields
/// apply: `normalize` (post-call L2 normalization) and `max_embed_duration_secs`
/// (dispatcher timeout). Model-loading fields (`batch_size`, `cache_dir`,
/// `show_download_progress`, `acceleration`) are ignored — the host owns the
/// model lifecycle.
///
/// Semantic chunking falls back to [`ChunkingConfig::max_characters`] when this variant
/// is used, since there is no preset to look a chunk-size ceiling up against — size your
/// context window via `max_characters` directly.
///
/// See [`crate::plugins::register_embedding_backend`].
Plugin { name: String },
}
impl Default for EmbeddingModelType {
/// Returns the "balanced" preset as the default model.
///
/// Previously returned `Preset { name: "" }` (empty string) which caused
/// "Unknown embedding preset: " errors in every language binding that calls
/// `EmbeddingModelType::default()` — including generated bindings that
/// use struct-level `#[serde(default)]` instead of `default_model()`.
/// All defaults across the codebase converge on "balanced".
fn default() -> Self {
Self::Preset {
name: "balanced".to_string(),
}
}
}
fn default_true() -> bool {
true
}
fn default_chunk_size() -> usize {
1000
}
/// Deserialize a value that may be explicitly `null` into its `Default` value.
///
/// Internally-tagged serde enums (e.g. `#[serde(tag = "type")]`) reject `null`
/// even when the containing field has `#[serde(default)]`, because that attribute
/// only covers the *missing* case. Polyglot bindings frequently emit explicit
/// `"field": null` from zero-valued mirror structs, so this helper accepts either
/// `null` or a present value and falls back to `T::default()` for null.
fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: serde::Deserializer<'de>,
T: Default + serde::Deserialize<'de>,
{
let opt = Option::<T>::deserialize(deserializer)?;
Ok(opt.unwrap_or_default())
}
fn default_chunk_overlap() -> usize {
200
}
fn default_trim() -> bool {
true
}
fn default_chunker_type() -> ChunkerType {
ChunkerType::Text
}
fn default_normalize() -> bool {
true
}
fn default_batch_size() -> usize {
32
}
fn default_model() -> EmbeddingModelType {
EmbeddingModelType::Preset {
name: "balanced".to_string(),
}
}
/// `deserialize_with` companion for `EmbeddingModelType` fields that may be
/// explicitly `null` in polyglot binding payloads. Treats null as the configured
/// `default_model()` (the "balanced" preset) rather than the trait `Default` impl
/// (which is an empty-name placeholder unsuitable for live use).
fn deserialize_null_model<'de, D>(deserializer: D) -> Result<EmbeddingModelType, D::Error>
where
D: serde::Deserializer<'de>,
{
let opt = Option::<EmbeddingModelType>::deserialize(deserializer)?;
Ok(opt.unwrap_or_else(default_model))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_postprocessor_config_default() {
let config = PostProcessorConfig::default();
assert!(config.enabled);
assert!(config.enabled_processors.is_none());
assert!(config.disabled_processors.is_none());
}
#[test]
fn test_postprocessor_config_build_lookup_sets() {
let mut config = PostProcessorConfig {
enabled: true,
enabled_processors: Some(vec!["a".to_string(), "b".to_string()]),
disabled_processors: Some(vec!["c".to_string()]),
enabled_set: None,
disabled_set: None,
};
config.build_lookup_sets();
assert!(config.enabled_set.is_some());
assert!(config.disabled_set.is_some());
assert!(config.enabled_set.unwrap().contains("a"));
assert!(config.disabled_set.unwrap().contains("c"));
}
#[test]
fn test_chunking_config_defaults() {
let config = ChunkingConfig::default();
assert_eq!(config.max_characters, 1000);
assert_eq!(config.overlap, 200);
assert!(config.trim);
assert_eq!(config.chunker_type, ChunkerType::Text);
assert!(matches!(config.sizing, ChunkSizing::Characters));
}
#[test]
fn test_embedding_config_default() {
let config = EmbeddingConfig::default();
assert!(config.normalize);
assert_eq!(config.batch_size, 32);
assert!(config.cache_dir.is_none());
}
/// Tests that `EmbeddingModelType::default()` returns the "balanced" preset.
///
/// Language bindings that use struct-level `#[serde(default)]` resolve absent
/// `model` fields via this impl. An empty-string name caused "Unknown embedding
/// preset: " panics in `get_preset()`; the default must be a valid preset.
#[test]
fn test_embedding_model_type_default_is_balanced() {
match EmbeddingModelType::default() {
EmbeddingModelType::Preset { name } => {
assert_eq!(name, "balanced", "Default model should be the balanced preset");
}
other => panic!("Expected Preset variant, got {:?}", other),
}
}
/// Tests that EmbeddingModelType::Preset serializes with "type" field (internally-tagged).
/// This validates the API schema matches the documented format:
/// `{"type": "preset", "name": "fast"}` NOT `{"preset": {"name": "fast"}}`
#[test]
fn test_embedding_model_type_preset_serialization() {
let model = EmbeddingModelType::Preset {
name: "fast".to_string(),
};
let json = serde_json::to_string(&model).unwrap();
// Should use internally-tagged format with "type" discriminator
assert!(json.contains(r#""type":"preset""#), "Should contain type:preset field");
assert!(json.contains(r#""name":"fast""#), "Should contain name:fast field");
// Should NOT use adjacently-tagged format
assert!(
!json.contains(r#"{"preset":"#),
"Should NOT use adjacently-tagged format"
);
}
/// Tests that EmbeddingModelType::Preset deserializes from the documented API format.
/// API documentation shows: `{"type": "preset", "name": "fast"}`
#[test]
fn test_embedding_model_type_preset_deserialization() {
// This is the documented API format that users should send
let json = r#"{"type": "preset", "name": "fast"}"#;
let model: EmbeddingModelType = serde_json::from_str(json).unwrap();
match model {
EmbeddingModelType::Preset { name } => {
assert_eq!(name, "fast");
}
_ => panic!("Expected Preset variant"),
}
}
/// Tests that the wrong format (adjacently-tagged) is rejected.
/// This ensures the API doesn't accept the old/wrong documentation format.
#[test]
fn test_embedding_model_type_rejects_wrong_format() {
// This is the WRONG format that was in the old documentation
let wrong_json = r#"{"preset": {"name": "fast"}}"#;
let result: Result<EmbeddingModelType, _> = serde_json::from_str(wrong_json);
// Should fail to parse - the wrong format should be rejected
assert!(result.is_err(), "Should reject adjacently-tagged format");
}
/// Tests round-trip serialization/deserialization of EmbeddingConfig.
#[test]
fn test_embedding_config_roundtrip() {
let config = EmbeddingConfig {
model: EmbeddingModelType::Preset {
name: "balanced".to_string(),
},
normalize: true,
batch_size: 64,
show_download_progress: false,
cache_dir: None,
acceleration: None,
max_embed_duration_secs: Some(60),
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: EmbeddingConfig = serde_json::from_str(&json).unwrap();
match deserialized.model {
EmbeddingModelType::Preset { name } => {
assert_eq!(name, "balanced");
}
_ => panic!("Expected Preset variant"),
}
assert!(deserialized.normalize);
assert_eq!(deserialized.batch_size, 64);
}
/// Tests Custom model type serialization format.
#[test]
fn test_embedding_model_type_custom_serialization() {
let model = EmbeddingModelType::Custom {
model_id: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
dimensions: 384,
};
let json = serde_json::to_string(&model).unwrap();
assert!(json.contains(r#""type":"custom""#), "Should contain type:custom field");
assert!(json.contains(r#""model_id":"#), "Should contain model_id field");
assert!(json.contains(r#""dimensions":384"#), "Should contain dimensions field");
}
#[test]
#[cfg(feature = "embeddings")]
fn test_resolve_preset_balanced() {
let config = ChunkingConfig {
preset: Some("balanced".to_string()),
..Default::default()
};
let resolved = config.resolve_preset();
assert_eq!(resolved.max_characters, 1024);
assert_eq!(resolved.overlap, 100);
// Preset configures chunking parameters only; embedding stays None unless
// the caller explicitly provided one (#797).
assert!(resolved.embedding.is_none());
}
#[test]
#[cfg(feature = "embeddings")]
fn test_resolve_preset_preserves_explicit_embedding() {
let explicit_embedding = EmbeddingConfig {
model: EmbeddingModelType::Custom {
model_id: "custom/model".to_string(),
dimensions: 512,
},
batch_size: 64,
..Default::default()
};
let config = ChunkingConfig {
preset: Some("fast".to_string()),
embedding: Some(explicit_embedding),
..Default::default()
};
let resolved = config.resolve_preset();
assert_eq!(resolved.max_characters, 512);
assert_eq!(resolved.overlap, 50);
// Explicit embedding config preserved
match &resolved.embedding.unwrap().model {
EmbeddingModelType::Custom { model_id, .. } => assert_eq!(model_id, "custom/model"),
_ => panic!("Expected Custom model type to be preserved"),
}
}
#[cfg(any(feature = "embeddings", feature = "chunking"))]
#[test]
fn test_resolve_preset_no_preset_returns_unchanged() {
let config = ChunkingConfig {
max_characters: 500,
overlap: 50,
..Default::default()
};
let resolved = config.resolve_preset();
assert_eq!(resolved.max_characters, 500);
assert_eq!(resolved.overlap, 50);
assert!(resolved.embedding.is_none());
}
#[cfg(any(feature = "embeddings", feature = "chunking"))]
#[test]
fn test_resolve_preset_unknown_name_returns_unchanged() {
let config = ChunkingConfig {
max_characters: 500,
preset: Some("nonexistent".to_string()),
..Default::default()
};
let resolved = config.resolve_preset();
assert_eq!(resolved.max_characters, 500);
}
#[test]
fn test_embedding_model_type_llm_roundtrip() {
let model_type = EmbeddingModelType::Llm {
llm: crate::core::config::llm::LlmConfig {
model: "openai/text-embedding-3-small".to_string(),
api_key: None,
base_url: None,
timeout_secs: None,
max_retries: None,
temperature: None,
max_tokens: None,
},
};
let json = serde_json::to_string(&model_type).unwrap();
assert!(json.contains("\"type\":\"llm\""));
assert!(json.contains("openai/text-embedding-3-small"));
let deserialized: EmbeddingModelType = serde_json::from_str(&json).unwrap();
match deserialized {
EmbeddingModelType::Llm { llm } => {
assert_eq!(llm.model, "openai/text-embedding-3-small");
}
_ => panic!("Expected Llm variant"),
}
}
#[test]
#[should_panic(expected = "topic_threshold must be in [0.0, 1.0]")]
fn test_with_topic_threshold_panics_above_one() {
ChunkingConfig::default().with_topic_threshold(1.1);
}
#[test]
#[should_panic(expected = "topic_threshold must be in [0.0, 1.0]")]
fn test_with_topic_threshold_panics_below_zero() {
ChunkingConfig::default().with_topic_threshold(-0.1);
}
#[test]
fn test_with_topic_threshold_accepts_boundary_values() {
let config = ChunkingConfig::default().with_topic_threshold(0.0);
assert_eq!(config.topic_threshold, Some(0.0));
let config = ChunkingConfig::default().with_topic_threshold(1.0);
assert_eq!(config.topic_threshold, Some(1.0));
}
/// Tests Custom model type deserialization.
#[test]
fn test_embedding_model_type_custom_deserialization() {
let json = r#"{"type": "custom", "model_id": "test/model", "dimensions": 512}"#;
let model: EmbeddingModelType = serde_json::from_str(json).unwrap();
match model {
EmbeddingModelType::Custom { model_id, dimensions } => {
assert_eq!(model_id, "test/model");
assert_eq!(dimensions, 512);
}
_ => panic!("Expected Custom variant"),
}
}
#[test]
fn test_embedding_model_type_plugin_roundtrip() {
let model = EmbeddingModelType::Plugin {
name: "lilbee-llamacpp".to_string(),
};
let json = serde_json::to_string(&model).unwrap();
assert!(json.contains("\"type\":\"plugin\""));
assert!(json.contains("lilbee-llamacpp"));
let deserialized: EmbeddingModelType = serde_json::from_str(&json).unwrap();
match deserialized {
EmbeddingModelType::Plugin { name } => assert_eq!(name, "lilbee-llamacpp"),
_ => panic!("Expected Plugin variant"),
}
}
#[test]
fn test_embedding_model_type_plugin_deserialization() {
let json = r#"{"type": "plugin", "name": "my-embedder"}"#;
let model: EmbeddingModelType = serde_json::from_str(json).unwrap();
match model {
EmbeddingModelType::Plugin { name } => assert_eq!(name, "my-embedder"),
_ => panic!("Expected Plugin variant"),
}
}
// --- Issue #797 regression tests ---
/// Preset with no explicit embedding: embedding must remain None.
///
/// Before the fix, `resolve_preset()` would silently inject an
/// `EmbeddingConfig` whenever a preset was configured, causing every
/// chunk to have an unexpected `.embedding` field populated.
#[test]
#[cfg(feature = "embeddings")]
fn test_resolve_preset_does_not_inject_embedding_when_none() {
let config = ChunkingConfig {
preset: Some("multilingual".to_string()),
embedding: None,
..Default::default()
};
let resolved = config.resolve_preset();
assert!(
resolved.embedding.is_none(),
"preset alone must not inject an EmbeddingConfig (#797)"
);
}
/// Preset with an explicit embedding: the embedding must be preserved unchanged.
#[test]
#[cfg(feature = "embeddings")]
fn test_resolve_preset_preserves_explicit_embedding_config() {
let explicit = EmbeddingConfig {
model: EmbeddingModelType::Custom {
model_id: "my-org/model".to_string(),
dimensions: 768,
},
batch_size: 16,
..Default::default()
};
let config = ChunkingConfig {
preset: Some("multilingual".to_string()),
embedding: Some(explicit),
..Default::default()
};
let resolved = config.resolve_preset();
let emb = resolved
.embedding
.expect("explicit embedding must survive resolve_preset");
assert_eq!(emb.batch_size, 16);
match emb.model {
EmbeddingModelType::Custom { model_id, dimensions } => {
assert_eq!(model_id, "my-org/model");
assert_eq!(dimensions, 768);
}
other => panic!("expected Custom model type, got {other:?}"),
}
}
/// No preset, no embedding: embedding must stay None (regression guard).
#[cfg(any(feature = "embeddings", feature = "chunking"))]
#[test]
fn test_resolve_preset_no_preset_no_embedding_stays_none() {
let config = ChunkingConfig {
preset: None,
embedding: None,
..Default::default()
};
let resolved = config.resolve_preset();
assert!(resolved.embedding.is_none(), "no-preset path must not touch embedding");
}
}

View File

@@ -0,0 +1,160 @@
//! Tree-sitter language pack configuration.
//!
//! This module contains configuration types for the tree-sitter integration,
//! including grammar download settings and code analysis processing options.
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// Content rendering mode for code extraction.
///
/// Controls how extracted code content is represented in the `content` field
/// of `ExtractionResult`.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CodeContentMode {
/// Use TSLP semantic chunks as content (default).
#[default]
Chunks,
/// Use raw source code as content.
Raw,
/// Emit function/class headings + docstrings (no code bodies).
Structure,
}
/// Configuration for tree-sitter language pack integration.
///
/// Controls grammar download behavior and code analysis options.
///
/// # Example (TOML)
///
/// ```toml
/// [tree_sitter]
/// languages = ["python", "rust"]
/// groups = ["web"]
///
/// [tree_sitter.process]
/// structure = true
/// comments = true
/// docstrings = true
/// ```
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TreeSitterConfig {
/// Enable code intelligence processing (default: true).
///
/// When `false`, tree-sitter analysis is completely skipped even if
/// the config section is present.
#[serde(default = "default_true")]
pub enabled: bool,
/// Custom cache directory for downloaded grammars.
///
/// When `None`, uses the default: `~/.cache/tree-sitter-language-pack/v{version}/libs/`.
#[serde(default)]
pub cache_dir: Option<PathBuf>,
/// Languages to pre-download on init (e.g., `["python", "rust"]`).
#[serde(default)]
pub languages: Option<Vec<String>>,
/// Language groups to pre-download (e.g., `["web", "systems", "scripting"]`).
#[serde(default)]
pub groups: Option<Vec<String>>,
/// Processing options for code analysis.
#[serde(default)]
pub process: TreeSitterProcessConfig,
}
/// Processing options for tree-sitter code analysis.
///
/// Controls which analysis features are enabled when extracting code files.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TreeSitterProcessConfig {
/// Extract structural items (functions, classes, structs, etc.). Default: true.
#[serde(default = "default_true")]
pub structure: bool,
/// Extract import statements. Default: true.
#[serde(default = "default_true")]
pub imports: bool,
/// Extract export statements. Default: true.
#[serde(default = "default_true")]
pub exports: bool,
/// Extract comments. Default: false.
#[serde(default)]
pub comments: bool,
/// Extract docstrings. Default: false.
#[serde(default)]
pub docstrings: bool,
/// Extract symbol definitions. Default: false.
#[serde(default)]
pub symbols: bool,
/// Include parse diagnostics. Default: false.
#[serde(default)]
pub diagnostics: bool,
/// Maximum chunk size in bytes. `None` disables chunking.
#[serde(default)]
pub chunk_max_size: Option<usize>,
/// Content rendering mode for code extraction.
#[serde(default)]
pub content_mode: CodeContentMode,
}
impl Default for TreeSitterConfig {
fn default() -> Self {
Self {
enabled: true,
cache_dir: None,
languages: None,
groups: None,
process: TreeSitterProcessConfig::default(),
}
}
}
impl Default for TreeSitterProcessConfig {
fn default() -> Self {
Self {
structure: true,
imports: true,
exports: true,
comments: false,
docstrings: false,
symbols: false,
diagnostics: false,
chunk_max_size: None,
content_mode: CodeContentMode::default(),
}
}
}
fn default_true() -> bool {
true
}
/// Convert kreuzberg's process config to TSLP's `ProcessConfig`.
///
/// The language field is left empty — callers must set it before use.
impl From<&TreeSitterProcessConfig> for tree_sitter_language_pack::ProcessConfig {
fn from(p: &TreeSitterProcessConfig) -> Self {
Self {
language: std::borrow::Cow::Borrowed(""),
structure: p.structure,
imports: p.imports,
exports: p.exports,
comments: p.comments,
docstrings: p.docstrings,
symbols: p.symbols,
diagnostics: p.diagnostics,
chunk_max_size: p.chunk_max_size,
}
}
}

View File

@@ -0,0 +1,97 @@
//! Cross-section dependency validation.
//!
//! This module contains validation functions that check dependencies and relationships
//! between different configuration sections. These validators ensure that related
//! configuration values are consistent and compatible with each other.
#[cfg(test)]
use crate::{KreuzbergError, Result};
#[cfg(test)]
pub(crate) fn validate_port(port: u32) -> Result<()> {
if port == 0 || port > 65535 {
Err(KreuzbergError::Validation {
message: format!("Port must be 1-65535, got {}", port),
source: None,
})
} else {
Ok(())
}
}
#[cfg(test)]
pub(crate) fn validate_host(host: &str) -> Result<()> {
let host = host.trim();
if host.is_empty() {
return Err(KreuzbergError::Validation {
message: "Invalid host '': must be a valid IP address or hostname".to_string(),
source: None,
});
}
// Check if it's a valid IPv4 address
if host.parse::<std::net::Ipv4Addr>().is_ok() {
return Ok(());
}
// Check if it's a valid IPv6 address
if host.parse::<std::net::Ipv6Addr>().is_ok() {
return Ok(());
}
// Check if it's a valid hostname (basic validation)
// Hostnames must contain only alphanumeric characters, dots, and hyphens
// Must not look like an invalid IPv4 address (all numeric with dots)
let looks_like_ipv4 = host
.split('.')
.all(|part| !part.is_empty() && part.chars().all(|c| c.is_numeric()));
if !looks_like_ipv4
&& host.chars().all(|c| c.is_alphanumeric() || c == '.' || c == '-')
&& !host.starts_with('-')
&& !host.ends_with('-')
{
return Ok(());
}
Err(KreuzbergError::Validation {
message: format!("Invalid host '{}': must be a valid IP address or hostname", host),
source: None,
})
}
#[cfg(test)]
pub(crate) fn validate_cors_origin(origin: &str) -> Result<()> {
let origin = origin.trim();
if origin == "*" {
return Ok(());
}
if origin.starts_with("http://") || origin.starts_with("https://") {
// Basic validation: ensure there's something after the protocol
if origin.len() > 8 && (origin.starts_with("http://") && origin.len() > 7 || origin.starts_with("https://")) {
return Ok(());
}
}
Err(KreuzbergError::Validation {
message: format!(
"Invalid CORS origin '{}': must be a valid HTTP/HTTPS URL or '*'",
origin
),
source: None,
})
}
#[cfg(test)]
pub(crate) fn validate_upload_size(size: usize) -> Result<()> {
if size > 0 {
Ok(())
} else {
Err(KreuzbergError::Validation {
message: format!("Upload size must be greater than 0, got {}", size),
source: None,
})
}
}

View File

@@ -0,0 +1,399 @@
//! Configuration validation module.
//!
//! Provides centralized validation for configuration values across all bindings.
//! This eliminates duplication of validation logic in Python, TypeScript, Java, Go, and other language bindings.
//!
//! All validation functions return `Result<()>` and produce detailed error messages
//! suitable for user-facing error handling.
//!
//! # Examples
//!
//! ```ignore
//! use kreuzberg::core::config_validation::{
//! validate_binarization_method,
//! validate_token_reduction_level,
//! validate_language_code,
//! };
//!
//! // Valid values
//! assert!(validate_binarization_method("otsu").is_ok());
//! assert!(validate_token_reduction_level("moderate").is_ok());
//! assert!(validate_language_code("en").is_ok());
//!
//! // Invalid values
//! assert!(validate_binarization_method("invalid").is_err());
//! assert!(validate_token_reduction_level("extreme").is_err());
//! ```
mod dependencies;
mod sections;
// Re-export validation functions used in production code
pub(crate) use sections::{
validate_chunking_params, validate_language_code, validate_ocr_backend, validate_token_reduction_level,
};
// Re-export validation functions used only in tests
#[cfg(test)]
pub(crate) use dependencies::{validate_cors_origin, validate_host, validate_port, validate_upload_size};
#[cfg(test)]
pub(crate) use sections::{
validate_binarization_method, validate_confidence, validate_dpi, validate_output_format, validate_tesseract_oem,
validate_tesseract_psm, validate_vlm_backend_config,
};
#[cfg(test)]
mod tests {
use super::*;
// Tests for section validation functions
#[test]
fn test_validate_binarization_method_valid() {
assert!(validate_binarization_method("otsu").is_ok());
assert!(validate_binarization_method("adaptive").is_ok());
assert!(validate_binarization_method("sauvola").is_ok());
}
#[test]
fn test_validate_binarization_method_case_insensitive() {
assert!(validate_binarization_method("OTSU").is_ok());
assert!(validate_binarization_method("Adaptive").is_ok());
assert!(validate_binarization_method("SAUVOLA").is_ok());
}
#[test]
fn test_validate_binarization_method_invalid() {
let result = validate_binarization_method("invalid");
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("Invalid binarization method"));
assert!(msg.contains("otsu"));
}
#[test]
fn test_validate_token_reduction_level_valid() {
assert!(validate_token_reduction_level("off").is_ok());
assert!(validate_token_reduction_level("light").is_ok());
assert!(validate_token_reduction_level("moderate").is_ok());
assert!(validate_token_reduction_level("aggressive").is_ok());
assert!(validate_token_reduction_level("maximum").is_ok());
}
#[test]
fn test_validate_token_reduction_level_case_insensitive() {
assert!(validate_token_reduction_level("OFF").is_ok());
assert!(validate_token_reduction_level("Moderate").is_ok());
assert!(validate_token_reduction_level("MAXIMUM").is_ok());
}
#[test]
fn test_validate_token_reduction_level_invalid() {
let result = validate_token_reduction_level("extreme");
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("Invalid token reduction level"));
}
#[test]
fn test_validate_ocr_backend_valid() {
assert!(validate_ocr_backend("tesseract").is_ok());
assert!(validate_ocr_backend("easyocr").is_ok());
assert!(validate_ocr_backend("paddleocr").is_ok());
}
#[test]
fn test_validate_ocr_backend_case_insensitive() {
assert!(validate_ocr_backend("TESSERACT").is_ok());
assert!(validate_ocr_backend("EasyOCR").is_ok());
assert!(validate_ocr_backend("PADDLEOCR").is_ok());
}
#[test]
fn test_validate_ocr_backend_invalid() {
let result = validate_ocr_backend("invalid_backend");
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("Invalid OCR backend"));
}
#[test]
fn test_validate_language_code_valid_iso639_1() {
assert!(validate_language_code("en").is_ok());
assert!(validate_language_code("de").is_ok());
assert!(validate_language_code("fr").is_ok());
assert!(validate_language_code("es").is_ok());
assert!(validate_language_code("zh").is_ok());
assert!(validate_language_code("ja").is_ok());
assert!(validate_language_code("ko").is_ok());
}
#[test]
fn test_validate_language_code_valid_iso639_3() {
assert!(validate_language_code("eng").is_ok());
assert!(validate_language_code("deu").is_ok());
assert!(validate_language_code("fra").is_ok());
assert!(validate_language_code("spa").is_ok());
assert!(validate_language_code("zho").is_ok());
assert!(validate_language_code("jpn").is_ok());
assert!(validate_language_code("kor").is_ok());
}
#[test]
fn test_validate_language_code_case_insensitive() {
assert!(validate_language_code("EN").is_ok());
assert!(validate_language_code("ENG").is_ok());
assert!(validate_language_code("De").is_ok());
assert!(validate_language_code("DEU").is_ok());
}
#[test]
fn test_validate_language_code_all_keyword() {
assert!(validate_language_code("all").is_ok());
assert!(validate_language_code("ALL").is_ok());
assert!(validate_language_code("All").is_ok());
assert!(validate_language_code("*").is_ok());
}
#[test]
fn test_validate_language_code_invalid() {
let result = validate_language_code("invalid");
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("Invalid language code"));
assert!(msg.contains("ISO 639"));
}
#[test]
fn test_validate_tesseract_psm_valid() {
for psm in 0..=13 {
assert!(validate_tesseract_psm(psm).is_ok(), "PSM {} should be valid", psm);
}
}
#[test]
fn test_validate_tesseract_psm_invalid() {
assert!(validate_tesseract_psm(-1).is_err());
assert!(validate_tesseract_psm(14).is_err());
assert!(validate_tesseract_psm(100).is_err());
}
#[test]
fn test_validate_tesseract_oem_valid() {
for oem in 0..=3 {
assert!(validate_tesseract_oem(oem).is_ok(), "OEM {} should be valid", oem);
}
}
#[test]
fn test_validate_tesseract_oem_invalid() {
assert!(validate_tesseract_oem(-1).is_err());
assert!(validate_tesseract_oem(4).is_err());
assert!(validate_tesseract_oem(10).is_err());
}
#[test]
fn test_validate_output_format_valid() {
assert!(validate_output_format("text").is_ok());
assert!(validate_output_format("markdown").is_ok());
}
#[test]
fn test_validate_output_format_case_insensitive() {
assert!(validate_output_format("TEXT").is_ok());
assert!(validate_output_format("Markdown").is_ok());
}
#[test]
fn test_validate_output_format_invalid() {
let result = validate_output_format("xml");
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("Invalid output format"));
}
#[test]
fn test_validate_confidence_valid() {
assert!(validate_confidence(0.0).is_ok());
assert!(validate_confidence(0.5).is_ok());
assert!(validate_confidence(1.0).is_ok());
assert!(validate_confidence(0.75).is_ok());
}
#[test]
fn test_validate_confidence_invalid() {
assert!(validate_confidence(-0.1).is_err());
assert!(validate_confidence(1.1).is_err());
assert!(validate_confidence(2.0).is_err());
}
#[test]
fn test_validate_dpi_valid() {
assert!(validate_dpi(72).is_ok());
assert!(validate_dpi(96).is_ok());
assert!(validate_dpi(300).is_ok());
assert!(validate_dpi(600).is_ok());
assert!(validate_dpi(1).is_ok());
}
#[test]
fn test_validate_dpi_invalid() {
assert!(validate_dpi(0).is_err());
assert!(validate_dpi(-1).is_err());
assert!(validate_dpi(2401).is_err());
}
#[test]
fn test_validate_chunking_params_valid() {
assert!(validate_chunking_params(1000, 200).is_ok());
assert!(validate_chunking_params(500, 50).is_ok());
assert!(validate_chunking_params(1, 0).is_ok());
}
#[test]
fn test_validate_chunking_params_zero_chars() {
let result = validate_chunking_params(0, 100);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("max_chars"));
}
#[test]
fn test_validate_chunking_params_overlap_too_large() {
let result = validate_chunking_params(100, 100);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("overlap"));
let result = validate_chunking_params(100, 150);
assert!(result.is_err());
}
#[test]
fn test_error_messages_are_helpful() {
let err = validate_binarization_method("bad").unwrap_err().to_string();
assert!(err.contains("otsu"));
assert!(err.contains("adaptive"));
assert!(err.contains("sauvola"));
let err = validate_token_reduction_level("bad").unwrap_err().to_string();
assert!(err.contains("off"));
assert!(err.contains("moderate"));
let err = validate_language_code("bad").unwrap_err().to_string();
assert!(err.contains("ISO 639"));
assert!(err.contains("en"));
}
// Tests for dependency validation functions
#[test]
fn test_validate_port_valid() {
assert!(validate_port(1).is_ok());
assert!(validate_port(80).is_ok());
assert!(validate_port(443).is_ok());
assert!(validate_port(8000).is_ok());
assert!(validate_port(65535).is_ok());
}
#[test]
fn test_validate_port_invalid() {
let result = validate_port(0);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("Port must be 1-65535"));
assert!(msg.contains("0"));
}
#[test]
fn test_validate_host_ipv4() {
assert!(validate_host("127.0.0.1").is_ok());
assert!(validate_host("0.0.0.0").is_ok());
assert!(validate_host("192.168.1.1").is_ok());
assert!(validate_host("10.0.0.1").is_ok());
assert!(validate_host("255.255.255.255").is_ok());
}
#[test]
fn test_validate_host_ipv6() {
assert!(validate_host("::1").is_ok());
assert!(validate_host("::").is_ok());
assert!(validate_host("2001:db8::1").is_ok());
assert!(validate_host("fe80::1").is_ok());
}
#[test]
fn test_validate_host_hostname() {
assert!(validate_host("localhost").is_ok());
assert!(validate_host("example.com").is_ok());
assert!(validate_host("sub.example.com").is_ok());
assert!(validate_host("api-server").is_ok());
assert!(validate_host("app123").is_ok());
}
#[test]
fn test_validate_host_invalid() {
let result = validate_host("");
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("Invalid host"));
let result = validate_host("not a valid host");
assert!(result.is_err());
let result = validate_host("256.256.256.256");
assert!(result.is_err());
}
#[test]
fn test_validate_cors_origin_https() {
assert!(validate_cors_origin("https://example.com").is_ok());
assert!(validate_cors_origin("https://localhost:3000").is_ok());
assert!(validate_cors_origin("https://sub.example.com").is_ok());
assert!(validate_cors_origin("https://192.168.1.1").is_ok());
assert!(validate_cors_origin("https://example.com/path").is_ok());
}
#[test]
fn test_validate_cors_origin_http() {
assert!(validate_cors_origin("http://example.com").is_ok());
assert!(validate_cors_origin("http://localhost:3000").is_ok());
assert!(validate_cors_origin("http://127.0.0.1:8000").is_ok());
}
#[test]
fn test_validate_cors_origin_wildcard() {
assert!(validate_cors_origin("*").is_ok());
}
#[test]
fn test_validate_cors_origin_invalid() {
let result = validate_cors_origin("not-a-url");
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("Invalid CORS origin"));
let result = validate_cors_origin("ftp://example.com");
assert!(result.is_err());
let result = validate_cors_origin("example.com");
assert!(result.is_err());
let result = validate_cors_origin("http://");
assert!(result.is_err());
}
#[test]
fn test_validate_upload_size_valid() {
assert!(validate_upload_size(1).is_ok());
assert!(validate_upload_size(1024).is_ok());
assert!(validate_upload_size(1_000_000).is_ok());
assert!(validate_upload_size(1_000_000_000).is_ok());
assert!(validate_upload_size(usize::MAX).is_ok());
}
#[test]
fn test_validate_upload_size_invalid() {
let result = validate_upload_size(0);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("Upload size must be greater than 0"));
assert!(msg.contains("0"));
}
}

View File

@@ -0,0 +1,574 @@
//! Per-section validation functions.
//!
//! This module contains validation functions for individual configuration sections
//! and their specific parameters. Each function validates a specific aspect of
//! the configuration and returns detailed error messages when validation fails.
use crate::{KreuzbergError, Result};
/// Valid binarization methods for image preprocessing.
#[cfg(test)]
const VALID_BINARIZATION_METHODS: &[&str] = &["otsu", "adaptive", "sauvola"];
/// Valid token reduction levels.
const VALID_TOKEN_REDUCTION_LEVELS: &[&str] = &["off", "light", "moderate", "aggressive", "maximum"];
/// Valid OCR backends.
const VALID_OCR_BACKENDS: &[&str] = &["tesseract", "easyocr", "paddleocr", "paddle-ocr", "vlm"];
/// Common ISO 639-1 language codes (extended list).
/// Covers most major languages and variants used in document processing.
const VALID_LANGUAGE_CODES: &[&str] = &[
"en",
"de",
"fr",
"es",
"it",
"pt",
"nl",
"pl",
"ru",
"zh",
"ja",
"ko",
"bg",
"cs",
"da",
"el",
"et",
"fi",
"hu",
"lt",
"lv",
"ro",
"sk",
"sl",
"sv",
"uk",
"ar",
"hi",
"th",
"tr",
"vi",
"eng",
"deu",
"fra",
"spa",
"ita",
"por",
"nld",
"pol",
"rus",
"zho",
"jpn",
"kor",
"ces",
"dan",
"ell",
"est",
"fin",
"hun",
"lit",
"lav",
"ron",
"slk",
"slv",
"swe",
"tur",
// PaddleOCR-specific language codes (non-ISO but widely used)
"ch",
"chinese_cht",
"latin",
"cyrillic",
"devanagari",
"arabic",
];
/// Valid tesseract PSM (Page Segmentation Mode) values.
#[cfg(test)]
const VALID_TESSERACT_PSM: &[i32] = &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13];
/// Valid tesseract OEM (OCR Engine Mode) values.
#[cfg(test)]
const VALID_TESSERACT_OEM: &[i32] = &[0, 1, 2, 3];
/// Valid output formats for document extraction.
/// Supports plain text, markdown, djot, HTML, and structured (JSON) output formats.
/// Also accepts aliases: "text" for "plain", "md" for "markdown", "json" for "structured".
#[cfg(test)]
const VALID_OUTPUT_FORMATS: &[&str] = &["plain", "text", "markdown", "md", "djot", "html", "structured", "json"];
/// Validate a binarization method string.
///
/// # Arguments
///
/// * `method` - The binarization method to validate (e.g., "otsu", "adaptive", "sauvola")
///
/// # Returns
///
/// `Ok(())` if the method is valid, or a `ValidationError` with details about valid options.
///
/// # Examples
///
/// ```rust
/// use kreuzberg::core::config_validation::validate_binarization_method;
///
/// assert!(validate_binarization_method("otsu").is_ok());
/// assert!(validate_binarization_method("adaptive").is_ok());
/// assert!(validate_binarization_method("invalid").is_err());
/// ```
#[cfg(test)]
pub(crate) fn validate_binarization_method(method: &str) -> Result<()> {
let method = method.to_lowercase();
if VALID_BINARIZATION_METHODS.contains(&method.as_str()) {
Ok(())
} else {
Err(KreuzbergError::Validation {
message: format!(
"Invalid binarization method '{}'. Valid options are: {}",
method,
VALID_BINARIZATION_METHODS.join(", ")
),
source: None,
})
}
}
/// Validate a token reduction level string.
///
/// # Arguments
///
/// * `level` - The token reduction level to validate (e.g., "off", "light", "moderate")
///
/// # Returns
///
/// `Ok(())` if the level is valid, or a `ValidationError` with details about valid options.
///
/// # Examples
///
/// ```ignore
/// use kreuzberg::core::config_validation::validate_token_reduction_level;
///
/// assert!(validate_token_reduction_level("off").is_ok());
/// assert!(validate_token_reduction_level("moderate").is_ok());
/// assert!(validate_token_reduction_level("extreme").is_err());
/// ```
pub(crate) fn validate_token_reduction_level(level: &str) -> Result<()> {
let level = level.to_lowercase();
if VALID_TOKEN_REDUCTION_LEVELS.contains(&level.as_str()) {
Ok(())
} else {
Err(KreuzbergError::Validation {
message: format!(
"Invalid token reduction level '{}'. Valid options are: {}",
level,
VALID_TOKEN_REDUCTION_LEVELS.join(", ")
),
source: None,
})
}
}
/// Validate an OCR backend string.
///
/// # Arguments
///
/// * `backend` - The OCR backend to validate (e.g., "tesseract", "easyocr", "paddleocr")
///
/// # Returns
///
/// `Ok(())` if the backend is valid, or a `ValidationError` with details about valid options.
///
/// # Examples
///
/// ```ignore
/// use kreuzberg::core::config_validation::validate_ocr_backend;
///
/// assert!(validate_ocr_backend("tesseract").is_ok());
/// assert!(validate_ocr_backend("easyocr").is_ok());
/// assert!(validate_ocr_backend("invalid").is_err());
/// ```
pub(crate) fn validate_ocr_backend(backend: &str) -> Result<()> {
let backend = backend.to_lowercase();
if VALID_OCR_BACKENDS.contains(&backend.as_str()) {
Ok(())
} else {
Err(KreuzbergError::Validation {
message: format!(
"Invalid OCR backend '{}'. Valid options are: {}",
backend,
VALID_OCR_BACKENDS.join(", ")
),
source: None,
})
}
}
/// Validate a language code (ISO 639-1 or 639-3 format).
///
/// Accepts both 2-letter ISO 639-1 codes (e.g., "en", "de") and
/// 3-letter ISO 639-3 codes (e.g., "eng", "deu") for broader compatibility.
///
/// # Arguments
///
/// * `code` - The language code to validate
///
/// # Returns
///
/// `Ok(())` if the code is valid, or a `ValidationError` indicating an invalid language code.
///
/// # Examples
///
/// ```ignore
/// use kreuzberg::core::config_validation::validate_language_code;
///
/// assert!(validate_language_code("en").is_ok());
/// assert!(validate_language_code("eng").is_ok());
/// assert!(validate_language_code("de").is_ok());
/// assert!(validate_language_code("deu").is_ok());
/// assert!(validate_language_code("invalid").is_err());
/// ```
#[cfg_attr(alef, alef(skip))]
pub(crate) fn validate_language_code(code: &str) -> Result<()> {
let code_lower = code.to_lowercase();
// Accept "all" and "*" as special values to auto-detect installed languages
if code_lower == "all" || code_lower == "*" {
return Ok(());
}
if VALID_LANGUAGE_CODES.contains(&code_lower.as_str()) {
return Ok(());
}
Err(KreuzbergError::Validation {
message: format!(
"Invalid language code '{}'. Use ISO 639-1 (2-letter, e.g., 'en', 'de') \
or ISO 639-3 (3-letter, e.g., 'eng', 'deu') codes. \
Common codes: en, de, fr, es, it, pt, nl, pl, ru, zh, ja, ko, ar, hi, th.",
code
),
source: None,
})
}
/// Validate a tesseract Page Segmentation Mode (PSM).
///
/// # Arguments
///
/// * `psm` - The PSM value to validate (0-13)
///
/// # Returns
///
/// `Ok(())` if the PSM is valid, or a `ValidationError` with details about valid ranges.
///
/// # Examples
///
/// ```rust
/// use kreuzberg::core::config_validation::validate_tesseract_psm;
///
/// assert!(validate_tesseract_psm(3).is_ok()); // Fully automatic
/// assert!(validate_tesseract_psm(6).is_ok()); // Single block of text
/// assert!(validate_tesseract_psm(14).is_err()); // Out of range
/// ```
#[cfg(test)]
pub(crate) fn validate_tesseract_psm(psm: i32) -> Result<()> {
if VALID_TESSERACT_PSM.contains(&psm) {
Ok(())
} else {
Err(KreuzbergError::Validation {
message: format!(
"Invalid tesseract PSM value '{}'. Valid range is 0-13. \
Common values: 3 (auto), 6 (single block), 11 (sparse text).",
psm
),
source: None,
})
}
}
/// Validate a tesseract OCR Engine Mode (OEM).
///
/// # Arguments
///
/// * `oem` - The OEM value to validate (0-3)
///
/// # Returns
///
/// `Ok(())` if the OEM is valid, or a `ValidationError` with details about valid options.
///
/// # Examples
///
/// ```rust
/// use kreuzberg::core::config_validation::validate_tesseract_oem;
///
/// assert!(validate_tesseract_oem(1).is_ok()); // Neural nets (LSTM)
/// assert!(validate_tesseract_oem(2).is_ok()); // Legacy + LSTM
/// assert!(validate_tesseract_oem(4).is_err()); // Out of range
/// ```
#[cfg(test)]
pub(crate) fn validate_tesseract_oem(oem: i32) -> Result<()> {
if VALID_TESSERACT_OEM.contains(&oem) {
Ok(())
} else {
Err(KreuzbergError::Validation {
message: format!(
"Invalid tesseract OEM value '{}'. Valid range is 0-3. \
0=Legacy, 1=LSTM, 2=Legacy+LSTM, 3=Default",
oem
),
source: None,
})
}
}
/// Validate a document extraction output format.
///
/// Accepts the following formats and aliases:
/// - "plain" or "text" for plain text output
/// - "markdown" or "md" for Markdown output
/// - "djot" for Djot markup format
/// - "html" for HTML output
///
/// # Arguments
///
/// * `format` - The output format to validate
///
/// # Returns
///
/// `Ok(())` if the format is valid, or a `ValidationError` with details about valid options.
///
/// # Examples
///
/// ```rust
/// use kreuzberg::core::config_validation::validate_output_format;
///
/// assert!(validate_output_format("text").is_ok());
/// assert!(validate_output_format("plain").is_ok());
/// assert!(validate_output_format("markdown").is_ok());
/// assert!(validate_output_format("md").is_ok());
/// assert!(validate_output_format("djot").is_ok());
/// assert!(validate_output_format("html").is_ok());
/// assert!(validate_output_format("json").is_ok());
/// ```
#[cfg(test)]
pub(crate) fn validate_output_format(format: &str) -> Result<()> {
let format = format.to_lowercase();
if VALID_OUTPUT_FORMATS.contains(&format.as_str()) {
Ok(())
} else {
Err(KreuzbergError::Validation {
message: format!(
"Invalid output format '{}'. Valid options are: {}",
format,
VALID_OUTPUT_FORMATS.join(", ")
),
source: None,
})
}
}
/// Validate a confidence threshold value.
///
/// Confidence thresholds should be between 0.0 and 1.0 inclusive.
///
/// # Arguments
///
/// * `confidence` - The confidence threshold to validate
///
/// # Returns
///
/// `Ok(())` if the confidence is valid, or a `ValidationError` with details about valid ranges.
///
/// # Examples
///
/// ```rust
/// use kreuzberg::core::config_validation::validate_confidence;
///
/// assert!(validate_confidence(0.5).is_ok());
/// assert!(validate_confidence(0.0).is_ok());
/// assert!(validate_confidence(1.0).is_ok());
/// assert!(validate_confidence(1.5).is_err());
/// assert!(validate_confidence(-0.1).is_err());
/// ```
#[cfg(test)]
pub(crate) fn validate_confidence(confidence: f64) -> Result<()> {
if (0.0..=1.0).contains(&confidence) {
Ok(())
} else {
Err(KreuzbergError::Validation {
message: format!(
"Invalid confidence threshold '{}'. Must be between 0.0 and 1.0.",
confidence
),
source: None,
})
}
}
/// Validate a DPI (dots per inch) value.
///
/// DPI should be a positive integer, typically 72-600.
///
/// # Arguments
///
/// * `dpi` - The DPI value to validate
///
/// # Returns
///
/// `Ok(())` if the DPI is valid, or a `ValidationError` with details about valid ranges.
///
/// # Examples
///
/// ```rust
/// use kreuzberg::core::config_validation::validate_dpi;
///
/// assert!(validate_dpi(96).is_ok());
/// assert!(validate_dpi(300).is_ok());
/// assert!(validate_dpi(0).is_err());
/// assert!(validate_dpi(-1).is_err());
/// ```
#[cfg(test)]
pub(crate) fn validate_dpi(dpi: i32) -> Result<()> {
if dpi > 0 && dpi <= 2400 {
Ok(())
} else {
Err(KreuzbergError::Validation {
message: format!(
"Invalid DPI value '{}'. Must be a positive integer, typically 72-600.",
dpi
),
source: None,
})
}
}
/// Validate chunk size parameters.
///
/// Checks that max_chars > 0 and max_overlap < max_chars.
///
/// # Arguments
///
/// * `max_chars` - The maximum characters per chunk
/// * `max_overlap` - The maximum overlap between chunks
///
/// # Returns
///
/// `Ok(())` if the parameters are valid, or a `ValidationError` with details about constraints.
///
/// # Examples
///
/// ```ignore
/// use kreuzberg::core::config_validation::validate_chunking_params;
///
/// assert!(validate_chunking_params(1000, 200).is_ok());
/// assert!(validate_chunking_params(500, 50).is_ok());
/// assert!(validate_chunking_params(0, 100).is_err()); // max_chars must be > 0
/// assert!(validate_chunking_params(100, 150).is_err()); // overlap >= max_chars
/// ```
pub(crate) fn validate_chunking_params(max_chars: usize, max_overlap: usize) -> Result<()> {
if max_chars == 0 {
return Err(KreuzbergError::Validation {
message: "max_chars must be greater than 0".to_string(),
source: None,
});
}
if max_overlap >= max_chars {
return Err(KreuzbergError::Validation {
message: format!(
"max_overlap ({}) must be less than max_chars ({})",
max_overlap, max_chars
),
source: None,
});
}
Ok(())
}
/// Validate that an [`LlmConfig`](crate::core::config::LlmConfig) has a non-empty model string.
///
/// # Arguments
///
/// * `model` - The model string to validate
///
/// # Returns
///
/// `Ok(())` if the model is non-empty, or a `ValidationError` otherwise.
///
/// # Examples
///
/// ```rust
/// use kreuzberg::core::config_validation::validate_llm_config_model;
///
/// assert!(validate_llm_config_model("openai/gpt-4o").is_ok());
/// assert!(validate_llm_config_model("").is_err());
/// ```
#[cfg(test)]
pub(crate) fn validate_llm_config_model(model: &str) -> Result<()> {
if model.trim().is_empty() {
return Err(KreuzbergError::Validation {
message: "LLM config 'model' must not be empty. Provide a model identifier (e.g., 'openai/gpt-4o')."
.to_string(),
source: None,
});
}
Ok(())
}
/// Validate that a VLM OCR backend has the required `vlm_config`.
///
/// When the OCR backend is set to `"vlm"`, the `vlm_config` field must be present
/// to provide the model endpoint configuration, and the model string must be non-empty.
///
/// # Arguments
///
/// * `backend` - The OCR backend name
/// * `vlm_config` - The optional VLM config to validate
///
/// # Returns
///
/// `Ok(())` if the backend is not `"vlm"` or `vlm_config` is present with a valid model,
/// or a `ValidationError` if `"vlm"` backend is used without `vlm_config` or with an empty model.
///
/// # Examples
///
/// ```rust
/// use kreuzberg::core::config_validation::validate_vlm_backend_config;
/// use kreuzberg::core::config::LlmConfig;
///
/// assert!(validate_vlm_backend_config("tesseract", None).is_ok());
/// let config = LlmConfig {
/// model: "openai/gpt-4o".to_string(),
/// api_key: None,
/// base_url: None,
/// timeout_secs: None,
/// max_retries: None,
/// temperature: None,
/// max_tokens: None,
/// };
/// assert!(validate_vlm_backend_config("vlm", Some(&config)).is_ok());
/// assert!(validate_vlm_backend_config("vlm", None).is_err());
/// ```
#[cfg(test)]
pub(crate) fn validate_vlm_backend_config(
backend: &str,
vlm_config: Option<&crate::core::config::LlmConfig>,
) -> Result<()> {
if backend.to_lowercase() == "vlm" {
match vlm_config {
None => {
return Err(KreuzbergError::Validation {
message: "OCR backend 'vlm' requires 'vlm_config' to be set with model endpoint configuration."
.to_string(),
source: None,
});
}
Some(config) => {
validate_llm_config_model(&config.model)?;
}
}
}
Ok(())
}

View File

@@ -0,0 +1,331 @@
//! Batch extraction operations for concurrent processing.
//!
//! This module provides parallel extraction capabilities for processing
//! multiple files or byte arrays concurrently with automatic resource management.
#[cfg(feature = "tokio-runtime")]
use crate::core::config::BatchBytesItem;
#[cfg(feature = "tokio-runtime")]
use crate::core::config::BatchFileItem;
use crate::core::config::ExtractionConfig;
use crate::core::config::extraction::FileExtractionConfig;
use crate::types::ExtractionResult;
use crate::{KreuzbergError, Result};
use std::future::Future;
use std::sync::Arc;
use std::time::Instant;
use super::bytes::extract_bytes;
use super::file::extract_file;
use super::helpers::error_extraction_result;
/// Shared batch result collection: spawns tasks via callback, collects ordered results.
#[cfg(feature = "tokio-runtime")]
async fn collect_batch<F, Fut>(count: usize, config: &ExtractionConfig, spawn_task: F) -> Result<Vec<ExtractionResult>>
where
F: Fn(usize, Arc<tokio::sync::Semaphore>) -> Fut,
Fut: Future<Output = (usize, Result<ExtractionResult>, u64)> + Send + 'static,
{
use tokio::sync::Semaphore;
use tokio::task::JoinSet;
if count == 0 {
return Ok(vec![]);
}
let max_concurrent = config
.max_concurrent_extractions
.or_else(|| config.concurrency.as_ref().and_then(|c| c.max_threads))
.unwrap_or_else(|| crate::core::config::concurrency::resolve_thread_budget(config.concurrency.as_ref()));
let semaphore = Arc::new(Semaphore::new(max_concurrent));
let mut tasks = JoinSet::new();
for index in 0..count {
let sem = Arc::clone(&semaphore);
tasks.spawn(spawn_task(index, sem));
}
let mut results: Vec<Option<ExtractionResult>> = vec![None; count];
while let Some(task_result) = tasks.join_next().await {
match task_result {
Ok((index, Ok(result), _elapsed_ms)) => {
results[index] = Some(result);
}
Ok((index, Err(e), elapsed_ms)) => {
results[index] = Some(error_extraction_result(&e, Some(elapsed_ms)));
}
Err(join_err) => {
return Err(KreuzbergError::Other(format!("Task panicked: {}", join_err)));
}
}
}
#[allow(clippy::unwrap_used)]
Ok(results.into_iter().map(|r| r.unwrap()).collect())
}
/// Run a single extraction task with semaphore gating, timing, optional timeout, and batch mode.
///
/// When `cancel_token` is provided and the timeout fires, the token is signalled so that
/// any blocking PDF operations in progress can observe the cancellation at the next
/// inter-page checkpoint and stop early.
#[cfg(feature = "tokio-runtime")]
async fn run_timed_extraction<F, Fut>(
index: usize,
semaphore: Arc<tokio::sync::Semaphore>,
timeout_secs: Option<u64>,
cancel_token: Option<crate::cancellation::CancellationToken>,
extract_fn: F,
) -> (usize, Result<ExtractionResult>, u64)
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<ExtractionResult>>,
{
let _permit = semaphore.acquire().await.unwrap();
let start = Instant::now();
let extraction_future = crate::core::batch_mode::with_batch_mode(extract_fn());
let mut result = match timeout_secs {
Some(secs) => match tokio::time::timeout(std::time::Duration::from_secs(secs), extraction_future).await {
Ok(inner) => inner,
Err(_elapsed) => {
// Signal the cancellation token so that any blocking PDF thread can
// detect it at the next inter-page checkpoint and stop processing.
if let Some(ref token) = cancel_token {
token.cancel();
}
let elapsed_ms = start.elapsed().as_millis() as u64;
Err(KreuzbergError::Timeout {
elapsed_ms,
limit_ms: secs * 1000,
})
}
},
None => extraction_future.await,
};
let elapsed_ms = start.elapsed().as_millis() as u64;
if let Ok(ref mut r) = result {
r.metadata.extraction_duration_ms = Some(elapsed_ms);
}
(index, result, elapsed_ms)
}
/// Resolve a per-file config against a base config. Returns owned config.
fn resolve_config(base: &ExtractionConfig, file_config: &Option<FileExtractionConfig>) -> ExtractionConfig {
match file_config {
Some(fc) => base.with_file_overrides(fc),
None => base.clone(),
}
}
/// Extract content from multiple files concurrently.
///
/// This function processes multiple files in parallel, automatically managing
/// concurrency to prevent resource exhaustion. The concurrency limit can be
/// configured via `ExtractionConfig::max_concurrent_extractions` or defaults
/// to `(num_cpus * 1.5).ceil()`.
///
/// Each file can optionally specify a [`FileExtractionConfig`] that overrides specific
/// fields from the batch-level `config`. Pass `None` for a file to use the batch defaults.
/// Batch-level settings like `max_concurrent_extractions` and `use_cache` are always
/// taken from the batch-level `config`.
///
/// # Arguments
///
/// * `items` - Vector of `BatchFileItem` structs, each containing a path and optional
/// per-file configuration overrides.
/// * `config` - Batch-level extraction configuration (provides defaults and batch settings)
///
/// # Returns
///
/// A vector of `ExtractionResult` in the same order as the input items.
///
/// # Errors
///
/// Individual file errors are captured in the result metadata. System errors
/// (IO, RuntimeError equivalents) will bubble up and fail the entire batch.
///
/// # Examples
///
/// Simple usage with no per-file overrides:
///
/// ```rust,no_run
/// use kreuzberg::core::extractor::batch_extract_files;
/// use kreuzberg::core::config::{ExtractionConfig, BatchFileItem};
/// use std::path::PathBuf;
///
/// # async fn example() -> kreuzberg::Result<()> {
/// let config = ExtractionConfig::default();
/// let items = vec![
/// BatchFileItem { path: "doc1.pdf".into(), config: None },
/// BatchFileItem { path: "doc2.pdf".into(), config: None },
/// ];
/// let results = batch_extract_files(items, &config).await?;
/// println!("Processed {} files", results.len());
/// # Ok(())
/// # }
/// ```
///
/// Per-file configuration overrides:
///
/// ```rust,no_run
/// use kreuzberg::core::extractor::batch_extract_files;
/// use kreuzberg::core::config::{ExtractionConfig, BatchFileItem, FileExtractionConfig};
/// use std::path::PathBuf;
///
/// # async fn example() -> kreuzberg::Result<()> {
/// let config = ExtractionConfig::default();
/// let items = vec![
/// BatchFileItem {
/// path: "scan.pdf".into(),
/// config: Some(FileExtractionConfig { force_ocr: Some(true), ..Default::default() }),
/// },
/// BatchFileItem { path: "notes.txt".into(), config: None },
/// ];
/// let results = batch_extract_files(items, &config).await?;
/// # Ok(())
/// # }
/// ```
#[cfg(feature = "tokio-runtime")]
#[cfg_attr(feature = "otel", tracing::instrument(
skip(config, items),
fields(
extraction.batch_size = items.len(),
)
))]
pub async fn batch_extract_files(
items: Vec<BatchFileItem>,
config: &ExtractionConfig,
) -> Result<Vec<ExtractionResult>> {
let config_arc = Arc::new(config.clone());
// Use Arc<Vec> for file items — paths are small, so keeping them all alive is fine.
let items_arc = Arc::new(items);
let count = items_arc.len();
collect_batch(count, config, |index, sem| {
let cfg = Arc::clone(&config_arc);
let items = Arc::clone(&items_arc);
async move {
let item = &items[index];
let resolved = resolve_config(&cfg, &item.config);
let timeout = resolved.extraction_timeout_secs;
let cancel_token = resolved.cancel_token.clone();
run_timed_extraction(index, sem, timeout, cancel_token, || {
let path = item.path.clone();
async move { extract_file(&path, None, &resolved).await }
})
.await
}
})
.await
}
/// Extract content from multiple byte arrays concurrently.
///
/// This function processes multiple byte arrays in parallel, automatically managing
/// concurrency to prevent resource exhaustion. The concurrency limit can be
/// configured via `ExtractionConfig::max_concurrent_extractions` or defaults
/// to `(num_cpus * 1.5).ceil()`.
///
/// Each item can optionally specify a [`FileExtractionConfig`] that overrides specific
/// fields from the batch-level `config`. Pass `None` as the config to use
/// the batch-level defaults for that item.
///
/// # Arguments
///
/// * `items` - Vector of `BatchBytesItem` structs, each containing content bytes,
/// MIME type, and optional per-item configuration overrides.
/// * `config` - Batch-level extraction configuration
///
/// # Returns
///
/// A vector of `ExtractionResult` in the same order as the input items.
///
/// # Examples
///
/// Simple usage with no per-item overrides:
///
/// ```rust,no_run
/// use kreuzberg::core::extractor::batch_extract_bytes;
/// use kreuzberg::core::config::{ExtractionConfig, BatchBytesItem};
///
/// # async fn example() -> kreuzberg::Result<()> {
/// let config = ExtractionConfig::default();
/// let items = vec![
/// BatchBytesItem { content: b"content 1".to_vec(), mime_type: "text/plain".to_string(), config: None },
/// BatchBytesItem { content: b"content 2".to_vec(), mime_type: "text/plain".to_string(), config: None },
/// ];
/// let results = batch_extract_bytes(items, &config).await?;
/// println!("Processed {} items", results.len());
/// # Ok(())
/// # }
/// ```
///
/// Per-item configuration overrides:
///
/// ```rust,no_run
/// use kreuzberg::core::extractor::batch_extract_bytes;
/// use kreuzberg::core::config::{ExtractionConfig, BatchBytesItem, FileExtractionConfig};
///
/// # async fn example() -> kreuzberg::Result<()> {
/// let config = ExtractionConfig::default();
/// let items = vec![
/// BatchBytesItem { content: b"content".to_vec(), mime_type: "text/plain".to_string(), config: None },
/// BatchBytesItem {
/// content: b"<html>test</html>".to_vec(),
/// mime_type: "text/html".to_string(),
/// config: Some(FileExtractionConfig { force_ocr: Some(true), ..Default::default() }),
/// },
/// ];
/// let results = batch_extract_bytes(items, &config).await?;
/// # Ok(())
/// # }
/// ```
#[cfg(feature = "tokio-runtime")]
#[cfg_attr(feature = "otel", tracing::instrument(
skip(config, items),
fields(
extraction.batch_size = items.len(),
)
))]
pub async fn batch_extract_bytes(
items: Vec<BatchBytesItem>,
config: &ExtractionConfig,
) -> Result<Vec<ExtractionResult>> {
let config_arc = Arc::new(config.clone());
let count = items.len();
// Move items into individually-indexed slots so each task can take ownership
// of its bytes without cloning. This avoids the memory regression of
// Arc<Vec<BatchBytesItem>> which would keep all byte arrays alive for the
// entire batch duration.
type BytesSlot = parking_lot::Mutex<Option<BatchBytesItem>>;
let slots: Arc<Vec<BytesSlot>> = Arc::new(
items
.into_iter()
.map(|item| parking_lot::Mutex::new(Some(item)))
.collect(),
);
collect_batch(count, config, |index, sem| {
let cfg = Arc::clone(&config_arc);
let slots = Arc::clone(&slots);
async move {
let item = slots[index].lock().take().expect("batch item already consumed");
let resolved = resolve_config(&cfg, &item.config);
let timeout = resolved.extraction_timeout_secs;
let cancel_token = resolved.cancel_token.clone();
run_timed_extraction(index, sem, timeout, cancel_token, || async move {
extract_bytes(&item.content, &item.mime_type, &resolved).await
})
.await
}
})
.await
}

View File

@@ -0,0 +1,170 @@
//! Byte array extraction operations.
//!
//! This module handles extraction from in-memory byte arrays, including:
//! - MIME type validation
//! - Legacy format conversion (DOC, PPT)
//! - Extraction pipeline orchestration
#[cfg(not(feature = "office"))]
use crate::KreuzbergError;
use crate::Result;
use crate::core::config::ExtractionConfig;
use crate::core::mime::{LEGACY_POWERPOINT_MIME_TYPE, LEGACY_WORD_MIME_TYPE};
use crate::types::ExtractionResult;
use super::file::extract_bytes_with_extractor;
/// Extract content from a byte array.
///
/// This is the main entry point for in-memory extraction. It performs the following steps:
/// 1. Validate MIME type
/// 2. Handle legacy format conversion if needed
/// 3. Select appropriate extractor from registry
/// 4. Extract content
/// 5. Run post-processing pipeline
///
/// # Arguments
///
/// * `content` - The byte array to extract
/// * `mime_type` - MIME type of the content
/// * `config` - Extraction configuration
///
/// # Returns
///
/// An `ExtractionResult` containing the extracted content and metadata.
///
/// # Errors
///
/// Returns `KreuzbergError::Validation` if MIME type is invalid.
/// Returns `KreuzbergError::UnsupportedFormat` if MIME type is not supported.
///
/// # Example
///
/// ```rust,no_run
/// use kreuzberg::core::extractor::extract_bytes;
/// use kreuzberg::core::config::ExtractionConfig;
///
/// # async fn example() -> kreuzberg::Result<()> {
/// let config = ExtractionConfig::default();
/// let bytes = b"Hello, world!";
/// let result = extract_bytes(bytes, "text/plain", &config).await?;
/// println!("Content: {}", result.content);
/// # Ok(())
/// # }
/// ```
#[cfg_attr(feature = "otel", tracing::instrument(
skip(config, content),
fields(
{ crate::telemetry::conventions::OPERATION } = crate::telemetry::conventions::operations::EXTRACT_BYTES,
{ crate::telemetry::conventions::DOCUMENT_MIME_TYPE } = mime_type,
{ crate::telemetry::conventions::DOCUMENT_SIZE_BYTES } = content.len(),
{ crate::telemetry::conventions::OTEL_STATUS_CODE } = tracing::field::Empty,
{ crate::telemetry::conventions::ERROR_TYPE } = tracing::field::Empty,
{ crate::telemetry::conventions::ERROR_MESSAGE } = tracing::field::Empty,
)
))]
pub async fn extract_bytes(content: &[u8], mime_type: &str, config: &ExtractionConfig) -> Result<ExtractionResult> {
use crate::core::mime;
let extraction_future = async {
if config.force_ocr && config.effective_disable_ocr() {
return Err(crate::KreuzbergError::Validation {
message: "force_ocr and disable_ocr cannot both be true".to_string(),
source: None,
});
}
let validated_mime = if mime_type == "application/octet-stream" {
// When tree-sitter is configured, check if content is recognized source code.
// This allows octet-stream files with tree-sitter config to be detected as code.
#[cfg(feature = "tree-sitter")]
{
if config.tree_sitter.is_some() {
if let Ok(text) = std::str::from_utf8(content) {
let trimmed = text.trim_start();
if tree_sitter_language_pack::detect_language_from_content(trimmed).is_some() {
// Recognize as source code when tree-sitter can detect a language.
mime::SOURCE_CODE_MIME_TYPE.to_string()
} else {
mime::detect_mime_type_from_bytes(content)?
}
} else {
mime::detect_mime_type_from_bytes(content)?
}
} else {
mime::detect_mime_type_from_bytes(content)?
}
}
#[cfg(not(feature = "tree-sitter"))]
{
let _ = config;
mime::detect_mime_type_from_bytes(content)?
}
} else {
mime::validate_mime_type(mime_type)?
};
// Native DOC/PPT extractors are registered in the plugin registry.
// When the office feature is disabled, these MIME types are unsupported.
#[cfg(not(feature = "office"))]
match validated_mime.as_str() {
LEGACY_WORD_MIME_TYPE => {
return Err(KreuzbergError::UnsupportedFormat(
"Legacy Word extraction requires the `office` feature".to_string(),
));
}
LEGACY_POWERPOINT_MIME_TYPE => {
return Err(KreuzbergError::UnsupportedFormat(
"Legacy PowerPoint extraction requires the `office` feature".to_string(),
));
}
_ => {}
}
// Suppress unused import warnings when office feature is enabled
#[cfg(feature = "office")]
{
let _ = LEGACY_WORD_MIME_TYPE;
let _ = LEGACY_POWERPOINT_MIME_TYPE;
}
extract_bytes_with_extractor(content, &validated_mime, config).await
};
#[cfg(feature = "tokio-runtime")]
let result = if let Some(secs) = config.extraction_timeout_secs {
let start = std::time::Instant::now();
match tokio::time::timeout(std::time::Duration::from_secs(secs), extraction_future).await {
Ok(inner) => inner,
Err(_elapsed) => {
if let Some(ref token) = config.cancel_token {
token.cancel();
}
Err(crate::KreuzbergError::Timeout {
elapsed_ms: start.elapsed().as_millis() as u64,
limit_ms: secs * 1000,
})
}
}
} else {
extraction_future.await
};
#[cfg(not(feature = "tokio-runtime"))]
let result = {
if config.extraction_timeout_secs.is_some() {
return Err(crate::KreuzbergError::Validation {
message: "extraction_timeout_secs requires the 'tokio-runtime' feature to be enabled".to_string(),
source: None,
});
}
extraction_future.await
};
#[cfg(feature = "otel")]
if let Err(ref e) = result {
crate::telemetry::spans::record_error_on_current_span(e);
}
result
}

View File

@@ -0,0 +1,279 @@
//! File-based extraction operations.
//!
//! This module handles extraction from filesystem paths, including:
//! - MIME type detection and validation
//! - Legacy format conversion (DOC, PPT)
//! - File validation and reading
//! - Extraction pipeline orchestration
#[cfg(not(feature = "office"))]
use crate::KreuzbergError;
use crate::Result;
use crate::core::config::ExtractionConfig;
use crate::core::mime::{LEGACY_POWERPOINT_MIME_TYPE, LEGACY_WORD_MIME_TYPE};
use crate::types::ExtractionResult;
use std::path::Path;
use super::helpers::get_extractor;
/// Extract content from a file.
///
/// This is the main entry point for file-based extraction. It performs the following steps:
/// 1. Check cache for existing result (if caching enabled)
/// 2. Detect or validate MIME type
/// 3. Select appropriate extractor from registry
/// 4. Extract content
/// 5. Run post-processing pipeline
/// 6. Store result in cache (if caching enabled)
///
/// # Arguments
///
/// * `path` - Path to the file to extract
/// * `mime_type` - Optional MIME type override. If None, will be auto-detected
/// * `config` - Extraction configuration
///
/// # Returns
///
/// An `ExtractionResult` containing the extracted content and metadata.
///
/// # Errors
///
/// Returns `KreuzbergError::Io` if the file doesn't exist (NotFound) or for other file I/O errors.
/// Returns `KreuzbergError::UnsupportedFormat` if MIME type is not supported.
///
/// # Example
///
/// ```rust,no_run
/// use kreuzberg::core::extractor::extract_file;
/// use kreuzberg::core::config::ExtractionConfig;
///
/// # async fn example() -> kreuzberg::Result<()> {
/// let config = ExtractionConfig::default();
/// let result = extract_file("document.pdf", None, &config).await?;
/// println!("Content: {}", result.content);
/// # Ok(())
/// # }
/// ```
#[cfg_attr(feature = "otel", tracing::instrument(
skip(config, path),
fields(
{ crate::telemetry::conventions::OPERATION } = crate::telemetry::conventions::operations::EXTRACT_FILE,
{ crate::telemetry::conventions::DOCUMENT_FILENAME } = tracing::field::Empty,
{ crate::telemetry::conventions::OTEL_STATUS_CODE } = tracing::field::Empty,
{ crate::telemetry::conventions::ERROR_TYPE } = tracing::field::Empty,
{ crate::telemetry::conventions::ERROR_MESSAGE } = tracing::field::Empty,
)
))]
pub async fn extract_file(
path: impl AsRef<Path>,
mime_type: Option<&str>,
config: &ExtractionConfig,
) -> Result<ExtractionResult> {
use crate::core::{io, mime};
let path = path.as_ref();
#[cfg(feature = "otel")]
{
let span = tracing::Span::current();
span.record(
crate::telemetry::conventions::DOCUMENT_FILENAME,
crate::telemetry::spans::sanitize_path(path),
);
}
let extraction_future = async {
io::validate_file_exists(path)?;
if config.force_ocr && config.effective_disable_ocr() {
return Err(crate::KreuzbergError::Validation {
message: "force_ocr and disable_ocr cannot both be true".to_string(),
source: None,
});
}
let detected_mime = mime::detect_or_validate(path.to_str(), mime_type)?;
// Native DOC/PPT extractors are registered in the plugin registry.
// When the office feature is disabled, these MIME types are unsupported.
#[cfg(not(feature = "office"))]
match detected_mime.as_str() {
LEGACY_WORD_MIME_TYPE => {
return Err(KreuzbergError::UnsupportedFormat(
"Legacy Word extraction requires the `office` feature".to_string(),
));
}
LEGACY_POWERPOINT_MIME_TYPE => {
return Err(KreuzbergError::UnsupportedFormat(
"Legacy PowerPoint extraction requires the `office` feature".to_string(),
));
}
_ => {}
}
// Suppress unused import warnings when office feature is enabled
#[cfg(feature = "office")]
{
let _ = LEGACY_WORD_MIME_TYPE;
let _ = LEGACY_POWERPOINT_MIME_TYPE;
}
extract_file_with_extractor(path, &detected_mime, config).await
};
#[cfg(feature = "tokio-runtime")]
let result = if let Some(secs) = config.extraction_timeout_secs {
let start = std::time::Instant::now();
match tokio::time::timeout(std::time::Duration::from_secs(secs), extraction_future).await {
Ok(inner) => inner,
Err(_elapsed) => {
if let Some(ref token) = config.cancel_token {
token.cancel();
}
Err(crate::KreuzbergError::Timeout {
elapsed_ms: start.elapsed().as_millis() as u64,
limit_ms: secs * 1000,
})
}
}
} else {
extraction_future.await
};
#[cfg(not(feature = "tokio-runtime"))]
let result = {
if config.extraction_timeout_secs.is_some() {
return Err(crate::KreuzbergError::Validation {
message: "extraction_timeout_secs requires the 'tokio-runtime' feature to be enabled".to_string(),
source: None,
});
}
extraction_future.await
};
#[cfg(feature = "otel")]
if let Err(ref e) = result {
crate::telemetry::spans::record_error_on_current_span(e);
}
result
}
pub(in crate::core::extractor) async fn extract_file_with_extractor(
path: &Path,
mime_type: &str,
config: &ExtractionConfig,
) -> Result<ExtractionResult> {
// Normalize config so cache keys are consistent for ElementBased requests
// regardless of whether the caller explicitly set extract_pages.
let config = config.normalized();
let config = config.as_ref();
// Skip cache if disabled or TTL=0
if !config.use_cache || config.cache_ttl_secs == Some(0) {
return extract_file_uncached(path, mime_type, config).await;
}
// Generate cache key from file content hash + config fingerprint
let content_hash = crate::cache::blake3_hash_file(path)?;
let config_hash = hash_extraction_config(config, mime_type);
let cache_key = format!("{content_hash}_{config_hash}");
let namespace = config.cache_namespace.as_deref();
// Try cache read
if let Some(cache) = get_extraction_cache()
&& let Ok(Some(data)) = cache.get(&cache_key, path.to_str(), namespace, config.cache_ttl_secs)
&& let Ok(result) = rmp_serde::from_slice::<ExtractionResult>(&data)
{
tracing::debug!(cache_key = %cache_key, "Extraction cache hit");
return Ok(result);
}
// Cache miss — extract
let result = extract_file_uncached(path, mime_type, config).await?;
// Cache write (best-effort)
if let Some(cache) = get_extraction_cache()
&& let Ok(data) = rmp_serde::to_vec(&result)
{
let _ = cache.set(&cache_key, data, path.to_str(), namespace, config.cache_ttl_secs);
}
Ok(result)
}
/// Extract without caching logic.
async fn extract_file_uncached(path: &Path, mime_type: &str, config: &ExtractionConfig) -> Result<ExtractionResult> {
let budget = crate::core::config::concurrency::resolve_thread_budget(config.concurrency.as_ref());
crate::core::config::concurrency::init_thread_pools(budget);
crate::extractors::ensure_initialized()?;
let extractor = get_extractor(mime_type)?;
let doc = extractor.extract_file(path, mime_type, config).await?;
let result = crate::core::pipeline::run_pipeline(doc, config).await?;
Ok(result)
}
/// Hash ExtractionConfig fields that affect extraction output.
///
/// Excludes cache-control fields (use_cache, cache_namespace, cache_ttl_secs)
/// since they don't affect the extraction result. Uses a clone-and-normalize
/// approach to ensure determinism: cache fields are zeroed, then the struct
/// is serialized to canonical JSON via serde_json's sorted-keys representation.
fn hash_extraction_config(config: &ExtractionConfig, mime_type: &str) -> String {
let mut normalized = config.clone();
// Zero out cache-control fields so they don't affect the hash
normalized.use_cache = true;
normalized.cache_namespace = None;
normalized.cache_ttl_secs = None;
let mut hasher = blake3::Hasher::new();
hasher.update(mime_type.as_bytes());
// Use MessagePack for deterministic serialization (no float formatting issues,
// no HashMap key ordering issues — serde serializes struct fields in declaration order).
if let Ok(bytes) = rmp_serde::to_vec(&normalized) {
hasher.update(&bytes);
}
let hash = hasher.finalize();
hex::encode(&hash.as_bytes()[..16])
}
/// Get or initialize the global extraction cache.
fn get_extraction_cache() -> Option<&'static crate::cache::GenericCache> {
use std::sync::OnceLock;
static CACHE: OnceLock<Option<crate::cache::GenericCache>> = OnceLock::new();
CACHE
.get_or_init(|| {
crate::cache::GenericCache::new(
"extraction".to_string(),
None,
30.0, // 30-day default TTL
2000.0, // 2 GB max cache size
500.0, // 500 MB min free space
)
.ok()
})
.as_ref()
}
pub(in crate::core::extractor) async fn extract_bytes_with_extractor(
content: &[u8],
mime_type: &str,
config: &ExtractionConfig,
) -> Result<ExtractionResult> {
let config = config.normalized();
let config = config.as_ref();
let budget = crate::core::config::concurrency::resolve_thread_budget(config.concurrency.as_ref());
crate::core::config::concurrency::init_thread_pools(budget);
crate::extractors::ensure_initialized()?;
let extractor = get_extractor(mime_type)?;
let doc = extractor.extract_bytes(content, mime_type, config).await?;
let result = crate::core::pipeline::run_pipeline(doc, config).await?;
Ok(result)
}

View File

@@ -0,0 +1,85 @@
//! Helper functions and utilities for extraction operations.
//!
//! This module provides shared utilities used across extraction modules.
use crate::plugins::DocumentExtractor;
use crate::types::{ErrorMetadata, ExtractionResult, Metadata};
use crate::{KreuzbergError, Result};
use std::borrow::Cow;
use std::sync::Arc;
/// Get an extractor from the registry.
///
/// This function acquires the registry read lock and retrieves the appropriate
/// extractor for the given MIME type.
///
/// When the `otel` feature is enabled, the returned extractor is wrapped in an
/// [`InstrumentedExtractor`](crate::plugins::extractor::instrumented::InstrumentedExtractor)
/// that adds tracing spans and metrics automatically.
///
/// # Performance
///
/// RwLock read + HashMap lookup is ~100ns, fast enough without caching.
/// Removed thread-local cache to avoid Tokio work-stealing scheduler issues.
pub(in crate::core::extractor) fn get_extractor(mime_type: &str) -> Result<Arc<dyn DocumentExtractor>> {
let registry = crate::plugins::registry::get_document_extractor_registry();
let registry_read = registry.read();
let extractor = registry_read.get(mime_type)?;
#[cfg(feature = "otel")]
{
Ok(Arc::new(
crate::plugins::extractor::instrumented::InstrumentedExtractor::new(extractor),
))
}
#[cfg(not(feature = "otel"))]
{
Ok(extractor)
}
}
/// Get optimal pool sizing hint for a document.
///
/// This function calculates recommended pool sizes based on the document's
/// file size and MIME type. The hint can be used to create appropriately
/// sized thread pools for extraction, reducing memory waste from over-allocation.
///
/// # Arguments
///
/// * `file_size` - The size of the file in bytes
/// * `mime_type` - The MIME type of the document
///
/// # Returns
///
/// A `PoolSizeHint` with recommended pool configurations
///
/// # Example
///
/// ```rust,ignore
/// use kreuzberg::core::extractor::get_pool_sizing_hint;
///
/// let hint = get_pool_sizing_hint(5_000_000, "application/pdf");
/// println!("Recommended string buffers: {}", hint.string_buffer_count);
/// ```
/// Build an error `ExtractionResult` for failed batch items.
///
/// Used by both tokio-based batch functions and WASM synchronous fallbacks
/// to construct a uniform error result.
pub(crate) fn error_extraction_result(e: &KreuzbergError, elapsed_ms: Option<u64>) -> ExtractionResult {
let metadata = Metadata {
error: Some(ErrorMetadata {
error_type: format!("{:?}", e),
message: e.to_string(),
}),
extraction_duration_ms: elapsed_ms,
..Default::default()
};
ExtractionResult {
content: format!("Error: {}", e),
mime_type: Cow::Borrowed("text/plain"),
metadata,
..Default::default()
}
}

View File

@@ -0,0 +1,67 @@
//! Legacy synchronous extraction for WASM compatibility.
//!
//! This module provides truly synchronous extraction implementations
//! for environments where Tokio runtime is not available (e.g., WASM).
/// Synchronous extraction implementation for WASM compatibility.
///
/// This function performs extraction without requiring a tokio runtime.
/// It calls the sync extractor methods directly.
///
/// # Arguments
///
/// * `content` - The byte content to extract
/// * `mime_type` - Optional MIME type to validate/use
/// * `config` - Optional extraction configuration
///
/// # Returns
///
/// An `ExtractionResult` or a `KreuzbergError`
///
/// # Implementation Notes
///
/// This is called when the `tokio-runtime` feature is disabled.
/// It replicates the logic of `extract_bytes` but uses synchronous extractor methods.
#[cfg(not(feature = "tokio-runtime"))]
pub(super) fn extract_bytes_sync_impl(
content: &[u8],
mime_type: Option<&str>,
config: Option<&crate::core::config::ExtractionConfig>,
) -> crate::Result<crate::types::ExtractionResult> {
use crate::KreuzbergError;
use crate::core::extractor::helpers::get_extractor;
use crate::core::mime;
let cfg = config.cloned().unwrap_or_default();
let cfg = cfg.normalized().into_owned();
let validated_mime = if let Some(mime) = mime_type {
if mime == "application/octet-stream" {
mime::detect_mime_type_from_bytes(content)?
} else {
mime::validate_mime_type(mime)?
}
} else {
return Err(KreuzbergError::Validation {
message: "MIME type is required for synchronous extraction".to_string(),
source: None,
});
};
crate::extractors::ensure_initialized()?;
let extractor = get_extractor(&validated_mime)?;
let sync_extractor = extractor.as_sync_extractor().ok_or_else(|| {
KreuzbergError::UnsupportedFormat(format!(
"Extractor for '{}' does not support synchronous extraction",
validated_mime
))
})?;
let doc = sync_extractor.extract_sync(content, &validated_mime, &cfg)?;
let result = crate::core::pipeline::run_pipeline_sync(doc, &cfg)?;
Ok(result)
}

View File

@@ -0,0 +1,681 @@
//! Main extraction entry points.
//!
//! This module provides the primary API for extracting content from files and byte arrays.
//! It orchestrates the entire extraction pipeline: cache checking, MIME detection,
//! extractor selection, extraction, post-processing, and cache storage.
//!
//! # Functions
//!
//! - [`extract_file`] - Extract content from a file path
//! - [`extract_bytes`] - Extract content from a byte array
//! - [`batch_extract_files`] - Extract content from multiple files concurrently
//! - [`batch_extract_bytes`] - Extract content from multiple byte arrays concurrently
mod bytes;
mod file;
mod helpers;
mod legacy;
mod sync;
#[cfg(feature = "tokio-runtime")]
mod batch;
// Re-export public API
pub use bytes::extract_bytes;
pub use file::extract_file;
pub use sync::{batch_extract_bytes_sync, extract_bytes_sync};
#[cfg(feature = "tokio-runtime")]
pub use sync::extract_file_sync;
#[cfg(feature = "tokio-runtime")]
pub use batch::{batch_extract_bytes, batch_extract_files};
#[cfg(feature = "tokio-runtime")]
pub use sync::batch_extract_files_sync;
#[cfg(test)]
mod tests {
use super::*;
use crate::core::config::{BatchBytesItem, BatchFileItem, ExtractionConfig};
use serial_test::serial;
use std::fs::File;
use std::io::Write;
use std::sync::Arc;
use tempfile::tempdir;
fn assert_text_content(actual: &str, expected: &str) {
assert_eq!(actual.trim_end_matches('\n'), expected);
}
#[tokio::test]
async fn test_extract_file_basic() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
let mut file = File::create(&file_path).unwrap();
file.write_all(b"Hello, world!").unwrap();
let config = ExtractionConfig::default();
let result = extract_file(&file_path, None, &config).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_text_content(&result.content, "Hello, world!");
assert_eq!(result.mime_type, "text/plain");
}
#[tokio::test]
async fn test_extract_file_with_mime_override() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.dat");
let mut file = File::create(&file_path).unwrap();
file.write_all(b"test content").unwrap();
let config = ExtractionConfig::default();
let result = extract_file(&file_path, Some("text/plain"), &config).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.mime_type, "text/plain");
}
#[tokio::test]
async fn test_extract_file_nonexistent() {
let config = ExtractionConfig::default();
let result = extract_file("/nonexistent/file.txt", None, &config).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_extract_bytes_basic() {
let config = ExtractionConfig::default();
let result = extract_bytes(b"test content", "text/plain", &config).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_text_content(&result.content, "test content");
assert_eq!(result.mime_type, "text/plain");
}
#[tokio::test]
async fn test_extract_bytes_invalid_mime() {
let config = ExtractionConfig::default();
let result = extract_bytes(b"test", "invalid/mime", &config).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_batch_extract_file() {
let dir = tempdir().unwrap();
let file1 = dir.path().join("test1.txt");
let file2 = dir.path().join("test2.txt");
File::create(&file1).unwrap().write_all(b"content 1").unwrap();
File::create(&file2).unwrap().write_all(b"content 2").unwrap();
let config = ExtractionConfig::default();
let items = vec![
BatchFileItem {
path: file1,
config: None,
},
BatchFileItem {
path: file2,
config: None,
},
];
let results = batch_extract_files(items, &config).await;
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 2);
assert_text_content(&results[0].content, "content 1");
assert_text_content(&results[1].content, "content 2");
}
#[tokio::test]
async fn test_batch_extract_file_empty() {
let config = ExtractionConfig::default();
let items: Vec<BatchFileItem> = vec![];
let results = batch_extract_files(items, &config).await;
assert!(results.is_ok());
assert_eq!(results.unwrap().len(), 0);
}
#[tokio::test]
async fn test_batch_extract_bytes() {
let config = ExtractionConfig::default();
let items = vec![
BatchBytesItem {
content: b"content 1".to_vec(),
mime_type: "text/plain".to_string(),
config: None,
},
BatchBytesItem {
content: b"content 2".to_vec(),
mime_type: "text/plain".to_string(),
config: None,
},
];
let results = batch_extract_bytes(items, &config).await;
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 2);
assert_text_content(&results[0].content, "content 1");
assert_text_content(&results[1].content, "content 2");
}
#[test]
fn test_sync_wrappers() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
File::create(&file_path).unwrap().write_all(b"sync test").unwrap();
let config = ExtractionConfig::default();
let result = extract_file_sync(&file_path, None, &config);
assert!(result.is_ok());
let result = result.unwrap();
assert_text_content(&result.content, "sync test");
let result = extract_bytes_sync(b"test", "text/plain", &config);
assert!(result.is_ok());
}
#[tokio::test]
async fn test_extractor_cache() {
let config = ExtractionConfig::default();
let result1 = extract_bytes(b"test 1", "text/plain", &config).await;
assert!(result1.is_ok());
let result1 = result1.unwrap();
let result2 = extract_bytes(b"test 2", "text/plain", &config).await;
assert!(result2.is_ok());
let result2 = result2.unwrap();
assert_text_content(&result1.content, "test 1");
assert_text_content(&result2.content, "test 2");
let result3 = extract_bytes(b"# test 3", "text/markdown", &config).await;
assert!(result3.is_ok());
}
#[tokio::test]
async fn test_extract_file_empty() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("empty.txt");
File::create(&file_path).unwrap();
let config = ExtractionConfig::default();
let result = extract_file(&file_path, None, &config).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.content, "");
}
#[tokio::test]
async fn test_extract_bytes_empty() {
let config = ExtractionConfig::default();
let result = extract_bytes(b"", "text/plain", &config).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.content, "");
}
#[tokio::test]
async fn test_extract_file_whitespace_only() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("whitespace.txt");
File::create(&file_path).unwrap().write_all(b" \n\t \n ").unwrap();
let config = ExtractionConfig::default();
let result = extract_file(&file_path, None, &config).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_extract_file_very_long_path() {
let dir = tempdir().unwrap();
let long_name = "a".repeat(200);
let file_path = dir.path().join(format!("{}.txt", long_name));
if let Ok(mut f) = File::create(&file_path) {
f.write_all(b"content").unwrap();
let config = ExtractionConfig::default();
let result = extract_file(&file_path, None, &config).await;
assert!(result.is_ok() || result.is_err());
}
}
#[tokio::test]
async fn test_extract_file_special_characters_in_path() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test with spaces & symbols!.txt");
File::create(&file_path).unwrap().write_all(b"content").unwrap();
let config = ExtractionConfig::default();
let result = extract_file(&file_path, None, &config).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_text_content(&result.content, "content");
}
#[tokio::test]
async fn test_extract_file_unicode_filename() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("测试文件名.txt");
File::create(&file_path).unwrap().write_all(b"content").unwrap();
let config = ExtractionConfig::default();
let result = extract_file(&file_path, None, &config).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_extract_bytes_unsupported_mime() {
let config = ExtractionConfig::default();
let result = extract_bytes(b"test", "application/x-unknown-format", &config).await;
assert!(result.is_err());
use crate::KreuzbergError;
assert!(matches!(result.unwrap_err(), KreuzbergError::UnsupportedFormat(_)));
}
#[tokio::test]
async fn test_batch_extract_file_with_errors() {
let dir = tempdir().unwrap();
let valid_file = dir.path().join("valid.txt");
File::create(&valid_file).unwrap().write_all(b"valid content").unwrap();
let invalid_file = dir.path().join("nonexistent.txt");
let config = ExtractionConfig::default();
let items = vec![
BatchFileItem {
path: valid_file,
config: None,
},
BatchFileItem {
path: invalid_file,
config: None,
},
];
let results = batch_extract_files(items, &config).await;
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 2);
assert_text_content(&results[0].content, "valid content");
assert!(results[1].metadata.error.is_some());
}
#[tokio::test]
async fn test_batch_extract_bytes_mixed_valid_invalid() {
let config = ExtractionConfig::default();
let items = vec![
BatchBytesItem {
content: b"valid 1".to_vec(),
mime_type: "text/plain".to_string(),
config: None,
},
BatchBytesItem {
content: b"invalid".to_vec(),
mime_type: "invalid/mime".to_string(),
config: None,
},
BatchBytesItem {
content: b"valid 2".to_vec(),
mime_type: "text/plain".to_string(),
config: None,
},
];
let results = batch_extract_bytes(items, &config).await;
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 3);
assert_text_content(&results[0].content, "valid 1");
assert!(results[1].metadata.error.is_some());
assert_text_content(&results[2].content, "valid 2");
}
#[tokio::test]
async fn test_batch_extract_bytes_all_invalid() {
let config = ExtractionConfig::default();
let items = vec![
BatchBytesItem {
content: b"test 1".to_vec(),
mime_type: "invalid/mime1".to_string(),
config: None,
},
BatchBytesItem {
content: b"test 2".to_vec(),
mime_type: "invalid/mime2".to_string(),
config: None,
},
];
let results = batch_extract_bytes(items, &config).await;
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 2);
assert!(results[0].metadata.error.is_some());
assert!(results[1].metadata.error.is_some());
}
#[tokio::test]
async fn test_extract_bytes_very_large() {
let large_content = vec![b'a'; 10_000_000];
let config = ExtractionConfig::default();
let result = extract_bytes(&large_content, "text/plain", &config).await;
assert!(result.is_ok());
let result = result.unwrap();
let trimmed_len = result.content.trim_end_matches('\n').len();
assert_eq!(trimmed_len, 10_000_000);
}
#[tokio::test]
async fn test_batch_extract_large_count() {
let dir = tempdir().unwrap();
let mut items = Vec::new();
for i in 0..100 {
let file_path = dir.path().join(format!("file{}.txt", i));
File::create(&file_path)
.unwrap()
.write_all(format!("content {}", i).as_bytes())
.unwrap();
items.push(BatchFileItem {
path: file_path,
config: None,
});
}
let config = ExtractionConfig::default();
let results = batch_extract_files(items, &config).await;
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 100);
for (i, result) in results.iter().enumerate() {
assert_text_content(&result.content, &format!("content {}", i));
}
}
#[tokio::test]
async fn test_extract_file_mime_detection_fallback() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("testfile");
File::create(&file_path)
.unwrap()
.write_all(b"plain text content")
.unwrap();
let config = ExtractionConfig::default();
let result = extract_file(&file_path, None, &config).await;
assert!(result.is_ok() || result.is_err());
}
#[tokio::test]
async fn test_extract_file_wrong_mime_override() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
File::create(&file_path).unwrap().write_all(b"plain text").unwrap();
let config = ExtractionConfig::default();
let result = extract_file(&file_path, Some("application/pdf"), &config).await;
assert!(result.is_err() || result.is_ok());
}
#[test]
fn test_sync_wrapper_nonexistent_file() {
let config = ExtractionConfig::default();
let result = extract_file_sync("/nonexistent/path.txt", None, &config);
assert!(result.is_err());
use crate::KreuzbergError;
// File validation returns Io error, not Validation error
assert!(matches!(result.unwrap_err(), KreuzbergError::Io { .. }));
}
#[test]
fn test_sync_wrapper_batch_empty() {
let config = ExtractionConfig::default();
let items: Vec<BatchFileItem> = vec![];
let results = batch_extract_files_sync(items, &config);
assert!(results.is_ok());
assert_eq!(results.unwrap().len(), 0);
}
#[test]
fn test_sync_wrapper_batch_bytes_empty() {
let config = ExtractionConfig::default();
let items: Vec<BatchBytesItem> = vec![];
let results = batch_extract_bytes_sync(items, &config);
assert!(results.is_ok());
assert_eq!(results.unwrap().len(), 0);
}
#[tokio::test]
async fn test_concurrent_extractions_same_mime() {
use tokio::task::JoinSet;
let config = Arc::new(ExtractionConfig::default());
let mut tasks = JoinSet::new();
for i in 0..50 {
let config_clone = Arc::clone(&config);
tasks.spawn(async move {
let content = format!("test content {}", i);
extract_bytes(content.as_bytes(), "text/plain", &config_clone).await
});
}
let mut success_count = 0;
while let Some(task_result) = tasks.join_next().await {
if let Ok(Ok(_)) = task_result {
success_count += 1;
}
}
assert_eq!(success_count, 50);
}
#[serial]
#[tokio::test]
async fn test_concurrent_extractions_different_mimes() {
use tokio::task::JoinSet;
let config = Arc::new(ExtractionConfig::default());
let mut tasks = JoinSet::new();
let mime_types = ["text/plain", "text/markdown"];
for i in 0..30 {
let config_clone = Arc::clone(&config);
let mime = mime_types[i % mime_types.len()];
tasks.spawn(async move {
let content = format!("test {}", i);
extract_bytes(content.as_bytes(), mime, &config_clone).await
});
}
let mut success_count = 0;
while let Some(task_result) = tasks.join_next().await {
if let Ok(Ok(_)) = task_result {
success_count += 1;
}
}
assert_eq!(success_count, 30);
}
#[tokio::test]
async fn test_batch_extract_file_with_per_file_configs() {
let dir = tempdir().unwrap();
let file1 = dir.path().join("test1.txt");
let file2 = dir.path().join("test2.txt");
File::create(&file1).unwrap().write_all(b"content 1").unwrap();
File::create(&file2).unwrap().write_all(b"content 2").unwrap();
let config = ExtractionConfig::default();
let items = vec![
BatchFileItem {
path: file1,
config: Some(crate::FileExtractionConfig::default()),
},
BatchFileItem {
path: file2,
config: None,
},
];
let results = batch_extract_files(items, &config).await;
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 2);
assert_text_content(&results[0].content, "content 1");
assert_text_content(&results[1].content, "content 2");
}
#[tokio::test]
async fn test_batch_extract_file_with_configs_empty() {
let config = ExtractionConfig::default();
let items: Vec<BatchFileItem> = vec![];
let results = batch_extract_files(items, &config).await;
assert!(results.is_ok());
assert_eq!(results.unwrap().len(), 0);
}
#[tokio::test]
async fn test_batch_extract_bytes_with_per_item_configs() {
let config = ExtractionConfig::default();
let items = vec![
BatchBytesItem {
content: b"hello".to_vec(),
mime_type: "text/plain".to_string(),
config: None,
},
BatchBytesItem {
content: b"world".to_vec(),
mime_type: "text/plain".to_string(),
config: Some(crate::FileExtractionConfig::default()),
},
];
let results = batch_extract_bytes(items, &config).await;
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 2);
assert_text_content(&results[0].content, "hello");
assert_text_content(&results[1].content, "world");
}
#[tokio::test]
async fn test_batch_extract_bytes_with_configs_error_handling() {
let config = ExtractionConfig::default();
let items = vec![
BatchBytesItem {
content: b"valid".to_vec(),
mime_type: "text/plain".to_string(),
config: None,
},
BatchBytesItem {
content: b"invalid".to_vec(),
mime_type: "invalid/mime".to_string(),
config: Some(crate::FileExtractionConfig::default()),
},
];
let results = batch_extract_bytes(items, &config).await;
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 2);
assert_text_content(&results[0].content, "valid");
assert!(results[1].metadata.error.is_some());
}
#[test]
fn test_batch_extract_file_sync_with_configs() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
File::create(&file_path).unwrap().write_all(b"sync test").unwrap();
let config = ExtractionConfig::default();
let items = vec![BatchFileItem {
path: file_path,
config: None,
}];
let results = batch_extract_files_sync(items, &config);
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 1);
assert_text_content(&results[0].content, "sync test");
}
#[test]
fn test_with_file_overrides_single_field() {
let base = ExtractionConfig::default();
assert!(!base.force_ocr);
let overrides = crate::FileExtractionConfig {
force_ocr: Some(true),
..Default::default()
};
let resolved = base.with_file_overrides(&overrides);
assert!(resolved.force_ocr);
// Other fields unchanged
assert_eq!(resolved.use_cache, base.use_cache);
assert_eq!(resolved.enable_quality_processing, base.enable_quality_processing);
}
#[test]
fn test_with_file_overrides_none_keeps_default() {
let base = ExtractionConfig::default();
let overrides = crate::FileExtractionConfig::default(); // all None
let resolved = base.with_file_overrides(&overrides);
// All fields should match base
assert_eq!(resolved.use_cache, base.use_cache);
assert_eq!(resolved.force_ocr, base.force_ocr);
assert_eq!(resolved.enable_quality_processing, base.enable_quality_processing);
assert_eq!(resolved.include_document_structure, base.include_document_structure);
}
#[test]
fn test_with_file_overrides_batch_fields_unaffected() {
let base = ExtractionConfig {
max_concurrent_extractions: Some(42),
use_cache: false,
..Default::default()
};
let overrides = crate::FileExtractionConfig {
force_ocr: Some(true),
..Default::default()
};
let resolved = base.with_file_overrides(&overrides);
// Batch-level fields must be preserved from base
assert_eq!(resolved.max_concurrent_extractions, Some(42));
assert!(!resolved.use_cache);
// Override applied
assert!(resolved.force_ocr);
}
}

View File

@@ -0,0 +1,200 @@
//! Synchronous wrappers for extraction operations.
//!
//! This module provides blocking synchronous wrappers around async extraction functions
//! for use in non-async contexts. Uses a global Tokio runtime for optimal performance.
use crate::Result;
use crate::core::config::BatchBytesItem;
#[cfg(feature = "tokio-runtime")]
use crate::core::config::BatchFileItem;
use crate::core::config::ExtractionConfig;
use crate::types::ExtractionResult;
#[cfg(feature = "tokio-runtime")]
use std::path::Path;
#[cfg(feature = "tokio-runtime")]
use once_cell::sync::OnceCell;
#[cfg(feature = "tokio-runtime")]
use super::batch::{batch_extract_bytes, batch_extract_files};
#[cfg(feature = "tokio-runtime")]
use super::bytes::extract_bytes;
#[cfg(feature = "tokio-runtime")]
use super::file::extract_file;
#[cfg(not(feature = "tokio-runtime"))]
use super::helpers::error_extraction_result;
/// Global Tokio runtime cell for synchronous operations.
///
/// Lazily initialized on first use and shared across all sync wrappers.
/// Using a global runtime instead of creating one per call provides 100x+ performance improvement.
///
/// # Availability
///
/// This static is only available when the `tokio-runtime` feature is enabled.
/// For WASM targets, use the truly synchronous extraction functions instead.
#[cfg(feature = "tokio-runtime")]
static GLOBAL_RUNTIME: OnceCell<tokio::runtime::Runtime> = OnceCell::new();
/// Returns a reference to the global Tokio runtime, initializing it on first call.
///
/// Returns an error if the runtime cannot be created (e.g. system resource exhaustion).
#[cfg(feature = "tokio-runtime")]
fn global_runtime() -> crate::Result<&'static tokio::runtime::Runtime> {
GLOBAL_RUNTIME.get_or_try_init(|| {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| crate::KreuzbergError::Plugin {
message: format!("Failed to create global Tokio runtime: {e}"),
plugin_name: "runtime".to_string(),
})
})
}
/// Synchronous wrapper for `extract_file`.
///
/// This is a convenience function that blocks the current thread until extraction completes.
/// For async code, use `extract_file` directly.
///
/// Uses the global Tokio runtime for 100x+ performance improvement over creating
/// a new runtime per call. Always uses the global runtime to avoid nested runtime issues.
///
/// This function is only available with the `tokio-runtime` feature. For WASM targets,
/// use a truly synchronous extraction approach instead.
///
/// # Example
///
/// ```rust,no_run
/// use kreuzberg::core::extractor::extract_file_sync;
/// use kreuzberg::core::config::ExtractionConfig;
///
/// let config = ExtractionConfig::default();
/// let result = extract_file_sync("document.pdf", None, &config)?;
/// println!("Content: {}", result.content);
/// # Ok::<(), kreuzberg::KreuzbergError>(())
/// ```
#[cfg(feature = "tokio-runtime")]
pub fn extract_file_sync(
path: impl AsRef<Path>,
mime_type: Option<&str>,
config: &ExtractionConfig,
) -> Result<ExtractionResult> {
global_runtime()?.block_on(extract_file(path, mime_type, config))
}
/// Synchronous wrapper for `extract_bytes`.
///
/// Uses the global Tokio runtime for 100x+ performance improvement over creating
/// a new runtime per call.
///
/// With the `tokio-runtime` feature, this blocks the current thread using the global
/// Tokio runtime. Without it (WASM), this calls a truly synchronous implementation.
///
/// # Example
///
/// ```rust,no_run
/// use kreuzberg::core::extractor::extract_bytes_sync;
/// use kreuzberg::core::config::ExtractionConfig;
///
/// let config = ExtractionConfig::default();
/// let bytes = b"Hello, world!";
/// let result = extract_bytes_sync(bytes, "text/plain", &config)?;
/// println!("Content: {}", result.content);
/// # Ok::<(), kreuzberg::KreuzbergError>(())
/// ```
#[cfg(feature = "tokio-runtime")]
pub fn extract_bytes_sync(content: &[u8], mime_type: &str, config: &ExtractionConfig) -> Result<ExtractionResult> {
global_runtime()?.block_on(extract_bytes(content, mime_type, config))
}
/// Synchronous wrapper for `extract_bytes` (WASM-compatible version).
///
/// This is a truly synchronous implementation without tokio runtime dependency.
/// It calls `extract_bytes_sync_impl()` to perform the extraction.
#[cfg(not(feature = "tokio-runtime"))]
pub fn extract_bytes_sync(content: &[u8], mime_type: &str, config: &ExtractionConfig) -> Result<ExtractionResult> {
super::legacy::extract_bytes_sync_impl(content, Some(mime_type), Some(config))
}
/// Synchronous wrapper for `batch_extract_files`.
///
/// Uses the global Tokio runtime for optimal performance.
/// Only available with `tokio-runtime` (WASM has no filesystem).
///
/// # Example
///
/// ```rust,no_run
/// use kreuzberg::core::extractor::batch_extract_files_sync;
/// use kreuzberg::core::config::{ExtractionConfig, BatchFileItem, FileExtractionConfig};
///
/// let config = ExtractionConfig::default();
/// let items = vec![
/// BatchFileItem {
/// path: "doc1.pdf".into(),
/// config: Some(FileExtractionConfig { force_ocr: Some(true), ..Default::default() }),
/// },
/// BatchFileItem { path: "doc2.pdf".into(), config: None },
/// ];
/// let results = batch_extract_files_sync(items, &config)?;
/// # Ok::<(), kreuzberg::KreuzbergError>(())
/// ```
#[cfg(feature = "tokio-runtime")]
pub fn batch_extract_files_sync(items: Vec<BatchFileItem>, config: &ExtractionConfig) -> Result<Vec<ExtractionResult>> {
global_runtime()?.block_on(batch_extract_files(items, config))
}
/// Synchronous wrapper for `batch_extract_bytes`.
///
/// Uses the global Tokio runtime for optimal performance.
/// With the `tokio-runtime` feature, this blocks the current thread using the global
/// Tokio runtime. Without it (WASM), this calls a truly synchronous implementation
/// that iterates through items and calls `extract_bytes_sync()`.
///
/// # Example
///
/// ```rust,no_run
/// use kreuzberg::core::extractor::batch_extract_bytes_sync;
/// use kreuzberg::core::config::{ExtractionConfig, BatchBytesItem, FileExtractionConfig};
///
/// let config = ExtractionConfig::default();
/// let items = vec![
/// BatchBytesItem { content: b"content".to_vec(), mime_type: "text/plain".to_string(), config: None },
/// BatchBytesItem {
/// content: b"other".to_vec(),
/// mime_type: "text/plain".to_string(),
/// config: Some(FileExtractionConfig { force_ocr: Some(true), ..Default::default() }),
/// },
/// ];
/// let results = batch_extract_bytes_sync(items, &config)?;
/// # Ok::<(), kreuzberg::KreuzbergError>(())
/// ```
#[cfg(feature = "tokio-runtime")]
pub fn batch_extract_bytes_sync(
items: Vec<BatchBytesItem>,
config: &ExtractionConfig,
) -> Result<Vec<ExtractionResult>> {
global_runtime()?.block_on(batch_extract_bytes(items, config))
}
/// Synchronous wrapper for `batch_extract_bytes` (WASM-compatible version).
///
/// Iterates through items sequentially, applying per-file config overrides.
#[cfg(not(feature = "tokio-runtime"))]
pub fn batch_extract_bytes_sync(
items: Vec<BatchBytesItem>,
config: &ExtractionConfig,
) -> Result<Vec<ExtractionResult>> {
let mut results = Vec::with_capacity(items.len());
for item in items {
let resolved = match &item.config {
Some(fc) => config.with_file_overrides(fc),
None => config.clone(),
};
let result = extract_bytes_sync(&item.content, &item.mime_type, &resolved);
results.push(result.unwrap_or_else(|e| error_extraction_result(&e, None)));
}
Ok(results)
}

View File

@@ -0,0 +1,237 @@
//! Format field validation and metadata.
//!
//! This module provides a compile-time registry of known format fields used across
//! document extractors. It serves as the single source of truth for format validation
//! across all language bindings (Rust, Python, TypeScript, Ruby, Java, Go).
//!
//! # Known Format Fields
//!
//! The registry contains 55 standardized format fields organized by category:
//! - **Document Properties**: title, author, keywords, creator, producer, etc.
//! - **Dates**: creation_date, modification_date
//! - **Pagination**: page_count
//! - **Email Metadata**: from_email, from_name, to_emails, cc_emails, bcc_emails, message_id
//! - **Attachments**: attachments
//! - **Descriptions**: description, summary
//! - **Typography**: fonts
//! - **Archive/Compression**: format, file_count, file_list, total_size, compressed_size
//! - **Images**: width, height
//! - **Content Metrics**: element_count, unique_elements, line_count, word_count, character_count
//! - **HTML Structure**: headers, links, code_blocks
//! - **Meta Tags**: canonical, base_href, og_*, twitter_*, link_*
//! - **OCR**: psm, output_format
//! - **Tables**: table_count, table_rows, table_cols
//!
//! # Example
//!
//! ```ignore
//! use kreuzberg::core::formats::{KNOWN_FORMATS, is_valid_format_field};
//!
//! assert!(is_valid_format_field("title"));
//! assert!(!is_valid_format_field("invalid_field"));
//! assert_eq!(KNOWN_FORMATS.len(), 55);
//! ```
#[cfg(test)]
use ahash::AHashSet;
#[cfg(test)]
use std::sync::LazyLock;
/// All known format field names across all extractors.
///
/// This is a compile-time constant array of standardized field names used by document
/// extractors. Each binding (Python, TypeScript, Ruby, Java, Go) should reference this
/// as the single source of truth for format field validation.
///
/// Format fields are organized by document type:
/// - PDF/Office: title, author, creation_date, page_count, etc.
/// - Email: from_email, to_emails, cc_emails, bcc_emails, etc.
/// - Web: og_title, twitter_card, canonical, headers, links, etc.
/// - Images: width, height, format
/// - Archives: file_count, file_list, total_size, etc.
#[cfg(test)]
pub(crate) const KNOWN_FORMATS: &[&str] = &[
"title",
"author",
"keywords",
"creator",
"producer",
"creation_date",
"modification_date",
"page_count",
"from_email",
"from_name",
"to_emails",
"cc_emails",
"bcc_emails",
"message_id",
"attachments",
"description",
"summary",
"fonts",
"format",
"file_count",
"file_list",
"total_size",
"compressed_size",
"width",
"height",
"element_count",
"unique_elements",
"line_count",
"word_count",
"character_count",
"headers",
"links",
"code_blocks",
"canonical",
"base_href",
"og_title",
"og_description",
"og_image",
"og_url",
"og_type",
"og_site_name",
"twitter_card",
"twitter_title",
"twitter_description",
"twitter_image",
"twitter_site",
"twitter_creator",
"link_author",
"link_license",
"link_alternate",
"psm",
"output_format",
"table_count",
"table_rows",
"table_cols",
];
/// Cached format field set for fast O(1) lookups.
///
/// Uses AHashSet for its excellent cache locality and performance characteristics
/// with string keys. Built lazily on first use with minimal overhead.
#[cfg(test)]
static FORMAT_FIELD_SET: LazyLock<AHashSet<&'static str>> = LazyLock::new(|| KNOWN_FORMATS.iter().copied().collect());
/// Validates whether a field name is in the known formats registry.
///
/// This uses a pre-built hash set for O(1) lookups instead of linear search,
/// providing significant performance improvements for repeated validations.
///
/// # Arguments
///
/// * `field` - The field name to validate
///
/// # Returns
///
/// `true` if the field is in KNOWN_FORMATS, `false` otherwise.
///
/// # Example
///
/// ```rust
/// use kreuzberg::core::formats::is_valid_format_field;
///
/// assert!(is_valid_format_field("title"));
/// assert!(is_valid_format_field("creation_date"));
/// assert!(!is_valid_format_field("invalid_field"));
/// ```
#[cfg(test)]
#[inline]
pub(crate) fn is_valid_format_field(field: &str) -> bool {
FORMAT_FIELD_SET.contains(field)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_known_formats_count() {
assert_eq!(KNOWN_FORMATS.len(), 55, "Expected 55 known format fields");
}
#[test]
fn test_known_formats_no_duplicates() {
let mut seen = std::collections::HashSet::new();
for field in KNOWN_FORMATS {
assert!(seen.insert(field), "Duplicate format field found: {}", field);
}
}
#[test]
fn test_is_valid_format_field_true_cases() {
assert!(is_valid_format_field("title"));
assert!(is_valid_format_field("author"));
assert!(is_valid_format_field("creation_date"));
assert!(is_valid_format_field("page_count"));
assert!(is_valid_format_field("from_email"));
assert!(is_valid_format_field("og_title"));
assert!(is_valid_format_field("twitter_card"));
}
#[test]
fn test_is_valid_format_field_false_cases() {
assert!(!is_valid_format_field("invalid_field"));
assert!(!is_valid_format_field("unknown_metadata"));
assert!(!is_valid_format_field(""));
assert!(!is_valid_format_field("TITLE"));
assert!(!is_valid_format_field("title "));
}
#[test]
fn test_all_document_property_fields() {
let doc_fields = ["title", "author", "keywords", "creator", "producer"];
for field in &doc_fields {
assert!(is_valid_format_field(field), "Missing field: {}", field);
}
}
#[test]
fn test_all_email_fields() {
let email_fields = [
"from_email",
"from_name",
"to_emails",
"cc_emails",
"bcc_emails",
"message_id",
"attachments",
];
for field in &email_fields {
assert!(is_valid_format_field(field), "Missing email field: {}", field);
}
}
#[test]
fn test_all_web_meta_fields() {
let web_fields = [
"og_title",
"og_description",
"og_image",
"og_url",
"og_type",
"og_site_name",
"twitter_card",
"twitter_title",
"twitter_description",
"twitter_image",
"twitter_site",
"twitter_creator",
"canonical",
"base_href",
];
for field in &web_fields {
assert!(is_valid_format_field(field), "Missing web field: {}", field);
}
}
#[test]
fn test_all_table_fields() {
let table_fields = ["table_count", "table_rows", "table_cols"];
for field in &table_fields {
assert!(is_valid_format_field(field), "Missing table field: {}", field);
}
}
}

View File

@@ -0,0 +1,421 @@
//! File I/O utilities.
//!
//! This module provides async and sync file reading utilities with proper error handling.
//! For large files (> 1 MiB) on non-WASM platforms, memory-mapped I/O is used to avoid
//! heap-allocating the entire file contents, reducing memory pressure and syscall overhead.
use crate::{KreuzbergError, Result};
use std::path::Path;
/// Size threshold above which memory-mapped I/O is preferred over `read()`.
///
/// Files smaller than this are read with a regular `read()` call since the
/// mmap overhead (open, fstat, mmap syscalls + TLB pressure) outweighs the
/// benefit for small allocations.
#[cfg(not(target_arch = "wasm32"))]
const MMAP_THRESHOLD_BYTES: u64 = 1_048_576; // 1 MiB
/// An owned buffer of file bytes.
///
/// On non-WASM platforms this may be backed by a memory-mapped file (zero heap
/// allocation for the file contents) or by a `Vec<u8>` for small files.
/// On WASM it is always a `Vec<u8>`.
///
/// Implements `Deref<Target = [u8]>` so callers can pass `&FileBytes` as `&[u8]`
/// without any additional copy.
#[cfg_attr(alef, alef(skip))]
pub struct FileBytes {
inner: FileBytesInner,
}
enum FileBytesInner {
/// Regular heap-allocated buffer (small files or WASM).
Heap(Vec<u8>),
/// Memory-mapped file (large files on native platforms).
#[cfg(not(target_arch = "wasm32"))]
Mapped(memmap2::Mmap),
}
impl std::ops::Deref for FileBytes {
type Target = [u8];
fn deref(&self) -> &[u8] {
match &self.inner {
FileBytesInner::Heap(v) => v.as_slice(),
#[cfg(not(target_arch = "wasm32"))]
FileBytesInner::Mapped(m) => m.as_ref(),
}
}
}
impl AsRef<[u8]> for FileBytes {
fn as_ref(&self) -> &[u8] {
self
}
}
/// Open a file and return its bytes with zero-copy for large files.
///
/// On non-WASM targets, files larger than [`MMAP_THRESHOLD_BYTES`] are
/// memory-mapped so that the file contents are never copied to the heap.
/// The mapping is read-only; the file must not be modified while the returned
/// [`FileBytes`] is alive, which is safe for document extraction.
///
/// On WASM or for small files, falls back to a plain `std::fs::read`.
///
/// # Errors
///
/// Returns `KreuzbergError::Io` for any I/O failure.
#[allow(unsafe_code)]
pub(crate) fn open_file_bytes(path: &Path) -> Result<FileBytes> {
#[cfg(not(target_arch = "wasm32"))]
{
let metadata = std::fs::metadata(path).map_err(KreuzbergError::Io)?;
if metadata.len() > MMAP_THRESHOLD_BYTES {
let file = std::fs::File::open(path).map_err(KreuzbergError::Io)?;
// SAFETY: The file is opened read-only and we do not write to the
// mapped region. The `FileBytes` value owns the `Mmap` handle and
// the mapping is live for exactly as long as the bytes are accessed.
// External modification of the file while mapped is a documented
// TOCTOU risk inherent to mmap on all platforms; it is acceptable
// here because kreuzberg only reads user-supplied documents and
// makes no correctness guarantees about files modified concurrently.
let mmap = unsafe { memmap2::Mmap::map(&file) }.map_err(KreuzbergError::Io)?;
return Ok(FileBytes {
inner: FileBytesInner::Mapped(mmap),
});
}
}
// Small file or WASM: regular heap read.
let bytes = std::fs::read(path).map_err(KreuzbergError::Io)?;
Ok(FileBytes {
inner: FileBytesInner::Heap(bytes),
})
}
/// Read a file asynchronously.
///
/// # Arguments
///
/// * `path` - Path to the file to read
///
/// # Returns
///
/// The file contents as bytes.
///
/// # Errors
///
/// Returns `KreuzbergError::Io` for I/O errors (these always bubble up).
#[cfg(feature = "tokio-runtime")]
pub(crate) async fn read_file_async(path: impl AsRef<Path>) -> Result<Vec<u8>> {
tokio::fs::read(path.as_ref()).await.map_err(KreuzbergError::Io)
}
/// Read a file synchronously.
///
/// # Arguments
///
/// * `path` - Path to the file to read
///
/// # Returns
///
/// The file contents as bytes.
///
/// # Errors
///
/// Returns `KreuzbergError::Io` for I/O errors (these always bubble up).
#[cfg(test)]
pub(crate) fn read_file_sync(path: impl AsRef<Path>) -> Result<Vec<u8>> {
std::fs::read(path.as_ref()).map_err(KreuzbergError::Io)
}
/// Check if a file exists.
///
/// # Arguments
///
/// * `path` - Path to check
///
/// # Returns
///
/// `true` if the file exists, `false` otherwise.
pub(crate) fn file_exists(path: impl AsRef<Path>) -> bool {
path.as_ref().exists()
}
/// Validate that a file exists.
///
/// # Arguments
///
/// * `path` - Path to validate
///
/// # Errors
///
/// Returns `KreuzbergError::Io` if file doesn't exist.
pub(crate) fn validate_file_exists(path: impl AsRef<Path>) -> Result<()> {
if !file_exists(&path) {
return Err(KreuzbergError::from(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("File does not exist: {}", path.as_ref().display()),
)));
}
Ok(())
}
/// Traverse a directory and return all file paths matching a pattern.
///
/// # Arguments
///
/// * `dir` - Directory to traverse
/// * `recursive` - Whether to recursively traverse subdirectories
/// * `filter` - Optional filter function to match files
///
/// # Returns
///
/// Vector of file paths that match the criteria.
///
/// # Errors
///
/// Returns `KreuzbergError::Io` for I/O errors.
#[cfg(test)]
pub(crate) fn traverse_directory<F>(
dir: impl AsRef<Path>,
recursive: bool,
filter: Option<F>,
) -> Result<Vec<std::path::PathBuf>>
where
F: Fn(&Path) -> bool,
{
let dir = dir.as_ref();
let mut files = Vec::new();
if !dir.is_dir() {
return Err(KreuzbergError::from(std::io::Error::new(
std::io::ErrorKind::NotADirectory,
format!("Path is not a directory: {}", dir.display()),
)));
}
traverse_directory_impl(dir, recursive, &filter, &mut files)?;
Ok(files)
}
#[cfg(test)]
fn traverse_directory_impl<F>(
dir: &Path,
recursive: bool,
filter: &Option<F>,
files: &mut Vec<std::path::PathBuf>,
) -> Result<()>
where
F: Fn(&Path) -> bool,
{
let entries = std::fs::read_dir(dir).map_err(KreuzbergError::Io)?;
for entry in entries {
let entry = entry.map_err(KreuzbergError::Io)?;
let path = entry.path();
if path.is_file() {
let should_include = match filter {
Some(f) => f(&path),
None => true,
};
if should_include {
files.push(path);
}
} else if path.is_dir() && recursive {
traverse_directory_impl(&path, recursive, filter, files)?;
}
}
Ok(())
}
/// Get all files in a directory with a specific extension.
///
/// # Arguments
///
/// * `dir` - Directory to search
/// * `extension` - File extension to match (without the dot)
/// * `recursive` - Whether to recursively search subdirectories
///
/// # Returns
///
/// Vector of file paths with the specified extension.
///
/// # Errors
///
/// Returns `KreuzbergError::Io` for I/O errors.
#[cfg(test)]
pub(crate) fn find_files_by_extension(
dir: impl AsRef<Path>,
extension: &str,
recursive: bool,
) -> Result<Vec<std::path::PathBuf>> {
let ext = extension.to_lowercase();
traverse_directory(
dir,
recursive,
Some(|path: &Path| {
path.extension()
.and_then(|e| e.to_str())
.map(|e| e.to_lowercase() == ext)
.unwrap_or(false)
}),
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use std::io::Write;
use tempfile::tempdir;
#[cfg(feature = "tokio-runtime")]
#[tokio::test]
async fn test_read_file_async() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
let mut file = File::create(&file_path).unwrap();
file.write_all(b"test content").unwrap();
let content = read_file_async(&file_path).await.unwrap();
assert_eq!(content, b"test content");
}
#[test]
fn test_read_file_sync() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
let mut file = File::create(&file_path).unwrap();
file.write_all(b"test content").unwrap();
let content = read_file_sync(&file_path).unwrap();
assert_eq!(content, b"test content");
}
#[test]
fn test_file_exists() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
File::create(&file_path).unwrap();
assert!(file_exists(&file_path));
assert!(!file_exists(dir.path().join("nonexistent.txt")));
}
#[test]
fn test_validate_file_exists() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
File::create(&file_path).unwrap();
assert!(validate_file_exists(&file_path).is_ok());
assert!(validate_file_exists(dir.path().join("nonexistent.txt")).is_err());
}
#[test]
fn test_traverse_directory_non_recursive() {
let dir = tempdir().unwrap();
File::create(dir.path().join("file1.txt")).unwrap();
File::create(dir.path().join("file2.pdf")).unwrap();
File::create(dir.path().join("file3.txt")).unwrap();
std::fs::create_dir(dir.path().join("subdir")).unwrap();
File::create(dir.path().join("subdir").join("file4.txt")).unwrap();
let files = traverse_directory(dir.path(), false, None::<fn(&Path) -> bool>).unwrap();
assert_eq!(files.len(), 3);
}
#[test]
fn test_traverse_directory_recursive() {
let dir = tempdir().unwrap();
File::create(dir.path().join("file1.txt")).unwrap();
File::create(dir.path().join("file2.pdf")).unwrap();
std::fs::create_dir(dir.path().join("subdir")).unwrap();
File::create(dir.path().join("subdir").join("file3.txt")).unwrap();
File::create(dir.path().join("subdir").join("file4.pdf")).unwrap();
let files = traverse_directory(dir.path(), true, None::<fn(&Path) -> bool>).unwrap();
assert_eq!(files.len(), 4);
}
#[test]
fn test_traverse_directory_with_filter() {
let dir = tempdir().unwrap();
File::create(dir.path().join("file1.txt")).unwrap();
File::create(dir.path().join("file2.pdf")).unwrap();
File::create(dir.path().join("file3.txt")).unwrap();
let files = traverse_directory(
dir.path(),
false,
Some(|path: &Path| {
path.extension()
.and_then(|e| e.to_str())
.map(|e| e == "txt")
.unwrap_or(false)
}),
)
.unwrap();
assert_eq!(files.len(), 2);
assert!(files.iter().all(|p| p.extension().unwrap() == "txt"));
}
#[test]
fn test_find_files_by_extension() {
let dir = tempdir().unwrap();
File::create(dir.path().join("file1.txt")).unwrap();
File::create(dir.path().join("file2.pdf")).unwrap();
File::create(dir.path().join("file3.TXT")).unwrap();
std::fs::create_dir(dir.path().join("subdir")).unwrap();
File::create(dir.path().join("subdir").join("file4.txt")).unwrap();
let files = find_files_by_extension(dir.path(), "txt", false).unwrap();
assert_eq!(files.len(), 2);
let files_recursive = find_files_by_extension(dir.path(), "txt", true).unwrap();
assert_eq!(files_recursive.len(), 3);
}
#[test]
fn test_traverse_directory_invalid_path() {
let result = traverse_directory("/nonexistent/directory", false, None::<fn(&Path) -> bool>);
assert!(result.is_err());
}
#[test]
fn test_traverse_directory_file_not_dir() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
File::create(&file_path).unwrap();
let result = traverse_directory(&file_path, false, None::<fn(&Path) -> bool>);
assert!(result.is_err());
}
#[cfg(feature = "tokio-runtime")]
#[tokio::test]
async fn test_read_file_async_io_error() {
let result = read_file_async("/nonexistent/file.txt").await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), KreuzbergError::Io(_)));
}
#[test]
fn test_read_file_sync_io_error() {
let result = read_file_sync("/nonexistent/file.txt");
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), KreuzbergError::Io(_)));
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,61 @@
//! Core extraction orchestration module.
//!
//! This module contains the main extraction logic and orchestration layer for Kreuzberg.
//! It provides the primary entry points for file and bytes extraction, manages the
//! extractor registry, MIME type detection, configuration, and post-processing pipeline.
//!
//! # Architecture
//!
//! The core module is responsible for:
//! - **Entry Points**: Main `extract_file()` and `extract_bytes()` functions
//! - **Registry**: Mapping MIME types to extractors with priority-based selection
//! - **MIME Detection**: Detecting and validating MIME types from files and extensions
//! - **Pipeline**: Orchestrating post-processing steps (chunking, quality, etc.)
//! - **Configuration**: Loading and managing extraction configuration
//! - **I/O**: File reading and validation utilities
//!
//! # Example
//!
//! ```rust,no_run
//! use kreuzberg::core::extractor::extract_file;
//! use kreuzberg::core::config::ExtractionConfig;
//!
//! # async fn example() -> kreuzberg::Result<()> {
//! let config = ExtractionConfig::default();
//! let result = extract_file("document.pdf", None, &config).await?;
//! println!("Extracted content: {}", result.content);
//! # Ok(())
//! # }
//! ```
#[cfg(feature = "tokio-runtime")]
pub mod batch_mode;
#[cfg(feature = "tokio-runtime")]
pub mod batch_optimizations;
pub mod config;
pub mod config_validation;
pub mod extractor;
pub mod formats;
pub mod io;
pub mod mime;
pub(crate) mod path_resolver;
pub mod pipeline;
#[cfg(feature = "api-types")]
pub mod server_config;
#[cfg(feature = "pdf")]
pub use config::HierarchyConfig;
pub use config::{
ChunkingConfig, EmbeddingConfig, EmbeddingModelType, ExtractionConfig, ImageExtractionConfig,
LanguageDetectionConfig, OcrConfig, OutputFormat, PageConfig, PostProcessorConfig, TokenReductionOptions,
};
#[cfg(feature = "api-types")]
pub use server_config::ServerConfig;
#[cfg(feature = "tokio-runtime")]
pub use batch_optimizations::{BatchProcessor, BatchProcessorConfig};
#[cfg(feature = "pdf")]
pub use config::PdfConfig;
#[cfg(feature = "tokio-runtime")]
pub use extractor::{batch_extract_bytes, batch_extract_files};
pub use extractor::{extract_bytes, extract_file};

View File

@@ -0,0 +1,282 @@
//! Image path resolution for markup extractors.
//!
//! Resolves relative image paths found in markup documents (Markdown, LaTeX, RST,
//! Org-mode, Typst, Djot, DocBook) to actual filesystem paths, reads the image data,
//! and attaches them to the extraction result.
use std::borrow::Cow;
use std::path::{Path, PathBuf};
use bytes::Bytes;
use crate::core::config::ExtractionConfig;
use crate::types::ExtractedImage;
use crate::types::internal::InternalDocument;
use crate::types::uri::UriKind;
/// Maximum image file size: 50 MB.
const MAX_IMAGE_SIZE: u64 = 50 * 1024 * 1024;
/// Resolve a relative image reference against a base directory.
///
/// Returns `None` for URLs, absolute paths, and paths that escape `base_dir`
/// via traversal (`..`). Returns `Some(resolved)` for safe relative paths.
///
/// This function performs no filesystem access — it only validates the
/// structural safety of the path.
pub(crate) fn resolve_image_path(base_dir: &Path, image_ref: &str) -> Option<PathBuf> {
let trimmed = image_ref.trim();
// Reject URLs
if trimmed.starts_with("http://")
|| trimmed.starts_with("https://")
|| trimmed.starts_with("data:")
|| trimmed.starts_with("ftp://")
|| trimmed.starts_with("mailto:")
{
return None;
}
// Strip file:// or file: prefix (org-mode uses file: without //)
let path_str = if let Some(stripped) = trimmed.strip_prefix("file://") {
stripped
} else if let Some(stripped) = trimmed.strip_prefix("file:") {
stripped
} else {
trimmed
};
// Reject absolute paths (Unix or Windows drive letter)
if path_str.starts_with('/')
|| (path_str.len() >= 2 && path_str.as_bytes()[0].is_ascii_alphabetic() && path_str.as_bytes()[1] == b':')
{
return None;
}
let joined = base_dir.join(path_str);
let normalized = normalize_path(&joined);
// Path traversal prevention: resolved path must start with base_dir
let norm_base = normalize_path(base_dir);
if !normalized.starts_with(&norm_base) {
return None;
}
Some(normalized)
}
/// Read an image file and produce an `ExtractedImage`.
///
/// Returns `None` if the file does not exist, is not a regular file,
/// exceeds the size limit, or has an unrecognised extension.
pub(crate) fn read_image_file(path: &Path, image_index: u32) -> Option<ExtractedImage> {
let meta = std::fs::metadata(path).ok()?;
if !meta.is_file() {
return None;
}
if meta.len() > MAX_IMAGE_SIZE {
return None;
}
let ext = path
.extension()
.and_then(|e| e.to_str())
.map(|s| s.to_ascii_lowercase())?;
let format: Cow<'static, str> = match ext.as_str() {
"png" => Cow::Borrowed("png"),
"jpg" | "jpeg" => Cow::Borrowed("jpeg"),
"gif" => Cow::Borrowed("gif"),
"webp" => Cow::Borrowed("webp"),
"svg" => Cow::Borrowed("svg"),
"bmp" => Cow::Borrowed("bmp"),
"tiff" | "tif" => Cow::Borrowed("tiff"),
"avif" => Cow::Borrowed("avif"),
_ => return None,
};
let data = std::fs::read(path).ok()?;
let source_path = path.to_string_lossy().into_owned();
Some(ExtractedImage {
data: Bytes::from(data),
format,
image_index,
page_number: None,
width: None,
height: None,
colorspace: None,
bits_per_component: None,
is_mask: false,
description: None,
ocr_result: None,
bounding_box: None,
source_path: Some(source_path),
image_kind: None,
kind_confidence: None,
cluster_id: None,
})
}
/// Resolve image URIs in an `InternalDocument` to actual image data.
///
/// Iterates over all `UriKind::Image` entries, resolves them relative to
/// `base_dir`, reads the file, and appends the result to `doc.images`.
/// No-op when image extraction is disabled in `config`.
pub(crate) fn resolve_image_uris(doc: &mut InternalDocument, base_dir: &Path, config: &ExtractionConfig) {
let image_extraction_enabled = config.images.as_ref().is_some_and(|img| img.extract_images);
if !image_extraction_enabled {
return;
}
let mut image_index = doc.images.len() as u32;
// Collect URI indices first to avoid borrow conflict (doc.uris vs doc.images).
let image_uri_indices: Vec<usize> = doc
.uris
.iter()
.enumerate()
.filter(|(_, uri)| uri.kind == UriKind::Image)
.map(|(i, _)| i)
.collect();
for idx in image_uri_indices {
if let Some(resolved) = resolve_image_path(base_dir, &doc.uris[idx].url)
&& let Some(img) = read_image_file(&resolved, image_index)
{
doc.images.push(img);
image_index += 1;
}
}
}
/// Read a file, extract via `extract_bytes`, then resolve image URIs.
///
/// Shared helper for markup extractors (Markdown, LaTeX, RST, Org-mode, Typst,
/// Djot, DocBook, MDX) that need to resolve relative image paths after extraction.
pub(crate) async fn extract_file_with_image_resolution(
extractor: &(dyn crate::plugins::DocumentExtractor + Sync),
path: &Path,
mime_type: &str,
config: &ExtractionConfig,
) -> crate::Result<InternalDocument> {
let bytes = crate::core::io::open_file_bytes(path)?;
let mut doc = extractor.extract_bytes(&bytes, mime_type, config).await?;
if let Some(base_dir) = path.parent() {
resolve_image_uris(&mut doc, base_dir, config);
}
Ok(doc)
}
/// Normalize a path by resolving `.` and `..` components without filesystem access.
fn normalize_path(path: &Path) -> PathBuf {
let mut components = Vec::new();
for component in path.components() {
match component {
std::path::Component::ParentDir => {
components.pop();
}
std::path::Component::CurDir => {}
c => components.push(c),
}
}
components.iter().collect()
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::Path;
#[test]
fn test_resolve_relative_path() {
let base = Path::new("/home/user/docs");
let result = resolve_image_path(base, "images/photo.png");
assert_eq!(result, Some(PathBuf::from("/home/user/docs/images/photo.png")));
}
#[test]
fn test_resolve_nested_relative() {
let base = Path::new("/project/content");
let result = resolve_image_path(base, "images/subfolder/nested.png");
assert_eq!(
result,
Some(PathBuf::from("/project/content/images/subfolder/nested.png"))
);
}
#[test]
fn test_reject_absolute_path() {
let base = Path::new("/home/user/docs");
assert_eq!(resolve_image_path(base, "/etc/passwd"), None);
}
#[test]
fn test_reject_traversal() {
let base = Path::new("/home/user/docs");
assert_eq!(resolve_image_path(base, "../../etc/passwd"), None);
}
#[test]
fn test_reject_http_url() {
let base = Path::new("/home/user/docs");
assert_eq!(resolve_image_path(base, "https://example.com/img.png"), None);
}
#[test]
fn test_reject_data_uri() {
let base = Path::new("/home/user/docs");
assert_eq!(resolve_image_path(base, "data:image/png;base64,abc"), None);
}
#[test]
fn test_nonexistent_file_still_resolves() {
// resolve_image_path only checks structure, not filesystem
let base = Path::new("/nonexistent/base");
let result = resolve_image_path(base, "sub/image.jpg");
assert_eq!(result, Some(PathBuf::from("/nonexistent/base/sub/image.jpg")));
}
#[test]
fn test_path_with_spaces() {
let base = Path::new("/home/user/my docs");
let result = resolve_image_path(base, "my images/photo.png");
assert_eq!(result, Some(PathBuf::from("/home/user/my docs/my images/photo.png")));
}
#[test]
fn test_windows_backslash() {
// On all platforms, std::path::Path::join handles separators correctly.
// On Unix, backslash is a valid filename char so it stays as-is in the component.
// The key point: the function does not panic and produces a usable path.
let base = Path::new("/home/user/docs");
let result = resolve_image_path(base, "images/photo.png");
assert!(result.is_some());
}
#[test]
fn test_reject_ftp_url() {
let base = Path::new("/home/user/docs");
assert_eq!(resolve_image_path(base, "ftp://server/img.png"), None);
}
#[test]
fn test_reject_mailto() {
let base = Path::new("/home/user/docs");
assert_eq!(resolve_image_path(base, "mailto:user@example.com"), None);
}
#[test]
fn test_file_uri_stripped() {
let base = Path::new("/home/user/docs");
let result = resolve_image_path(base, "file://images/photo.png");
assert_eq!(result, Some(PathBuf::from("/home/user/docs/images/photo.png")));
}
#[test]
fn test_reject_windows_absolute() {
let base = Path::new("/home/user/docs");
assert_eq!(resolve_image_path(base, "C:\\Windows\\img.png"), None);
}
}

View File

@@ -0,0 +1,46 @@
//! Processor caching to reduce lock contention.
//!
//! This module manages the caching of post-processors by processing stage,
//! eliminating repeated registry lock acquisitions.
use crate::Result;
use crate::plugins::{PostProcessor, ProcessingStage};
use parking_lot::RwLock;
use std::sync::Arc;
use std::sync::LazyLock;
/// Cached post-processors for each stage to reduce lock contention.
///
/// This cache is populated once during the first pipeline run and reused
/// for all subsequent extractions, eliminating 3 of 4 registry lock acquisitions
/// per extraction.
pub(super) struct ProcessorCache {
pub(super) early: Arc<Vec<Arc<dyn PostProcessor>>>,
pub(super) middle: Arc<Vec<Arc<dyn PostProcessor>>>,
pub(super) late: Arc<Vec<Arc<dyn PostProcessor>>>,
}
impl ProcessorCache {
/// Create a new processor cache by fetching from the registry.
pub(super) fn new() -> Result<Self> {
let processor_registry = crate::plugins::registry::get_post_processor_registry();
let registry = processor_registry.read();
Ok(Self {
early: Arc::new(registry.get_for_stage(ProcessingStage::Early)),
middle: Arc::new(registry.get_for_stage(ProcessingStage::Middle)),
late: Arc::new(registry.get_for_stage(ProcessingStage::Late)),
})
}
}
/// Lazy processor cache - initialized on first use, then cached.
pub(super) static PROCESSOR_CACHE: LazyLock<RwLock<Option<ProcessorCache>>> = LazyLock::new(|| RwLock::new(None));
/// Clear the processor cache (primarily for testing when registry changes).
#[cfg_attr(alef, alef(skip))]
pub fn clear_processor_cache() -> Result<()> {
let mut cache = PROCESSOR_CACHE.write();
*cache = None;
Ok(())
}

View File

@@ -0,0 +1,128 @@
//! Core processor execution logic.
//!
//! This module handles the execution of post-processors and validators
//! in the correct order.
use crate::core::config::ExtractionConfig;
use crate::plugins::ProcessingStage;
use crate::types::{ExtractionResult, ProcessingWarning};
use crate::{KreuzbergError, Result};
use std::borrow::Cow;
#[cfg(feature = "otel")]
use std::time::Instant;
#[cfg(feature = "otel")]
use tracing::Instrument;
/// Execute all registered post-processors by stage.
pub(super) async fn execute_processors(
result: &mut ExtractionResult,
config: &ExtractionConfig,
pp_config: &Option<&crate::core::config::PostProcessorConfig>,
early_processors: std::sync::Arc<Vec<std::sync::Arc<dyn crate::plugins::PostProcessor>>>,
middle_processors: std::sync::Arc<Vec<std::sync::Arc<dyn crate::plugins::PostProcessor>>>,
late_processors: std::sync::Arc<Vec<std::sync::Arc<dyn crate::plugins::PostProcessor>>>,
) -> Result<()> {
for (_stage, processors_arc) in [
(ProcessingStage::Early, early_processors),
(ProcessingStage::Middle, middle_processors),
(ProcessingStage::Late, late_processors),
] {
#[cfg(feature = "otel")]
let stage_name = match _stage {
ProcessingStage::Early => crate::telemetry::conventions::stages::POST_PROCESSING_EARLY,
ProcessingStage::Middle => crate::telemetry::conventions::stages::POST_PROCESSING_MIDDLE,
ProcessingStage::Late => crate::telemetry::conventions::stages::POST_PROCESSING_LATE,
};
#[cfg(feature = "otel")]
let stage_span = crate::telemetry::spans::pipeline_stage_span(stage_name);
#[cfg(feature = "otel")]
let stage_start = Instant::now();
#[cfg(feature = "otel")]
let _stage_guard = stage_span.enter();
for processor in processors_arc.iter() {
let processor_name = processor.name();
let should_run = should_processor_run(pp_config, processor_name);
if should_run && processor.should_process(result, config) {
#[cfg(feature = "otel")]
let processor_span = crate::telemetry::spans::pipeline_processor_span(stage_name, processor_name);
#[cfg(feature = "otel")]
let process_result = processor.process(result, config).instrument(processor_span).await;
#[cfg(not(feature = "otel"))]
let process_result = processor.process(result, config).await;
match process_result {
Ok(_) => {}
Err(err @ KreuzbergError::Io(_))
| Err(err @ KreuzbergError::LockPoisoned(_))
| Err(err @ KreuzbergError::Plugin { .. }) => {
return Err(err);
}
Err(err) => {
let error_msg = err.to_string();
result.processing_warnings.push(ProcessingWarning {
source: Cow::Owned(processor_name.to_string()),
message: Cow::Owned(error_msg),
});
}
}
}
}
#[cfg(feature = "otel")]
{
let stage_ms = stage_start.elapsed().as_secs_f64() * 1000.0;
crate::telemetry::metrics::get_metrics().pipeline_duration_ms.record(
stage_ms,
&[opentelemetry::KeyValue::new(
crate::telemetry::conventions::PIPELINE_STAGE,
stage_name.to_string(),
)],
);
drop(_stage_guard);
drop(stage_span);
}
}
Ok(())
}
/// Determine if a processor should run based on configuration.
fn should_processor_run(pp_config: &Option<&crate::core::config::PostProcessorConfig>, processor_name: &str) -> bool {
if let Some(config) = pp_config {
if let Some(ref enabled_set) = config.enabled_set {
enabled_set.contains(processor_name)
} else if let Some(ref disabled_set) = config.disabled_set {
!disabled_set.contains(processor_name)
} else if let Some(ref enabled) = config.enabled_processors {
enabled.iter().any(|name| name == processor_name)
} else if let Some(ref disabled) = config.disabled_processors {
!disabled.iter().any(|name| name == processor_name)
} else {
true
}
} else {
true
}
}
/// Execute all registered validators.
pub(super) async fn execute_validators(result: &ExtractionResult, config: &ExtractionConfig) -> Result<()> {
let validator_registry = crate::plugins::registry::get_validator_registry();
let validators = {
let registry = validator_registry.read();
registry.get_all()
};
if !validators.is_empty() {
for validator in validators {
if validator.should_validate(result, config) {
validator.validate(result, config).await?;
}
}
}
Ok(())
}

View File

@@ -0,0 +1,461 @@
//! Feature processing logic.
//!
//! This module handles feature-specific processing like chunking,
//! embedding generation, and language detection.
use crate::Result;
use crate::core::config::ExtractionConfig;
#[cfg(feature = "chunking")]
use crate::types::PageBoundary;
use crate::types::{ExtractionResult, ProcessingWarning};
use std::borrow::Cow;
/// Recompute page boundaries against the rendered `content` string.
///
/// `PageBoundary` offsets produced during extraction are computed against raw
/// rendered/source text, but `result.content` is produced by `render_plain` which
/// trims trailing whitespace from each paragraph. The raw page text therefore has
/// different byte lengths for pages that contain trailing-space artifacts from PDF
/// rendering. This function re-derives the boundaries by locating each page's
/// **paragraph-normalised** content (each `"\n\n"`-separated segment trimmed, then
/// re-joined) inside the combined `content` string, so that the byte offsets passed
/// to the chunker are valid indices into `result.content`.
///
/// Pages whose content cannot be found are silently skipped (the chunker will
/// still produce output, just without page-range metadata for those pages).
#[cfg(feature = "chunking")]
fn recompute_boundaries_from_pages(content: &str, pages: &[crate::types::PageContent]) -> Vec<PageBoundary> {
let mut boundaries = Vec::with_capacity(pages.len());
let mut search_offset = 0usize;
for page in pages {
if page.content.trim().is_empty() {
boundaries.push(PageBoundary {
page_number: page.page_number,
byte_start: search_offset,
byte_end: search_offset,
});
continue;
}
// Normalise page content to match what render_plain produces: split on the
// paragraph separator, trim each segment (PDF pages often carry trailing
// spaces before "\n\n" that render_plain strips via paragraph.trim()), then
// re-join. Using the normalised form means exact-match succeeds and the
// resulting byte_end is correct — avoiding cascading search_offset
// over-advance that would push past subsequent pages.
let normalized: String = page
.content
.split("\n\n")
.map(str::trim)
.filter(|s| !s.is_empty())
.collect::<Vec<_>>()
.join("\n\n");
// Try normalised-exact match (primary path — handles trailing-space pages).
if let Some(pos) = content[search_offset..].find(normalized.as_str()) {
let byte_start = search_offset + pos;
let byte_end = content.floor_char_boundary(byte_start + normalized.len());
boundaries.push(PageBoundary {
page_number: page.page_number,
byte_start,
byte_end,
});
search_offset = byte_end;
continue;
}
// Fallback: search for first non-empty line of page content.
// Use normalized.len() for byte_end so search_offset advances correctly.
if let Some(line) = page.content.lines().find(|l| !l.trim().is_empty()).map(|l| l.trim())
&& let Some(pos) = content[search_offset..].find(line)
{
let byte_start = search_offset + pos;
let raw_end = (byte_start + normalized.len()).min(content.len());
let byte_end = content.floor_char_boundary(raw_end);
boundaries.push(PageBoundary {
page_number: page.page_number,
byte_start,
byte_end,
});
search_offset = byte_end;
continue;
}
// Last resort: skip this page
tracing::debug!(
page = page.page_number,
"Could not locate page content in rendered text — skipping page boundary"
);
}
boundaries
}
/// Map TSLP `CodeChunk`s directly to kreuzberg `Chunk`s, bypassing text-splitter.
///
/// When the extraction result contains code intelligence with non-empty chunks,
/// those chunks already represent semantically meaningful code boundaries produced
/// by tree-sitter. Using text-splitter would break these boundaries.
#[cfg(feature = "tree-sitter")]
fn try_code_chunks(result: &ExtractionResult) -> Option<Vec<crate::types::extraction::Chunk>> {
use crate::types::extraction::{Chunk, ChunkMetadata, ChunkType, HeadingContext, HeadingLevel};
let code_chunks = match &result.metadata.format {
Some(crate::types::metadata::FormatMetadata::Code(pr)) if !pr.chunks.is_empty() => &pr.chunks,
_ => return None,
};
let total_chunks = code_chunks.len();
let chunks: Vec<Chunk> = code_chunks
.iter()
.enumerate()
.map(|(i, cc)| {
// All code chunks are classified as CodeBlock regardless of node type.
let chunk_type = ChunkType::CodeBlock;
// Build heading context from context_path.
let heading_context = if cc.metadata.context_path.is_empty() {
None
} else {
Some(HeadingContext {
headings: cc
.metadata
.context_path
.iter()
.enumerate()
.map(|(depth, name)| HeadingLevel {
level: (depth as u8).saturating_add(2).min(6),
text: name.clone(),
})
.collect(),
})
};
Chunk {
content: cc.content.clone(),
chunk_type,
embedding: None,
metadata: ChunkMetadata {
byte_start: cc.start_byte,
byte_end: cc.end_byte,
token_count: None,
chunk_index: i,
total_chunks,
first_page: None,
last_page: None,
heading_context,
image_indices: Vec::new(),
},
}
})
.collect();
Some(chunks)
}
/// Execute chunking if configured.
pub(super) fn execute_chunking(result: &mut ExtractionResult, config: &ExtractionConfig) -> Result<()> {
#[cfg(feature = "chunking")]
if let Some(ref chunking_config) = config.chunking {
// For code extractions with TSLP chunks, bypass text-splitter and map directly.
#[cfg(feature = "tree-sitter")]
if let Some(code_chunks) = try_code_chunks(result) {
result.chunks = Some(code_chunks);
let resolved_config = chunking_config.resolve_preset();
#[cfg(feature = "embeddings")]
if let Some(ref embedding_config) = resolved_config.embedding
&& let Some(ref mut chunks) = result.chunks
&& let Err(e) = crate::embeddings::generate_embeddings_for_chunks(chunks, embedding_config)
{
tracing::warn!("Embedding generation failed: {e}. Check that ONNX Runtime is installed.");
result.processing_warnings.push(ProcessingWarning {
source: Cow::Borrowed("embedding"),
message: Cow::Owned(e.to_string()),
});
}
#[cfg(not(feature = "embeddings"))]
if resolved_config.embedding.is_some() {
tracing::warn!(
"Embedding config provided but embeddings feature is not enabled. Recompile with --features embeddings."
);
result.processing_warnings.push(ProcessingWarning {
source: Cow::Borrowed("embedding"),
message: Cow::Borrowed("Embeddings feature not enabled"),
});
}
return Ok(());
}
let resolved_config = chunking_config.resolve_preset();
let chunking_config = &resolved_config;
// Recompute page boundaries against `result.content` (rendered by `render_plain`)
// if per-page content is available. The boundaries stored in
// `result.metadata.pages.boundaries` were computed against the raw extractor text
// and may have different byte offsets than the rendered content.
let recomputed_boundaries: Option<Vec<PageBoundary>> = result
.pages
.as_deref()
.map(|pages| recompute_boundaries_from_pages(&result.content, pages));
let page_boundaries: Option<&[PageBoundary]> = recomputed_boundaries
.as_deref()
.filter(|s| !s.is_empty())
.or_else(|| result.metadata.pages.as_ref().and_then(|ps| ps.boundaries.as_deref()));
// Pass formatted_content (markdown) for heading context resolution when available.
// Plain-text rendering strips heading markers, but the markdown chunker needs them
// to build the heading hierarchy for chunk metadata.
let heading_source = result.formatted_content.as_deref();
match crate::chunking::chunk_text_with_heading_source(
&result.content,
chunking_config,
page_boundaries,
heading_source,
) {
Ok(chunking_result) => {
result.chunks = Some(chunking_result.chunks);
// Populate image_indices on each chunk: collect indices of images whose
// page_number falls within the chunk's [first_page, last_page] range.
if let Some(ref images) = result.images
&& let Some(ref mut chunks) = result.chunks
{
for chunk in chunks.iter_mut() {
if let (Some(first), Some(last)) = (chunk.metadata.first_page, chunk.metadata.last_page) {
chunk.metadata.image_indices = images
.iter()
.enumerate()
.filter_map(|(idx, img)| {
let pg = img.page_number?;
(pg >= first && pg <= last).then_some(idx as u32)
})
.collect();
}
}
}
#[cfg(feature = "embeddings")]
if let Some(ref embedding_config) = chunking_config.embedding
&& let Some(ref mut chunks) = result.chunks
&& let Err(e) = crate::embeddings::generate_embeddings_for_chunks(chunks, embedding_config)
{
tracing::warn!("Embedding generation failed: {e}. Check that ONNX Runtime is installed.");
result.processing_warnings.push(ProcessingWarning {
source: Cow::Borrowed("embedding"),
message: Cow::Owned(e.to_string()),
});
}
#[cfg(not(feature = "embeddings"))]
if chunking_config.embedding.is_some() {
tracing::warn!(
"Embedding config provided but embeddings feature is not enabled. Recompile with --features embeddings."
);
result.processing_warnings.push(ProcessingWarning {
source: Cow::Borrowed("embedding"),
message: Cow::Borrowed("Embeddings feature not enabled"),
});
}
}
Err(e) => {
result.processing_warnings.push(ProcessingWarning {
source: Cow::Borrowed("chunking"),
message: Cow::Owned(e.to_string()),
});
}
}
}
#[cfg(not(feature = "chunking"))]
if config.chunking.is_some() {
result.processing_warnings.push(ProcessingWarning {
source: Cow::Borrowed("chunking"),
message: Cow::Borrowed("Chunking feature not enabled"),
});
}
Ok(())
}
/// Execute language detection if configured.
pub(super) fn execute_language_detection(result: &mut ExtractionResult, config: &ExtractionConfig) -> Result<()> {
#[cfg(feature = "language-detection")]
if let Some(ref lang_config) = config.language_detection {
match crate::language_detection::detect_languages(&result.content, lang_config) {
Ok(detected) => {
result.detected_languages = detected;
}
Err(e) => {
result.processing_warnings.push(ProcessingWarning {
source: Cow::Borrowed("language_detection"),
message: Cow::Owned(e.to_string()),
});
}
}
}
#[cfg(not(feature = "language-detection"))]
if config.language_detection.is_some() {
result.processing_warnings.push(ProcessingWarning {
source: Cow::Borrowed("language_detection"),
message: Cow::Borrowed("Language detection feature not enabled"),
});
}
Ok(())
}
/// Execute token reduction if configured.
pub(super) fn execute_token_reduction(result: &mut ExtractionResult, config: &ExtractionConfig) -> Result<()> {
#[cfg(feature = "quality")]
if let Some(ref tr_config) = config.token_reduction {
let level = crate::text::token_reduction::ReductionLevel::from(tr_config.mode.as_str());
if !matches!(level, crate::text::token_reduction::ReductionLevel::Off) {
let impl_config = crate::text::token_reduction::TokenReductionConfig {
level,
..Default::default()
};
let lang_hint: Option<&str> = result
.detected_languages
.as_deref()
.and_then(|langs| langs.first().map(|s| s.as_str()));
match crate::text::token_reduction::reduce_tokens(&result.content, &impl_config, lang_hint) {
Ok(reduced) => {
result.content = reduced;
}
Err(e) => {
result.processing_warnings.push(ProcessingWarning {
source: Cow::Borrowed("token_reduction"),
message: Cow::Owned(e.to_string()),
});
}
}
}
}
#[cfg(not(feature = "quality"))]
if config.token_reduction.is_some() {
result.processing_warnings.push(ProcessingWarning {
source: Cow::Borrowed("token_reduction"),
message: Cow::Borrowed("Token reduction requires the quality feature"),
});
}
Ok(())
}
#[cfg(all(test, feature = "chunking"))]
mod tests {
use super::*;
use crate::types::PageContent;
fn make_page(page_number: u32, content: impl Into<String>) -> PageContent {
PageContent {
page_number,
content: content.into(),
tables: vec![],
image_indices: vec![],
hierarchy: None,
is_blank: None,
layout_regions: None,
section_name: None,
speaker_notes: None,
sheet_name: None,
}
}
// When PageContent.content matches result.content exactly, all boundaries succeed.
#[test]
fn recompute_boundaries_exact_match_produces_full_boundary_set() {
let p1 = "Hello world";
let p2 = "Second page text";
let p3 = "Third page here";
let content = format!("{p1}\n\n{p2}\n\n{p3}");
let pages = vec![make_page(1, p1), make_page(2, p2), make_page(3, p3)];
let boundaries = recompute_boundaries_from_pages(&content, &pages);
assert_eq!(boundaries.len(), 3, "all pages should resolve to boundaries");
assert_eq!(&content[boundaries[0].byte_start..boundaries[0].byte_end], p1);
assert_eq!(&content[boundaries[1].byte_start..boundaries[1].byte_end], p2);
assert_eq!(&content[boundaries[2].byte_start..boundaries[2].byte_end], p3);
}
// When PageContent.content is raw (control char present) but result.content has the
// cleaned version, the affected page is silently skipped — leaving fewer boundaries
// than pages. Documents the pre-fix failure mode.
#[test]
fn recompute_boundaries_raw_content_causes_skipped_pages() {
// U+0001 between word chars → fix_pdf_control_chars replaces with '-'
let p1_clean = "Hello world";
let p2_raw = "ab\x01cd"; // raw page text — control char present
let p2_clean = "ab-cd"; // what result.content contains after cleanup
let p3_clean = "Third page";
let content = format!("{p1_clean}\n\n{p2_clean}\n\n{p3_clean}");
// Pre-fix scenario: page.content = raw, result.content = cleaned → mismatch
let pages = vec![
make_page(1, p1_clean),
make_page(2, p2_raw), // intentionally stale raw content
make_page(3, p3_clean),
];
let boundaries = recompute_boundaries_from_pages(&content, &pages);
// Page 2 is skipped: neither exact nor first-line search finds "ab\x01cd"
// inside content (which has "ab-cd"). Only pages 1 and 3 resolve.
assert_eq!(boundaries.len(), 2, "page with raw/cleaned mismatch should be skipped");
assert_eq!(boundaries[0].page_number, 1);
assert_eq!(boundaries[1].page_number, 3);
}
// When PageContent.content is the cleaned text (the fix), all pages resolve.
#[test]
fn recompute_boundaries_cleaned_content_resolves_all_pages() {
let p1_clean = "Hello world";
let p2_clean = "ab-cd"; // cleaned — matches result.content exactly
let p3_clean = "Third page";
let content = format!("{p1_clean}\n\n{p2_clean}\n\n{p3_clean}");
// Post-fix scenario: page.content = cleaned, result.content = cleaned → exact match
let pages = vec![make_page(1, p1_clean), make_page(2, p2_clean), make_page(3, p3_clean)];
let boundaries = recompute_boundaries_from_pages(&content, &pages);
assert_eq!(boundaries.len(), 3, "all pages should resolve after fix");
assert_eq!(&content[boundaries[1].byte_start..boundaries[1].byte_end], p2_clean);
}
// PDF pages often have trailing spaces before "\n\n" paragraph separators (PDF
// rendering artifact). render_plain trims each paragraph via paragraph.trim(),
// so result.content lacks those trailing spaces while page.content retains them.
// The normalised-exact match must succeed and produce correct byte_end so that
// subsequent pages are found without cascading search_offset over-advance.
#[test]
fn recompute_boundaries_trailing_space_pages_all_resolve() {
// Simulate PDF page content with trailing spaces before "\n\n".
let p1_raw = "Heading \n\nBody paragraph one. ";
let p2_raw = "Second heading \n\nBody paragraph two. ";
let p3_raw = "Conclusion. ";
// result.content as render_plain produces it (each paragraph trimmed).
let p1_norm = "Heading\n\nBody paragraph one.";
let p2_norm = "Second heading\n\nBody paragraph two.";
let p3_norm = "Conclusion.";
let content = format!("{p1_norm}\n\n{p2_norm}\n\n{p3_norm}");
let pages = vec![make_page(1, p1_raw), make_page(2, p2_raw), make_page(3, p3_raw)];
let boundaries = recompute_boundaries_from_pages(&content, &pages);
assert_eq!(boundaries.len(), 3, "all pages must resolve despite trailing spaces");
assert_eq!(&content[boundaries[0].byte_start..boundaries[0].byte_end], p1_norm);
assert_eq!(&content[boundaries[1].byte_start..boundaries[1].byte_end], p2_norm);
assert_eq!(&content[boundaries[2].byte_start..boundaries[2].byte_end], p3_norm);
}
}

View File

@@ -0,0 +1,207 @@
//! Output format conversion for extraction results.
//!
//! This module handles the final step of output format application: swapping
//! pre-rendered content into the result and recording format metadata.
//!
//! The heavy rendering work (Markdown, Djot, HTML) is now done earlier in the
//! pipeline inside `derive_extraction_result`, which populates
//! `ExtractionResult::formatted_content`. This function simply swaps that
//! pre-rendered content into the `content` field after post-processors have
//! operated on the plain-text version.
use crate::core::config::OutputFormat;
use crate::types::ExtractionResult;
#[cfg(test)]
use std::borrow::Cow;
/// Apply output format conversion to the extraction result.
///
/// Records the output format in metadata and swaps in pre-rendered content
/// (produced during `derive_extraction_result`) if available.
///
/// This runs as the final pipeline step, after post-processors have operated
/// on the plain-text `content` field.
///
/// # Arguments
///
/// * `result` - The extraction result to modify
/// * `output_format` - The desired output format
#[cfg_attr(alef, alef(skip))]
pub fn apply_output_format(result: ExtractionResult, output_format: OutputFormat) -> ExtractionResult {
let mut result = result;
let format_name = match output_format {
OutputFormat::Plain => "plain",
OutputFormat::Markdown => "markdown",
OutputFormat::Djot => "djot",
OutputFormat::Html => "html",
OutputFormat::Json => "json",
OutputFormat::Structured => "structured",
OutputFormat::Custom(ref name) => name.as_str(),
};
result.metadata.output_format = Some(format_name.to_string());
// Swap in pre-rendered content if available (populated by derive_extraction_result).
if let Some(formatted) = result.formatted_content.take() {
result.content = formatted;
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Metadata;
#[test]
fn test_apply_output_format_plain() {
let result = ExtractionResult {
content: "Hello World".to_string(),
mime_type: Cow::Borrowed("text/plain"),
..Default::default()
};
let result = apply_output_format(result, OutputFormat::Plain);
assert_eq!(result.content, "Hello World");
assert_eq!(result.metadata.output_format, Some("plain".to_string()));
}
#[test]
fn test_apply_output_format_markdown_no_prerender() {
let result = ExtractionResult {
content: "Hello World".to_string(),
mime_type: Cow::Borrowed("text/plain"),
..Default::default()
};
let result = apply_output_format(result, OutputFormat::Markdown);
// Without pre-rendered content, content stays as-is
assert_eq!(result.content, "Hello World");
assert_eq!(result.metadata.output_format, Some("markdown".to_string()));
}
#[test]
fn test_apply_output_format_swaps_formatted_content() {
let result = ExtractionResult {
content: "plain text".to_string(),
mime_type: Cow::Borrowed("text/plain"),
formatted_content: Some("# Heading\n\nFormatted markdown".to_string()),
..Default::default()
};
let result = apply_output_format(result, OutputFormat::Markdown);
assert_eq!(result.content, "# Heading\n\nFormatted markdown");
assert!(result.formatted_content.is_none(), "formatted_content should be taken");
assert_eq!(result.metadata.output_format, Some("markdown".to_string()));
}
#[test]
fn test_apply_output_format_html_with_prerender() {
let result = ExtractionResult {
content: "plain text".to_string(),
mime_type: Cow::Borrowed("text/plain"),
formatted_content: Some("<p>Hello World</p>".to_string()),
..Default::default()
};
let result = apply_output_format(result, OutputFormat::Html);
assert_eq!(result.content, "<p>Hello World</p>");
assert_eq!(result.metadata.output_format, Some("html".to_string()));
}
#[test]
fn test_apply_output_format_djot_with_prerender() {
let result = ExtractionResult {
content: "plain text".to_string(),
mime_type: Cow::Borrowed("text/plain"),
formatted_content: Some("# Djot heading".to_string()),
..Default::default()
};
let result = apply_output_format(result, OutputFormat::Djot);
assert_eq!(result.content, "# Djot heading");
assert_eq!(result.metadata.output_format, Some("djot".to_string()));
}
#[test]
fn test_apply_output_format_structured() {
let result = ExtractionResult {
content: "Hello World".to_string(),
mime_type: Cow::Borrowed("text/plain"),
..Default::default()
};
let result = apply_output_format(result, OutputFormat::Structured);
assert_eq!(result.content, "Hello World");
assert_eq!(result.metadata.output_format, Some("structured".to_string()));
}
#[test]
fn test_apply_output_format_preserves_metadata() {
use ahash::AHashMap;
let mut additional = AHashMap::new();
additional.insert(Cow::Borrowed("custom_key"), serde_json::json!("custom_value"));
let metadata = Metadata {
title: Some("Test Title".to_string()),
additional,
..Default::default()
};
let result = ExtractionResult {
content: "Hello World".to_string(),
mime_type: Cow::Borrowed("text/plain"),
metadata,
..Default::default()
};
let result = apply_output_format(result, OutputFormat::Markdown);
assert_eq!(result.metadata.title, Some("Test Title".to_string()));
assert_eq!(
result.metadata.additional.get("custom_key"),
Some(&serde_json::json!("custom_value"))
);
}
#[test]
fn test_apply_output_format_preserves_tables() {
use crate::types::Table;
let table = Table {
cells: vec![vec!["A".to_string(), "B".to_string()]],
markdown: "| A | B |".to_string(),
page_number: 1,
bounding_box: None,
};
let result = ExtractionResult {
content: "Hello World".to_string(),
mime_type: Cow::Borrowed("text/plain"),
tables: vec![table],
..Default::default()
};
let result = apply_output_format(result, OutputFormat::Html);
assert_eq!(result.tables.len(), 1);
assert_eq!(result.tables[0].cells[0][0], "A");
}
#[test]
fn test_apply_output_format_sets_typed_field() {
let result = ExtractionResult {
content: "test".to_string(),
mime_type: Cow::Borrowed("text/plain"),
..Default::default()
};
let result = apply_output_format(result, OutputFormat::Djot);
assert_eq!(result.metadata.output_format, Some("djot".to_string()));
}
}

View File

@@ -0,0 +1,67 @@
//! Pipeline initialization and setup logic.
//!
//! This module handles the initialization of features and processor cache
//! required for pipeline execution.
use crate::Result;
#[cfg(feature = "quality")]
use std::sync::OnceLock;
use super::cache::{PROCESSOR_CACHE, ProcessorCache};
/// Type alias for processor stages tuple (Early, Middle, Late).
type ProcessorStages = (
std::sync::Arc<Vec<std::sync::Arc<dyn crate::plugins::PostProcessor>>>,
std::sync::Arc<Vec<std::sync::Arc<dyn crate::plugins::PostProcessor>>>,
std::sync::Arc<Vec<std::sync::Arc<dyn crate::plugins::PostProcessor>>>,
);
/// Initialize feature-specific systems that may be needed during pipeline execution.
pub(super) fn initialize_features() {
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
{
let _ = crate::keywords::ensure_initialized();
}
#[cfg(feature = "language-detection")]
{
let _ = crate::language_detection::ensure_initialized();
}
#[cfg(feature = "chunking")]
{
let _ = crate::chunking::ensure_initialized();
}
#[cfg(feature = "quality")]
{
static QUALITY_INIT: OnceLock<()> = OnceLock::new();
QUALITY_INIT.get_or_init(|| {
let registry = crate::plugins::registry::get_post_processor_registry();
let mut reg = registry.write();
let _ = reg.register(std::sync::Arc::new(crate::text::QualityProcessor));
});
}
}
/// Initialize the processor cache if not already initialized.
pub(super) fn initialize_processor_cache() -> Result<()> {
let mut cache_lock = PROCESSOR_CACHE.write();
if cache_lock.is_none() {
*cache_lock = Some(ProcessorCache::new()?);
}
Ok(())
}
/// Get processors from the cache, organized by stage.
pub(super) fn get_processors_from_cache() -> Result<ProcessorStages> {
let cache_lock = PROCESSOR_CACHE.read();
let cache = cache_lock
.as_ref()
.ok_or_else(|| crate::KreuzbergError::Other("Processor cache not initialized".to_string()))?;
Ok((
std::sync::Arc::clone(&cache.early),
std::sync::Arc::clone(&cache.middle),
std::sync::Arc::clone(&cache.late),
))
}

View File

@@ -0,0 +1,478 @@
//! Post-processing pipeline orchestration.
//!
//! This module orchestrates the post-processing pipeline, executing validators,
//! quality processing, chunking, and custom hooks in the correct order.
mod cache;
mod execution;
mod features;
mod format;
mod initialization;
#[cfg(test)]
mod tests;
pub use cache::clear_processor_cache;
pub use format::apply_output_format;
use crate::Result;
use crate::core::config::ExtractionConfig;
use crate::types::ExtractionResult;
use crate::types::internal::InternalDocument;
use execution::{execute_processors, execute_validators};
use features::{execute_chunking, execute_language_detection, execute_token_reduction};
use initialization::{get_processors_from_cache, initialize_features, initialize_processor_cache};
/// Run the post-processing pipeline on an `InternalDocument`.
///
/// Derives `ExtractionResult` from `InternalDocument` via the derivation pipeline,
/// then executes post-processing in the following order:
/// 1. Post-Processors - Execute by stage (Early, Middle, Late) to modify/enhance the result
/// 2. Quality Processing - Text cleaning and quality scoring
/// 3. Chunking - Text splitting if enabled
/// 4. Validators - Run validation hooks on the processed result (can fail fast)
///
/// # Arguments
///
/// * `doc` - The internal document produced by the extractor
/// * `config` - Extraction configuration
///
/// # Returns
///
/// The processed extraction result.
///
/// # Errors
///
/// - Validator errors bubble up immediately
/// - Post-processor errors are caught and recorded in metadata
/// - System errors (IO, RuntimeError equivalents) always bubble up
#[cfg_attr(feature = "otel", tracing::instrument(
skip(doc, config),
fields(
pipeline.stage = "post_processing",
content.element_count = doc.elements.len(),
)
))]
#[cfg_attr(alef, alef(skip))]
pub async fn run_pipeline(mut doc: InternalDocument, config: &ExtractionConfig) -> Result<ExtractionResult> {
// Propagate rendering preferences from config into the document.
doc.ocr_text_only = config.images.as_ref().map(|i| i.ocr_text_only).unwrap_or(false);
doc.append_ocr_text = config.images.as_ref().map(|i| i.append_ocr_text).unwrap_or(false);
// 1. Process extracted images with OCR if configured
#[cfg(all(feature = "ocr", feature = "tokio-runtime"))]
let image_ocr_enabled = config.images.as_ref().map(|i| i.run_ocr_on_images).unwrap_or(true);
#[cfg(all(feature = "ocr", feature = "tokio-runtime"))]
if image_ocr_enabled && config.ocr.is_some() && !doc.images.is_empty() {
let images_to_process = std::mem::take(&mut doc.images);
match crate::extraction::image_ocr::process_images_with_ocr(
images_to_process,
config,
&mut doc.processing_warnings,
)
.await
{
Ok(processed) => {
doc.images = processed;
}
Err(e) => {
doc.processing_warnings.push(crate::types::ProcessingWarning {
source: std::borrow::Cow::Borrowed("image_ocr"),
message: std::borrow::Cow::Owned(format!("Image OCR failed: {e}")),
});
}
}
}
replace_embedded_image_markdown_with_ocr(&mut doc);
append_embedded_image_ocr_text(&mut doc);
// Pre-render markdown for the chunker's heading context resolution when:
// - Markdown chunking is configured
// - Output format is not already Markdown (which would produce formatted_content anyway)
// Plain-text rendering strips heading markers, so the markdown chunker needs
// a separate markdown rendering to build the heading hierarchy for chunk metadata.
#[cfg(feature = "chunking")]
let chunker_heading_source = {
let needs_markdown = config.chunking.as_ref().is_some_and(|c| {
c.chunker_type == crate::core::config::ChunkerType::Markdown
|| c.resolve_preset().chunker_type == crate::core::config::ChunkerType::Markdown
}) && config.output_format == crate::core::config::OutputFormat::Plain;
if needs_markdown {
Some(crate::rendering::render_markdown(&doc))
} else {
None
}
};
// Pre-render styled HTML before `doc` is consumed by `derive_extraction_result`.
// When `html` is active and the caller has configured `html_output`, we
// render the document here and inject the result after derivation.
#[cfg(feature = "html")]
let styled_html_prerender: Option<String> = {
use crate::plugins::Renderer as _;
if config.output_format == crate::core::config::OutputFormat::Html {
config.html_output.as_ref().and_then(|html_cfg| {
match crate::rendering::StyledHtmlRenderer::new(html_cfg.clone()) {
Ok(renderer) => match renderer.render(&doc) {
Ok(html) => Some(html),
Err(e) => {
tracing::warn!("StyledHtmlRenderer render failed, falling back to default HTML: {e}");
None
}
},
Err(e) => {
tracing::warn!("StyledHtmlRenderer construction failed, falling back to default HTML: {e}");
None
}
}
})
} else {
None
}
};
// 2. Derive ExtractionResult from InternalDocument
let include_structure = config.include_document_structure;
let mut result =
crate::extraction::derive::derive_extraction_result(doc, include_structure, config.output_format.clone());
// Inject pre-rendered styled HTML (overrides the default render_html output).
#[cfg(feature = "html")]
if let Some(html) = styled_html_prerender {
result.formatted_content = Some(html);
}
// Temporarily store pre-rendered markdown for chunker heading context.
// Tracked separately so we can remove it after chunking — apply_output_format
// must not swap this into result.content when output_format is Plain.
#[cfg(feature = "chunking")]
let chunker_only_markdown = result.formatted_content.is_none();
#[cfg(feature = "chunking")]
if chunker_only_markdown && let Some(md) = chunker_heading_source {
result.formatted_content = Some(md);
}
// 2. Run post-processing pipeline
let pp_config = config.postprocessor.as_ref();
let postprocessing_enabled = pp_config.is_none_or(|c| c.enabled);
if postprocessing_enabled {
initialize_features();
initialize_processor_cache()?;
let (early_processors, middle_processors, late_processors) = get_processors_from_cache()?;
execute_processors(
&mut result,
config,
&pp_config,
early_processors,
middle_processors,
late_processors,
)
.await?;
}
execute_chunking(&mut result, config)?;
// Clear temporary markdown if it was only stored for chunker heading context.
// This prevents apply_output_format from swapping it into result.content.
#[cfg(feature = "chunking")]
if chunker_only_markdown {
result.formatted_content = None;
}
execute_language_detection(&mut result, config)?;
execute_token_reduction(&mut result, config)?;
execute_validators(&result, config).await?;
apply_element_transform(&mut result, config);
normalize_nfc(&mut result);
// Run LLM-based structured extraction BEFORE output formatting
// so extraction sees plain text, not markdown/HTML
// TODO(wasm-llm): hosted structured extraction should run on wasm through
// liter-llm's wasm-http backend once browser/runtime support is wired.
#[cfg(all(feature = "liter-llm", not(target_os = "windows"), not(target_arch = "wasm32")))]
if let Some(ref structured_config) = config.structured_extraction {
match crate::llm::structured::extract_structured(&result.content, structured_config).await {
Ok((output, usage)) => {
result.structured_output = Some(output);
crate::llm::usage::push_llm_usage(&mut result, usage);
}
Err(e) => {
tracing::warn!("Structured extraction failed: {e}");
result.processing_warnings.push(crate::types::ProcessingWarning {
source: std::borrow::Cow::Borrowed("structured_extraction"),
message: std::borrow::Cow::Owned(format!("Structured extraction failed: {e}")),
});
}
}
}
// TODO(wasm-llm): keep wasm in the fallback branch until structured
// extraction has an async wasm-compatible runtime path.
#[cfg(any(not(feature = "liter-llm"), target_os = "windows", target_arch = "wasm32"))]
if config.structured_extraction.is_some() {
result.processing_warnings.push(crate::types::ProcessingWarning {
source: std::borrow::Cow::Borrowed("structured_extraction"),
message: std::borrow::Cow::Borrowed("Structured extraction requires the 'liter-llm' feature"),
});
}
// Apply output format conversion as the final step
result = apply_output_format(result, config.output_format.clone());
Ok(result)
}
/// Run the post-processing pipeline synchronously (WASM-compatible version).
///
/// This is a synchronous implementation for WASM and non-async contexts.
/// It performs a subset of the full async pipeline, excluding async post-processors
/// and validators.
///
/// # Arguments
///
/// * `doc` - The internal document produced by the extractor
/// * `config` - Extraction configuration
///
/// # Returns
///
/// The processed extraction result.
///
/// # Notes
///
/// This function is only available when the `tokio-runtime` feature is disabled.
/// It handles:
/// - Quality processing (if enabled)
/// - Chunking (if enabled)
/// - Language detection (if enabled)
///
/// It does NOT handle:
/// - Async post-processors
/// - Async validators
#[cfg(not(feature = "tokio-runtime"))]
#[cfg_attr(alef, alef(skip))]
pub fn run_pipeline_sync(doc: InternalDocument, config: &ExtractionConfig) -> Result<ExtractionResult> {
// Pre-render markdown for chunker heading context (same logic as async path).
#[cfg(feature = "chunking")]
let chunker_heading_source = {
let needs_markdown = config.chunking.as_ref().is_some_and(|c| {
c.chunker_type == crate::core::config::ChunkerType::Markdown
|| c.resolve_preset().chunker_type == crate::core::config::ChunkerType::Markdown
}) && config.output_format == crate::core::config::OutputFormat::Plain;
if needs_markdown {
Some(crate::rendering::render_markdown(&doc))
} else {
None
}
};
// Pre-render styled HTML before `doc` is consumed (mirrors async path).
#[cfg(feature = "html")]
let styled_html_prerender: Option<String> = {
use crate::plugins::Renderer as _;
if config.output_format == crate::core::config::OutputFormat::Html {
config.html_output.as_ref().and_then(|html_cfg| {
match crate::rendering::StyledHtmlRenderer::new(html_cfg.clone()) {
Ok(renderer) => match renderer.render(&doc) {
Ok(html) => Some(html),
Err(e) => {
tracing::warn!("StyledHtmlRenderer render failed, falling back to default HTML: {e}");
None
}
},
Err(e) => {
tracing::warn!("StyledHtmlRenderer construction failed, falling back to default HTML: {e}");
None
}
}
})
} else {
None
}
};
// 1. Derive ExtractionResult from InternalDocument
let include_structure = config.include_document_structure;
let mut result =
crate::extraction::derive::derive_extraction_result(doc, include_structure, config.output_format.clone());
// Inject pre-rendered styled HTML.
#[cfg(feature = "html")]
if let Some(html) = styled_html_prerender {
result.formatted_content = Some(html);
}
#[cfg(feature = "chunking")]
let chunker_only_markdown = result.formatted_content.is_none();
#[cfg(feature = "chunking")]
if chunker_only_markdown && let Some(md) = chunker_heading_source {
result.formatted_content = Some(md);
}
// 2. Run synchronous post-processing
execute_chunking(&mut result, config)?;
#[cfg(feature = "chunking")]
if chunker_only_markdown {
result.formatted_content = None;
}
execute_language_detection(&mut result, config)?;
execute_token_reduction(&mut result, config)?;
apply_element_transform(&mut result, config);
normalize_nfc(&mut result);
// Apply output format conversion as the final step
result = apply_output_format(result, config.output_format.clone());
Ok(result)
}
/// Transform to element-based output if requested by the config.
fn apply_element_transform(result: &mut ExtractionResult, config: &ExtractionConfig) {
if config.result_format == crate::types::ResultFormat::ElementBased {
result.elements = Some(crate::extraction::transform::transform_extraction_result_to_elements(
result,
));
}
}
/// Replace inline markdown image references with OCR text for formats (e.g. PPTX)
/// that bake placeholders into paragraph text rather than using `ElementKind::Image`.
fn replace_embedded_image_markdown_with_ocr(doc: &mut InternalDocument) {
if !doc.ocr_text_only || doc.images.is_empty() {
return;
}
let mut image_idx = 0usize;
for elem in &mut doc.elements {
if !matches!(elem.kind, crate::types::internal::ElementKind::Paragraph) {
continue;
}
if !is_markdown_image_reference(&elem.text) {
continue;
}
if let Some(img) = doc.images.get(image_idx)
&& let Some(ocr) = &img.ocr_result
&& !ocr.content.is_empty()
{
elem.text = ocr.content.clone();
image_idx += 1;
continue;
}
image_idx += 1;
}
for table in &mut doc.tables {
for row in &mut table.cells {
for cell in row {
if !is_markdown_image_reference(cell) {
continue;
}
if let Some(img) = doc.images.get(image_idx)
&& let Some(ocr) = &img.ocr_result
&& !ocr.content.is_empty()
{
*cell = ocr.content.clone();
image_idx += 1;
continue;
}
image_idx += 1;
}
}
}
}
/// Append OCR text after inline markdown image references for formats (e.g. PPTX)
/// that bake placeholders into paragraph text. Only runs when `append_ocr_text` is
/// `true` and `ocr_text_only` is `false`.
fn append_embedded_image_ocr_text(doc: &mut InternalDocument) {
if doc.ocr_text_only || !doc.append_ocr_text || doc.images.is_empty() {
return;
}
let mut image_idx = 0usize;
let mut new_elements = Vec::with_capacity(doc.elements.len() * 2);
for elem in &doc.elements {
new_elements.push(elem.clone());
if matches!(elem.kind, crate::types::internal::ElementKind::Paragraph)
&& is_markdown_image_reference(&elem.text)
{
if let Some(img) = doc.images.get(image_idx)
&& let Some(ocr) = &img.ocr_result
&& !ocr.content.is_empty()
{
let ocr_elem = crate::types::internal::InternalElement::text(
crate::types::internal::ElementKind::Paragraph,
ocr.content.clone(),
0,
);
new_elements.push(ocr_elem);
}
image_idx += 1;
}
}
doc.elements = new_elements;
for table in &mut doc.tables {
for row in &mut table.cells {
for cell in row {
if !is_markdown_image_reference(cell) {
continue;
}
if let Some(img) = doc.images.get(image_idx)
&& let Some(ocr) = &img.ocr_result
&& !ocr.content.is_empty()
{
*cell = format!("{}\n\n{}", cell.trim(), ocr.content);
}
image_idx += 1;
}
}
}
}
/// Returns `true` if `text` is exactly a markdown image reference (`![alt](url)`).
fn is_markdown_image_reference(text: &str) -> bool {
let t = text.trim();
if !t.starts_with("![") {
return false;
}
let Some(bracket_end) = t.find("](") else {
return false;
};
if bracket_end < 2 {
return false;
}
let after = &t[bracket_end + 2..];
after.ends_with(')')
}
/// Apply NFC unicode normalization to all text content.
///
/// Ensures consistent representation of composed characters (e.g., é vs e+combining accent)
/// across all extraction backends (PDF, OCR, DOCX, HTML, etc.).
fn normalize_nfc(result: &mut ExtractionResult) {
#[cfg(feature = "quality")]
{
use unicode_normalization::UnicodeNormalization;
result.content = result.content.nfc().collect();
if let Some(pages) = result.pages.as_mut() {
for page in pages.iter_mut() {
page.content = page.content.nfc().collect();
}
}
}
// Suppress unused variable warning when quality feature is disabled
let _ = result;
}

View File

@@ -0,0 +1,993 @@
//! Pipeline orchestration tests.
use super::*;
use crate::core::config::OutputFormat;
use crate::types::Metadata;
use crate::types::internal::{ElementKind, InternalDocument, InternalElement};
use serial_test::serial;
use std::borrow::Cow;
/// Build an `InternalDocument` with a single paragraph element for pipeline tests.
fn make_doc(content: &str, mime: &str) -> InternalDocument {
let mut doc = InternalDocument::new("plain");
doc.mime_type = mime.to_string();
if !content.is_empty() {
doc.push_element(InternalElement::text(ElementKind::Paragraph, content, 0));
}
doc
}
/// Build an `InternalDocument` with content, mime, and custom metadata.
fn make_doc_with_metadata(content: &str, mime: &str, metadata: Metadata) -> InternalDocument {
let mut doc = make_doc(content, mime);
doc.metadata = metadata;
doc
}
const VALIDATION_MARKER_KEY: &str = "registry_validation_marker";
#[cfg(feature = "quality")]
const QUALITY_VALIDATION_MARKER: &str = "quality_validation_test";
const POSTPROCESSOR_VALIDATION_MARKER: &str = "postprocessor_validation_test";
const ORDER_VALIDATION_MARKER: &str = "order_validation_test";
/// Ensure the quality processor is registered and cache is fresh.
/// Needed because other tests may call `shutdown_all()` on the registry,
/// and the `OnceLock` in `initialize_features` prevents re-registration.
#[cfg(feature = "quality")]
fn ensure_quality_processor() {
let registry = crate::plugins::registry::get_post_processor_registry();
let mut reg = registry.write();
let _ = reg.register(std::sync::Arc::new(crate::text::QualityProcessor));
drop(reg);
let _ = clear_processor_cache();
}
#[tokio::test]
#[serial]
async fn test_run_pipeline_basic() {
let mut doc = make_doc("test", "text/plain");
doc.metadata.additional.insert(
Cow::Borrowed(VALIDATION_MARKER_KEY),
serde_json::json!(ORDER_VALIDATION_MARKER),
);
let config = ExtractionConfig {
postprocessor: Some(crate::core::config::PostProcessorConfig {
enabled: false,
..Default::default()
}),
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
assert_eq!(processed.content, "test");
}
#[tokio::test]
#[serial]
#[cfg(feature = "quality")]
async fn test_pipeline_with_quality_processing() {
ensure_quality_processor();
let doc = make_doc("This is a test document with some meaningful content.", "text/plain");
let config = ExtractionConfig {
enable_quality_processing: true,
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
assert!(processed.quality_score.is_some());
}
#[tokio::test]
#[serial]
async fn test_pipeline_without_quality_processing() {
let doc = make_doc("test", "text/plain");
let config = ExtractionConfig {
enable_quality_processing: false,
postprocessor: Some(crate::core::config::PostProcessorConfig {
enabled: false,
..Default::default()
}),
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
assert!(processed.quality_score.is_none());
}
#[tokio::test]
#[serial]
#[cfg(feature = "chunking")]
async fn test_pipeline_with_chunking() {
let doc = make_doc(
&"This is a long text that should be chunked. ".repeat(100),
"text/plain",
);
let config = ExtractionConfig {
chunking: Some(crate::ChunkingConfig {
max_characters: 500,
overlap: 50,
trim: true,
chunker_type: crate::ChunkerType::Text,
..Default::default()
}),
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
let chunks = processed.chunks.as_ref().expect("chunks should be present");
assert!(chunks.len() > 1);
}
#[tokio::test]
#[serial]
async fn test_pipeline_without_chunking() {
let doc = make_doc("test", "text/plain");
let config = ExtractionConfig {
chunking: None,
postprocessor: Some(crate::core::config::PostProcessorConfig {
enabled: false,
..Default::default()
}),
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
assert!(processed.chunks.is_none());
}
#[tokio::test]
#[serial]
async fn test_pipeline_preserves_metadata() {
use ahash::AHashMap;
let mut additional = AHashMap::new();
additional.insert(Cow::Borrowed("source"), serde_json::json!("test"));
additional.insert(Cow::Borrowed("page"), serde_json::json!(1));
let doc = make_doc_with_metadata(
"test",
"text/plain",
Metadata {
additional,
..Default::default()
},
);
let config = ExtractionConfig {
postprocessor: Some(crate::core::config::PostProcessorConfig {
enabled: false,
..Default::default()
}),
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
assert_eq!(
processed.metadata.additional.get("source").unwrap(),
&serde_json::json!("test")
);
assert_eq!(
processed.metadata.additional.get("page").unwrap(),
&serde_json::json!(1)
);
}
#[tokio::test]
#[serial]
async fn test_pipeline_preserves_tables() {
use crate::types::Table;
let table = Table {
cells: vec![vec!["A".to_string(), "B".to_string()]],
markdown: "| A | B |".to_string(),
page_number: 0,
bounding_box: None,
};
let mut doc = make_doc("test", "text/plain");
doc.tables.push(table);
let config = ExtractionConfig {
postprocessor: Some(crate::core::config::PostProcessorConfig {
enabled: false,
..Default::default()
}),
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
assert_eq!(processed.tables.len(), 1);
assert_eq!(processed.tables[0].cells.len(), 1);
}
#[tokio::test]
#[serial]
async fn test_pipeline_empty_content() {
{
let registry = crate::plugins::registry::get_post_processor_registry();
registry.write().shutdown_all().unwrap();
}
{
let registry = crate::plugins::registry::get_validator_registry();
registry.write().shutdown_all().unwrap();
}
let doc = make_doc("", "text/plain");
let config = ExtractionConfig::default();
let processed = run_pipeline(doc, &config).await.unwrap();
assert_eq!(processed.content, "");
}
#[tokio::test]
#[serial]
#[cfg(feature = "chunking")]
async fn test_pipeline_with_all_features() {
#[cfg(feature = "quality")]
ensure_quality_processor();
let doc = make_doc(&"This is a comprehensive test document. ".repeat(50), "text/plain");
let config = ExtractionConfig {
enable_quality_processing: true,
chunking: Some(crate::ChunkingConfig {
max_characters: 500,
overlap: 50,
trim: true,
chunker_type: crate::ChunkerType::Text,
..Default::default()
}),
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
#[cfg(feature = "quality")]
assert!(processed.quality_score.is_some());
assert!(processed.chunks.is_some());
}
#[tokio::test]
#[serial]
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
async fn test_pipeline_with_keyword_extraction() {
crate::plugins::registry::get_validator_registry()
.write()
.shutdown_all()
.unwrap();
crate::plugins::registry::get_post_processor_registry()
.write()
.shutdown_all()
.unwrap();
// Register keyword processor directly (bypasses Lazy which only runs once per process)
let _ = crate::keywords::register_keyword_processor();
clear_processor_cache().unwrap();
let doc = make_doc(
r#"
Machine learning is a branch of artificial intelligence that focuses on
building systems that can learn from data. Deep learning is a subset of
machine learning that uses neural networks with multiple layers.
Natural language processing enables computers to understand human language.
"#,
"text/plain",
);
#[cfg(feature = "keywords-yake")]
let keyword_config = crate::keywords::KeywordConfig::yake();
#[cfg(all(feature = "keywords-rake", not(feature = "keywords-yake")))]
let keyword_config = crate::keywords::KeywordConfig::rake();
let config = ExtractionConfig {
keywords: Some(keyword_config),
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
let keywords = processed
.extracted_keywords
.as_ref()
.expect("Should have extracted keywords");
assert!(!keywords.is_empty(), "Should have extracted keywords");
let first_keyword = &keywords[0];
assert!(!first_keyword.text.is_empty());
assert!(first_keyword.score > 0.0);
}
#[tokio::test]
#[serial]
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
async fn test_pipeline_without_keyword_config() {
let doc = make_doc("Machine learning and artificial intelligence.", "text/plain");
let config = ExtractionConfig {
keywords: None,
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
assert!(!processed.metadata.additional.contains_key("keywords"));
}
#[tokio::test]
#[serial]
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
async fn test_pipeline_keyword_extraction_short_content() {
crate::plugins::registry::get_validator_registry()
.write()
.shutdown_all()
.unwrap();
crate::plugins::registry::get_post_processor_registry()
.write()
.shutdown_all()
.unwrap();
let doc = make_doc("Short text", "text/plain");
#[cfg(feature = "keywords-yake")]
let keyword_config = crate::keywords::KeywordConfig::yake();
#[cfg(all(feature = "keywords-rake", not(feature = "keywords-yake")))]
let keyword_config = crate::keywords::KeywordConfig::rake();
let config = ExtractionConfig {
keywords: Some(keyword_config),
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
assert!(!processed.metadata.additional.contains_key("keywords"));
}
#[tokio::test]
#[serial]
async fn test_postprocessor_runs_before_validator() {
use crate::plugins::{Plugin, PostProcessor, ProcessingStage, Validator};
use async_trait::async_trait;
use std::sync::Arc;
struct TestPostProcessor;
impl Plugin for TestPostProcessor {
fn name(&self) -> &str {
"test-processor"
}
fn version(&self) -> String {
"1.0.0".to_string()
}
fn initialize(&self) -> Result<()> {
Ok(())
}
fn shutdown(&self) -> Result<()> {
Ok(())
}
}
#[async_trait]
impl PostProcessor for TestPostProcessor {
async fn process(&self, result: &mut ExtractionResult, _config: &ExtractionConfig) -> Result<()> {
result
.metadata
.additional
.insert(Cow::Borrowed("processed"), serde_json::json!(true));
Ok(())
}
fn processing_stage(&self) -> ProcessingStage {
ProcessingStage::Middle
}
}
struct TestValidator;
impl Plugin for TestValidator {
fn name(&self) -> &str {
"test-validator"
}
fn version(&self) -> String {
"1.0.0".to_string()
}
fn initialize(&self) -> Result<()> {
Ok(())
}
fn shutdown(&self) -> Result<()> {
Ok(())
}
}
#[async_trait]
impl Validator for TestValidator {
async fn validate(&self, result: &ExtractionResult, _config: &ExtractionConfig) -> Result<()> {
let should_validate = result
.metadata
.additional
.get(VALIDATION_MARKER_KEY)
.and_then(|v| v.as_str())
== Some(POSTPROCESSOR_VALIDATION_MARKER);
if !should_validate {
return Ok(());
}
let processed = result
.metadata
.additional
.get("processed")
.and_then(|v| v.as_bool())
.unwrap_or(false);
if !processed {
return Err(crate::KreuzbergError::Validation {
message: "Post-processor did not run before validator".to_string(),
source: None,
});
}
Ok(())
}
}
let pp_registry = crate::plugins::registry::get_post_processor_registry();
let val_registry = crate::plugins::registry::get_validator_registry();
clear_processor_cache().unwrap();
pp_registry.write().shutdown_all().unwrap();
val_registry.write().shutdown_all().unwrap();
clear_processor_cache().unwrap();
{
let mut registry = pp_registry.write();
registry.register(Arc::new(TestPostProcessor)).unwrap();
}
{
let mut registry = val_registry.write();
registry.register(Arc::new(TestValidator)).unwrap();
}
clear_processor_cache().unwrap();
let mut doc = make_doc("test", "text/plain");
doc.metadata.additional.insert(
Cow::Borrowed(VALIDATION_MARKER_KEY),
serde_json::json!(POSTPROCESSOR_VALIDATION_MARKER),
);
let config = ExtractionConfig {
postprocessor: Some(crate::core::config::PostProcessorConfig {
enabled: true,
enabled_set: None,
disabled_set: None,
enabled_processors: None,
disabled_processors: None,
}),
..Default::default()
};
let processed = run_pipeline(doc, &config).await;
pp_registry.write().shutdown_all().unwrap();
val_registry.write().shutdown_all().unwrap();
assert!(processed.is_ok(), "Validator should have seen post-processor metadata");
let processed = processed.unwrap();
assert_eq!(
processed.metadata.additional.get("processed"),
Some(&serde_json::json!(true)),
"Post-processor metadata should be present"
);
}
#[tokio::test]
#[serial]
#[cfg(feature = "quality")]
async fn test_quality_processing_runs_before_validator() {
ensure_quality_processor();
use crate::plugins::{Plugin, Validator};
use async_trait::async_trait;
use std::sync::Arc;
struct QualityValidator;
impl Plugin for QualityValidator {
fn name(&self) -> &str {
"quality-validator"
}
fn version(&self) -> String {
"1.0.0".to_string()
}
fn initialize(&self) -> Result<()> {
Ok(())
}
fn shutdown(&self) -> Result<()> {
Ok(())
}
}
#[async_trait]
impl Validator for QualityValidator {
async fn validate(&self, result: &ExtractionResult, _config: &ExtractionConfig) -> Result<()> {
let should_validate = result
.metadata
.additional
.get(VALIDATION_MARKER_KEY)
.and_then(|v| v.as_str())
== Some(QUALITY_VALIDATION_MARKER);
if !should_validate {
return Ok(());
}
if result.quality_score.is_none() {
return Err(crate::KreuzbergError::Validation {
message: "Quality processing did not run before validator".to_string(),
source: None,
});
}
Ok(())
}
}
let val_registry = crate::plugins::registry::get_validator_registry();
{
let mut registry = val_registry.write();
registry.register(Arc::new(QualityValidator)).unwrap();
}
let mut doc = make_doc("This is meaningful test content for quality scoring.", "text/plain");
doc.metadata.additional.insert(
Cow::Borrowed(VALIDATION_MARKER_KEY),
serde_json::json!(QUALITY_VALIDATION_MARKER),
);
let config = ExtractionConfig {
enable_quality_processing: true,
..Default::default()
};
let processed = run_pipeline(doc, &config).await;
{
let mut registry = val_registry.write();
registry.remove("quality-validator").unwrap();
}
assert!(processed.is_ok(), "Validator should have seen quality_score");
}
#[tokio::test]
#[serial]
async fn test_multiple_postprocessors_run_before_validator() {
use crate::plugins::{Plugin, PostProcessor, ProcessingStage, Validator};
use async_trait::async_trait;
use std::sync::Arc;
struct EarlyProcessor;
impl Plugin for EarlyProcessor {
fn name(&self) -> &str {
"early-proc"
}
fn version(&self) -> String {
"1.0.0".to_string()
}
fn initialize(&self) -> Result<()> {
Ok(())
}
fn shutdown(&self) -> Result<()> {
Ok(())
}
}
#[async_trait]
impl PostProcessor for EarlyProcessor {
async fn process(&self, result: &mut ExtractionResult, _config: &ExtractionConfig) -> Result<()> {
let mut order = result
.metadata
.additional
.get("execution_order")
.and_then(|v| v.as_array())
.cloned()
.unwrap_or_default();
order.push(serde_json::json!("early"));
result
.metadata
.additional
.insert(Cow::Borrowed("execution_order"), serde_json::json!(order));
Ok(())
}
fn processing_stage(&self) -> ProcessingStage {
ProcessingStage::Early
}
}
struct LateProcessor;
impl Plugin for LateProcessor {
fn name(&self) -> &str {
"late-proc"
}
fn version(&self) -> String {
"1.0.0".to_string()
}
fn initialize(&self) -> Result<()> {
Ok(())
}
fn shutdown(&self) -> Result<()> {
Ok(())
}
}
#[async_trait]
impl PostProcessor for LateProcessor {
async fn process(&self, result: &mut ExtractionResult, _config: &ExtractionConfig) -> Result<()> {
let mut order = result
.metadata
.additional
.get("execution_order")
.and_then(|v| v.as_array())
.cloned()
.unwrap_or_default();
order.push(serde_json::json!("late"));
result
.metadata
.additional
.insert(Cow::Borrowed("execution_order"), serde_json::json!(order));
Ok(())
}
fn processing_stage(&self) -> ProcessingStage {
ProcessingStage::Late
}
}
struct OrderValidator;
impl Plugin for OrderValidator {
fn name(&self) -> &str {
"order-validator"
}
fn version(&self) -> String {
"1.0.0".to_string()
}
fn initialize(&self) -> Result<()> {
Ok(())
}
fn shutdown(&self) -> Result<()> {
Ok(())
}
}
#[async_trait]
impl Validator for OrderValidator {
async fn validate(&self, result: &ExtractionResult, _config: &ExtractionConfig) -> Result<()> {
let should_validate = result
.metadata
.additional
.get(VALIDATION_MARKER_KEY)
.and_then(|v| v.as_str())
== Some(ORDER_VALIDATION_MARKER);
if !should_validate {
return Ok(());
}
let order = result
.metadata
.additional
.get("execution_order")
.and_then(|v| v.as_array())
.ok_or_else(|| crate::KreuzbergError::Validation {
message: "No execution order found".to_string(),
source: None,
})?;
if order.len() != 2 {
return Err(crate::KreuzbergError::Validation {
message: format!("Expected 2 processors to run, got {}", order.len()),
source: None,
});
}
if order[0] != "early" || order[1] != "late" {
return Err(crate::KreuzbergError::Validation {
message: format!("Wrong execution order: {:?}", order),
source: None,
});
}
Ok(())
}
}
let pp_registry = crate::plugins::registry::get_post_processor_registry();
let val_registry = crate::plugins::registry::get_validator_registry();
pp_registry.write().shutdown_all().unwrap();
val_registry.write().shutdown_all().unwrap();
clear_processor_cache().unwrap();
{
let mut registry = pp_registry.write();
registry.register(Arc::new(EarlyProcessor)).unwrap();
registry.register(Arc::new(LateProcessor)).unwrap();
}
{
let mut registry = val_registry.write();
registry.register(Arc::new(OrderValidator)).unwrap();
}
// Clear the cache after registering new processors so it rebuilds with the test processors
clear_processor_cache().unwrap();
let doc = make_doc("test", "text/plain");
let config = ExtractionConfig::default();
let processed = run_pipeline(doc, &config).await;
pp_registry.write().shutdown_all().unwrap();
val_registry.write().shutdown_all().unwrap();
clear_processor_cache().unwrap();
assert!(processed.is_ok(), "All processors should run before validator");
}
#[tokio::test]
#[serial]
async fn test_run_pipeline_with_output_format_plain() {
let doc = make_doc("test content", "text/plain");
let config = crate::core::config::ExtractionConfig {
output_format: OutputFormat::Plain,
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
assert_eq!(processed.content, "test content");
assert_eq!(processed.metadata.output_format, Some("plain".to_string()));
}
#[tokio::test]
#[serial]
async fn test_run_pipeline_with_output_format_djot() {
let doc = make_doc("test content", "text/djot");
let config = crate::core::config::ExtractionConfig {
output_format: OutputFormat::Djot,
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
// The content should still be present
assert!(!processed.content.is_empty());
assert_eq!(processed.metadata.output_format, Some("djot".to_string()));
}
#[tokio::test]
#[serial]
async fn test_run_pipeline_with_output_format_html() {
let doc = make_doc("test content", "text/plain");
let config = crate::core::config::ExtractionConfig {
output_format: OutputFormat::Html,
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
// HTML renderer produces semantic tags from InternalDocument
assert!(processed.content.contains("test content"));
assert_eq!(processed.metadata.output_format, Some("html".to_string()));
}
#[tokio::test]
#[serial]
#[cfg(feature = "quality")]
async fn test_nfc_normalization_decomposes_to_composed() {
// NFC normalization should convert decomposed characters to composed form.
// "e\u{0301}" (e + combining acute accent) → "\u{00e9}" (é precomposed)
let doc = make_doc("caf\u{0065}\u{0301}", "text/plain"); // "café" with decomposed é
let config = ExtractionConfig {
postprocessor: Some(crate::core::config::PostProcessorConfig {
enabled: false,
..Default::default()
}),
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
assert_eq!(processed.content, "caf\u{00e9}"); // composed é
assert!(!processed.content.contains('\u{0301}')); // no combining accent
}
#[tokio::test]
#[serial]
#[cfg(feature = "quality")]
async fn test_nfc_normalization_idempotent_on_ascii() {
// NFC on already-normalized/ASCII text should be a no-op.
let doc = make_doc("Hello, world! 123", "text/plain");
let config = ExtractionConfig {
postprocessor: Some(crate::core::config::PostProcessorConfig {
enabled: false,
..Default::default()
}),
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
assert_eq!(processed.content, "Hello, world! 123");
}
#[tokio::test]
#[serial]
#[cfg(feature = "quality")]
async fn test_nfc_normalization_applies_to_page_content() {
// Create a doc with a page-1 element containing decomposed characters
let mut doc = InternalDocument::new("plain");
doc.mime_type = "text/plain".to_string();
doc.push_element(InternalElement::text(ElementKind::Paragraph, "re\u{0301}sume\u{0301}", 0).with_page(1));
let config = ExtractionConfig {
postprocessor: Some(crate::core::config::PostProcessorConfig {
enabled: false,
..Default::default()
}),
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
// Content derived from page element
assert!(processed.content.contains("r\u{00e9}sum\u{00e9}"));
let pages = processed.pages.unwrap();
assert_eq!(pages[0].content, "r\u{00e9}sum\u{00e9}");
}
#[tokio::test]
#[serial]
async fn test_run_pipeline_applies_output_format_last() {
// This test verifies that output format is applied after all other processing
let doc = make_doc("test", "text/plain");
let config = crate::core::config::ExtractionConfig {
output_format: OutputFormat::Djot,
// Disable other processing to ensure pipeline runs cleanly
enable_quality_processing: false,
..Default::default()
};
let processed = run_pipeline(doc, &config).await.unwrap();
// The result should have gone through the pipeline successfully
assert_eq!(processed.metadata.output_format, Some("djot".to_string()));
}
#[tokio::test]
#[serial]
#[cfg(all(feature = "pdf", feature = "chunking"))]
async fn test_chunking_populates_page_numbers_for_pdf() {
use crate::core::config::ChunkingConfig;
let pdf_path =
std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../test_documents/pdf/issue-636-chunk-pages.pdf");
if !pdf_path.exists() {
// Skip if test document not available
return;
}
let pdf_bytes = std::fs::read(&pdf_path).unwrap();
// Configure chunking WITHOUT explicit pages config (the default user scenario)
let config = ExtractionConfig {
chunking: Some(ChunkingConfig {
max_characters: 500,
..Default::default()
}),
..Default::default()
};
let result = crate::core::extractor::extract_bytes(&pdf_bytes, "application/pdf", &config)
.await
.unwrap();
// Chunks should exist
assert!(result.chunks.is_some(), "Chunks should be produced");
let chunks = result.chunks.as_ref().unwrap();
assert!(!chunks.is_empty(), "Should have at least one chunk");
// At least some chunks should have page numbers
let chunks_with_pages = chunks.iter().filter(|c| c.metadata.first_page.is_some()).count();
assert!(
chunks_with_pages > 0,
"At least some chunks should have page numbers, but none do. Total chunks: {}",
chunks.len()
);
}
#[test]
fn test_append_ocr_text_for_pptx_images() {
use crate::types::ExtractedImage;
use crate::types::internal::{ElementKind, InternalDocument, InternalElement};
use std::borrow::Cow;
let mut doc = InternalDocument::new("pptx");
doc.append_ocr_text = true;
doc.elements
.push(InternalElement::text(ElementKind::Paragraph, "Before image.", 0));
doc.elements.push(InternalElement::text(
ElementKind::Paragraph,
"![img](../media/image-1.jpeg)",
0,
));
doc.elements
.push(InternalElement::text(ElementKind::Paragraph, "After image.", 0));
doc.images.push(ExtractedImage {
data: bytes::Bytes::new(),
format: Cow::Borrowed("jpeg"),
image_index: 0,
page_number: Some(1),
width: Some(100),
height: Some(100),
colorspace: None,
bits_per_component: None,
is_mask: false,
description: None,
ocr_result: Some(Box::new(crate::types::ExtractionResult {
content: "OCR text here".to_string(),
mime_type: Cow::Borrowed("text/plain"),
..Default::default()
})),
bounding_box: None,
source_path: None,
image_kind: None,
kind_confidence: None,
cluster_id: None,
});
super::append_embedded_image_ocr_text(&mut doc);
assert_eq!(
doc.elements.len(),
4,
"should have 4 elements (original 3 + 1 OCR paragraph)"
);
assert_eq!(doc.elements[2].text, "OCR text here");
let rendered = crate::rendering::render_markdown(&doc);
assert!(rendered.contains("OCR text here"));
}
#[tokio::test]
#[serial]
async fn test_pdf_run_fallback_not_suppressed_without_images_config() {
// When config.images is None, run_ocr_on_images must default to false so
// the PDF document-level OCR fallback is NOT silently suppressed for
// existing callers that never configured ImageExtractionConfig.
use crate::core::config::ImageExtractionConfig;
let default_no_images = crate::core::config::ExtractionConfig::default();
assert!(
default_no_images.images.is_none(),
"baseline: default config has no images section"
);
let skip_fallback = default_no_images
.images
.as_ref()
.map(|i| i.run_ocr_on_images)
.unwrap_or(false);
assert!(
!skip_fallback,
"RunFallback must NOT be suppressed when config.images is None"
);
let with_images_opted_in = crate::core::config::ExtractionConfig {
images: Some(ImageExtractionConfig {
run_ocr_on_images: true,
..Default::default()
}),
..Default::default()
};
let skip_fallback_opted_in = with_images_opted_in
.images
.as_ref()
.map(|i| i.run_ocr_on_images)
.unwrap_or(false);
assert!(
skip_fallback_opted_in,
"RunFallback must be suppressed when images.run_ocr_on_images=true"
);
}

View File

@@ -0,0 +1,76 @@
//! Environment variable overrides for server configuration.
//!
//! This module provides functionality to override server configuration values
//! using environment variables. All settings can be overridden at runtime.
use crate::{KreuzbergError, Result};
/// Apply environment variable overrides to a ServerConfig.
///
/// Reads the following environment variables and overrides config values if set:
///
/// - `KREUZBERG_HOST` - Server host address
/// - `KREUZBERG_PORT` - Server port number (parsed as u16)
/// - `KREUZBERG_CORS_ORIGINS` - Comma-separated list of allowed origins
/// - `KREUZBERG_MAX_REQUEST_BODY_BYTES` - Max request body size in bytes
/// - `KREUZBERG_MAX_MULTIPART_FIELD_BYTES` - Max multipart field size in bytes
///
/// # Errors
///
/// Returns `KreuzbergError::Validation` if:
/// - `KREUZBERG_PORT` cannot be parsed as u16
/// - `KREUZBERG_MAX_REQUEST_BODY_BYTES` cannot be parsed as usize
/// - `KREUZBERG_MAX_MULTIPART_FIELD_BYTES` cannot be parsed as usize
pub(crate) fn apply_env_overrides(
host: &mut String,
port: &mut u16,
cors_origins: &mut Vec<String>,
max_request_body_bytes: &mut usize,
max_multipart_field_bytes: &mut usize,
) -> Result<()> {
// Host override
if let Ok(env_host) = std::env::var("KREUZBERG_HOST") {
*host = env_host;
}
// Port override
if let Ok(port_str) = std::env::var("KREUZBERG_PORT") {
*port = port_str.parse::<u16>().map_err(|e| {
KreuzbergError::validation(format!(
"KREUZBERG_PORT must be a valid u16 number, got '{}': {}",
port_str, e
))
})?;
}
// CORS origins override (comma-separated)
if let Ok(origins_str) = std::env::var("KREUZBERG_CORS_ORIGINS") {
*cors_origins = origins_str
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
}
// Max request body bytes override
if let Ok(bytes_str) = std::env::var("KREUZBERG_MAX_REQUEST_BODY_BYTES") {
*max_request_body_bytes = bytes_str.parse::<usize>().map_err(|e| {
KreuzbergError::validation(format!(
"KREUZBERG_MAX_REQUEST_BODY_BYTES must be a valid usize, got '{}': {}",
bytes_str, e
))
})?;
}
// Max multipart field bytes override
if let Ok(bytes_str) = std::env::var("KREUZBERG_MAX_MULTIPART_FIELD_BYTES") {
*max_multipart_field_bytes = bytes_str.parse::<usize>().map_err(|e| {
KreuzbergError::validation(format!(
"KREUZBERG_MAX_MULTIPART_FIELD_BYTES must be a valid usize, got '{}': {}",
bytes_str, e
))
})?;
}
Ok(())
}

View File

@@ -0,0 +1,193 @@
//! File loading logic for server configuration.
//!
//! This module provides functionality to load server configuration from various
//! file formats (TOML, YAML, JSON) with support for both flat and nested formats.
use crate::{KreuzbergError, Result};
use serde::Deserialize;
use std::path::Path;
use super::ServerConfig;
/// Load server configuration from a file.
///
/// Automatically detects the file format based on extension:
/// - `.toml` - TOML format
/// - `.yaml` or `.yml` - YAML format
/// - `.json` - JSON format
///
/// This function handles two config file formats:
/// 1. Flat format: Server config at root level
/// 2. Nested format: Server config under `[server]` section (combined with ExtractionConfig)
///
/// # Arguments
///
/// * `path` - Path to the configuration file
///
/// # Errors
///
/// Returns `KreuzbergError::Validation` if:
/// - File doesn't exist or cannot be read
/// - File extension is not recognized
/// - File content is invalid for the detected format
pub(crate) fn from_file(path: impl AsRef<Path>) -> Result<ServerConfig> {
let path = path.as_ref();
let content = std::fs::read_to_string(path)
.map_err(|e| KreuzbergError::validation(format!("Failed to read config file {}: {}", path.display(), e)))?;
let extension = path.extension().and_then(|ext| ext.to_str()).ok_or_else(|| {
KreuzbergError::validation(format!(
"Cannot determine file format: no extension found in {}",
path.display()
))
})?;
let config = match extension.to_lowercase().as_str() {
"toml" => from_toml_str(&content, path)?,
"yaml" | "yml" => from_yaml_str(&content, path)?,
"json" => from_json_str(&content, path)?,
_ => {
return Err(KreuzbergError::validation(format!(
"Unsupported config file format: .{}. Supported formats: .toml, .yaml, .yml, .json",
extension
)));
}
};
Ok(config)
}
/// Load server configuration from a TOML file.
///
/// # Arguments
///
/// * `path` - Path to the TOML file
///
/// # Errors
///
/// Returns `KreuzbergError::Validation` if the file doesn't exist or is invalid TOML.
pub(crate) fn from_toml_file(path: impl AsRef<Path>) -> Result<ServerConfig> {
let path = path.as_ref();
let content = std::fs::read_to_string(path)
.map_err(|e| KreuzbergError::validation(format!("Failed to read config file {}: {}", path.display(), e)))?;
let config: ServerConfig = toml::from_str(&content)
.map_err(|e| KreuzbergError::validation(format!("Invalid TOML in {}: {}", path.display(), e)))?;
Ok(config)
}
/// Load server configuration from a YAML file.
///
/// # Arguments
///
/// * `path` - Path to the YAML file
///
/// # Errors
///
/// Returns `KreuzbergError::Validation` if the file doesn't exist or is invalid YAML.
pub(crate) fn from_yaml_file(path: impl AsRef<Path>) -> Result<ServerConfig> {
let path = path.as_ref();
let content = std::fs::read_to_string(path)
.map_err(|e| KreuzbergError::validation(format!("Failed to read config file {}: {}", path.display(), e)))?;
let config: ServerConfig = serde_yaml_ng::from_str(&content)
.map_err(|e| KreuzbergError::validation(format!("Invalid YAML in {}: {}", path.display(), e)))?;
Ok(config)
}
/// Load server configuration from a JSON file.
///
/// # Arguments
///
/// * `path` - Path to the JSON file
///
/// # Errors
///
/// Returns `KreuzbergError::Validation` if the file doesn't exist or is invalid JSON.
pub(crate) fn from_json_file(path: impl AsRef<Path>) -> Result<ServerConfig> {
let path = path.as_ref();
let content = std::fs::read_to_string(path)
.map_err(|e| KreuzbergError::validation(format!("Failed to read config file {}: {}", path.display(), e)))?;
let config: ServerConfig = serde_json::from_str(&content)
.map_err(|e| KreuzbergError::validation(format!("Invalid JSON in {}: {}", path.display(), e)))?;
Ok(config)
}
// Helper functions for parsing different formats
fn from_toml_str(content: &str, path: &Path) -> Result<ServerConfig> {
// Try nested format first (with [server] section)
#[derive(Deserialize)]
struct RootConfig {
#[serde(default)]
server: Option<ServerConfig>,
}
if let Ok(root) = toml::from_str::<RootConfig>(content) {
if let Some(server) = root.server {
return Ok(server);
} else {
// No [server] section, try flat format
return toml::from_str::<ServerConfig>(content)
.map_err(|e| KreuzbergError::validation(format!("Invalid TOML in {}: {}", path.display(), e)));
}
}
// Fall back to flat format
toml::from_str::<ServerConfig>(content)
.map_err(|e| KreuzbergError::validation(format!("Invalid TOML in {}: {}", path.display(), e)))
}
fn from_yaml_str(content: &str, path: &Path) -> Result<ServerConfig> {
// Try nested format first (with server: section)
#[derive(Deserialize)]
struct RootConfig {
#[serde(default)]
server: Option<ServerConfig>,
}
if let Ok(root) = serde_yaml_ng::from_str::<RootConfig>(content) {
if let Some(server) = root.server {
return Ok(server);
} else {
// No server section, try flat format
return serde_yaml_ng::from_str::<ServerConfig>(content)
.map_err(|e| KreuzbergError::validation(format!("Invalid YAML in {}: {}", path.display(), e)));
}
}
// Fall back to flat format
serde_yaml_ng::from_str::<ServerConfig>(content)
.map_err(|e| KreuzbergError::validation(format!("Invalid YAML in {}: {}", path.display(), e)))
}
fn from_json_str(content: &str, path: &Path) -> Result<ServerConfig> {
// Try nested format first (with "server" key)
#[derive(Deserialize)]
struct RootConfig {
#[serde(default)]
server: Option<ServerConfig>,
}
if let Ok(root) = serde_json::from_str::<RootConfig>(content) {
if let Some(server) = root.server {
return Ok(server);
} else {
// No server key, try flat format
return serde_json::from_str::<ServerConfig>(content)
.map_err(|e| KreuzbergError::validation(format!("Invalid JSON in {}: {}", path.display(), e)));
}
}
// Fall back to flat format
serde_json::from_str::<ServerConfig>(content)
.map_err(|e| KreuzbergError::validation(format!("Invalid JSON in {}: {}", path.display(), e)))
}

View File

@@ -0,0 +1,357 @@
//! Server configuration for the Kreuzberg API.
//!
//! This module provides the `ServerConfig` struct for managing API server settings
//! including host, port, CORS, and upload size limits. Configuration can be loaded
//! from TOML, YAML, or JSON files and can be overridden by environment variables.
//!
//! # Features
//!
//! - **Multi-format support**: Load configuration from TOML, YAML, or JSON files
//! - **Environment overrides**: All settings can be overridden via environment variables
//! - **Sensible defaults**: All fields have reasonable defaults matching current behavior
//! - **Flexible CORS**: Support for all origins (default) or specific origin lists
//!
//! # Example
//!
//! ```rust,no_run
//! use kreuzberg::core::ServerConfig;
//!
//! # fn example() -> kreuzberg::Result<()> {
//! // Create with defaults
//! let mut config = ServerConfig::default();
//!
//! // Or load from file
//! let mut config = ServerConfig::from_file("kreuzberg.toml")?;
//!
//! // Apply environment variable overrides
//! config.apply_env_overrides()?;
//!
//! # Ok(())
//! # }
//! ```
use crate::Result;
use serde::{Deserialize, Serialize};
use std::path::Path;
mod env;
mod loader;
mod validation;
#[cfg(test)]
mod tests;
/// Default host address for API server
const DEFAULT_HOST: &str = "127.0.0.1";
/// Default port for API server
const DEFAULT_PORT: u16 = 8000;
/// Default maximum request body size: 100 MB
const DEFAULT_MAX_REQUEST_BODY_BYTES: usize = 104_857_600;
/// Default maximum multipart field size: 100 MB
const DEFAULT_MAX_MULTIPART_FIELD_BYTES: usize = 104_857_600;
/// API server configuration.
///
/// This struct holds all configuration options for the Kreuzberg API server,
/// including host/port settings, CORS configuration, and upload limits.
///
/// # Defaults
///
/// - `host`: "127.0.0.1" (localhost only)
/// - `port`: 8000
/// - `cors_origins`: empty vector (allows all origins)
/// - `max_request_body_bytes`: 104_857_600 (100 MB)
/// - `max_multipart_field_bytes`: 104_857_600 (100 MB)
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct ServerConfig {
/// Server host address (e.g., "127.0.0.1", "0.0.0.0")
#[serde(default = "default_host")]
pub host: String,
/// Server port number
#[serde(default = "default_port")]
pub port: u16,
/// CORS allowed origins. Empty vector means allow all origins.
///
/// If this is an empty vector, the server will accept requests from any origin.
/// If populated with specific origins (e.g., `"https://example.com"`), only
/// those origins will be allowed.
#[serde(default)]
pub cors_origins: Vec<String>,
/// Maximum size of request body in bytes (default: 100 MB)
#[serde(default = "default_max_request_body_bytes")]
pub max_request_body_bytes: usize,
/// Maximum size of multipart fields in bytes (default: 100 MB)
#[serde(default = "default_max_multipart_field_bytes")]
pub max_multipart_field_bytes: usize,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: default_host(),
port: default_port(),
cors_origins: Vec::new(),
max_request_body_bytes: default_max_request_body_bytes(),
max_multipart_field_bytes: default_max_multipart_field_bytes(),
}
}
}
// Default value functions for serde
fn default_host() -> String {
DEFAULT_HOST.to_string()
}
fn default_port() -> u16 {
DEFAULT_PORT
}
fn default_max_request_body_bytes() -> usize {
DEFAULT_MAX_REQUEST_BODY_BYTES
}
fn default_max_multipart_field_bytes() -> usize {
DEFAULT_MAX_MULTIPART_FIELD_BYTES
}
impl ServerConfig {
/// Create a new `ServerConfig` with default values.
pub fn new() -> Self {
Self::default()
}
/// Get the server listen address (host:port).
///
/// # Example
///
/// ```rust
/// use kreuzberg::core::ServerConfig;
///
/// let config = ServerConfig::default();
/// assert_eq!(config.listen_addr(), "127.0.0.1:8000");
/// ```
pub fn listen_addr(&self) -> String {
format!("{}:{}", self.host, self.port)
}
/// Check if CORS allows all origins.
///
/// Returns `true` if the `cors_origins` vector is empty, meaning all origins
/// are allowed. Returns `false` if specific origins are configured.
///
/// # Example
///
/// ```rust
/// use kreuzberg::core::ServerConfig;
///
/// let mut config = ServerConfig::default();
/// assert!(config.cors_allows_all());
///
/// config.cors_origins.push("https://example.com".to_string());
/// assert!(!config.cors_allows_all());
/// ```
pub fn cors_allows_all(&self) -> bool {
self.cors_origins.is_empty()
}
/// Check if a given origin is allowed by CORS configuration.
///
/// Returns `true` if:
/// - CORS allows all origins (empty origins list), or
/// - The given origin is in the allowed origins list
///
/// # Arguments
///
/// * `origin` - The origin to check (e.g., "https://example.com")
///
/// # Example
///
/// ```rust
/// use kreuzberg::core::ServerConfig;
///
/// let mut config = ServerConfig::default();
/// assert!(config.is_origin_allowed("https://example.com"));
///
/// config.cors_origins.push("https://allowed.com".to_string());
/// assert!(config.is_origin_allowed("https://allowed.com"));
/// assert!(!config.is_origin_allowed("https://denied.com"));
/// ```
pub fn is_origin_allowed(&self, origin: &str) -> bool {
self.cors_origins.is_empty() || self.cors_origins.contains(&origin.to_string())
}
/// Get maximum request body size in megabytes (rounded up).
///
/// # Example
///
/// ```rust
/// use kreuzberg::core::ServerConfig;
///
/// let mut config = ServerConfig::default();
/// assert_eq!(config.max_request_body_mb(), 100);
/// ```
pub fn max_request_body_mb(&self) -> usize {
self.max_request_body_bytes.div_ceil(1_048_576)
}
/// Get maximum multipart field size in megabytes (rounded up).
///
/// # Example
///
/// ```rust
/// use kreuzberg::core::ServerConfig;
///
/// let mut config = ServerConfig::default();
/// assert_eq!(config.max_multipart_field_mb(), 100);
/// ```
pub fn max_multipart_field_mb(&self) -> usize {
self.max_multipart_field_bytes.div_ceil(1_048_576)
}
/// Apply environment variable overrides to the configuration.
///
/// Reads the following environment variables and overrides config values if set:
///
/// - `KREUZBERG_HOST` - Server host address
/// - `KREUZBERG_PORT` - Server port number (parsed as u16)
/// - `KREUZBERG_CORS_ORIGINS` - Comma-separated list of allowed origins
/// - `KREUZBERG_MAX_REQUEST_BODY_BYTES` - Max request body size in bytes
/// - `KREUZBERG_MAX_MULTIPART_FIELD_BYTES` - Max multipart field size in bytes
///
/// # Errors
///
/// Returns `KreuzbergError::Validation` if:
/// - `KREUZBERG_PORT` cannot be parsed as u16
/// - `KREUZBERG_MAX_REQUEST_BODY_BYTES` cannot be parsed as usize
/// - `KREUZBERG_MAX_MULTIPART_FIELD_BYTES` cannot be parsed as usize
///
/// # Example
///
/// ```rust,no_run
/// use kreuzberg::core::ServerConfig;
///
/// # fn example() -> kreuzberg::Result<()> {
/// unsafe {
/// std::env::set_var("KREUZBERG_HOST", "0.0.0.0");
/// std::env::set_var("KREUZBERG_PORT", "3000");
/// }
///
/// let mut config = ServerConfig::default();
/// config.apply_env_overrides()?;
///
/// assert_eq!(config.host, "0.0.0.0");
/// assert_eq!(config.port, 3000);
/// # Ok(())
/// # }
/// ```
#[cfg_attr(alef, alef(skip))]
pub fn apply_env_overrides(&mut self) -> Result<()> {
env::apply_env_overrides(
&mut self.host,
&mut self.port,
&mut self.cors_origins,
&mut self.max_request_body_bytes,
&mut self.max_multipart_field_bytes,
)?;
Ok(())
}
/// Load server configuration from a file.
///
/// Automatically detects the file format based on extension:
/// - `.toml` - TOML format
/// - `.yaml` or `.yml` - YAML format
/// - `.json` - JSON format
///
/// This function handles two config file formats:
/// 1. Flat format: Server config at root level
/// 2. Nested format: Server config under `[server]` section (combined with ExtractionConfig)
///
/// # Arguments
///
/// * `path` - Path to the configuration file
///
/// # Errors
///
/// Returns `KreuzbergError::Validation` if:
/// - File doesn't exist or cannot be read
/// - File extension is not recognized
/// - File content is invalid for the detected format
///
/// # Example
///
/// ```rust,no_run
/// use kreuzberg::core::ServerConfig;
///
/// # fn example() -> kreuzberg::Result<()> {
/// let config = ServerConfig::from_file("kreuzberg.toml")?;
/// # Ok(())
/// # }
/// ```
#[cfg_attr(alef, alef(skip))]
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
loader::from_file(path)
}
/// Load server configuration from a TOML file.
///
/// # Arguments
///
/// * `path` - Path to the TOML file
///
/// # Errors
///
/// Returns `KreuzbergError::Validation` if the file doesn't exist or is invalid TOML.
///
/// # Example
///
/// ```rust,no_run
/// use kreuzberg::core::ServerConfig;
///
/// # fn example() -> kreuzberg::Result<()> {
/// let config = ServerConfig::from_toml_file("kreuzberg.toml")?;
/// # Ok(())
/// # }
/// ```
#[cfg_attr(alef, alef(skip))]
pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self> {
loader::from_toml_file(path)
}
/// Load server configuration from a YAML file.
///
/// # Arguments
///
/// * `path` - Path to the YAML file
///
/// # Errors
///
/// Returns `KreuzbergError::Validation` if the file doesn't exist or is invalid YAML.
#[cfg_attr(alef, alef(skip))]
pub fn from_yaml_file(path: impl AsRef<Path>) -> Result<Self> {
loader::from_yaml_file(path)
}
/// Load server configuration from a JSON file.
///
/// # Arguments
///
/// * `path` - Path to the JSON file
///
/// # Errors
///
/// Returns `KreuzbergError::Validation` if the file doesn't exist or is invalid JSON.
#[cfg_attr(alef, alef(skip))]
pub fn from_json_file(path: impl AsRef<Path>) -> Result<Self> {
loader::from_json_file(path)
}
}

View File

@@ -0,0 +1,87 @@
//! Basic tests for ServerConfig functionality.
use crate::core::ServerConfig;
#[test]
fn test_default_config() {
let config = ServerConfig::default();
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.port, 8000);
assert!(config.cors_origins.is_empty());
assert_eq!(config.max_request_body_bytes, 104_857_600);
assert_eq!(config.max_multipart_field_bytes, 104_857_600);
}
#[test]
fn test_listen_addr() {
let config = ServerConfig::default();
assert_eq!(config.listen_addr(), "127.0.0.1:8000");
}
#[test]
fn test_listen_addr_custom() {
let config = ServerConfig {
host: "0.0.0.0".to_string(),
port: 3000,
..Default::default()
};
assert_eq!(config.listen_addr(), "0.0.0.0:3000");
}
#[test]
fn test_cors_allows_all() {
let mut config = ServerConfig::default();
assert!(config.cors_allows_all());
config.cors_origins.push("https://example.com".to_string());
assert!(!config.cors_allows_all());
}
#[test]
fn test_is_origin_allowed_all() {
let config = ServerConfig::default();
assert!(config.is_origin_allowed("https://example.com"));
assert!(config.is_origin_allowed("https://other.com"));
}
#[test]
fn test_is_origin_allowed_specific() {
let mut config = ServerConfig::default();
config.cors_origins.push("https://allowed.com".to_string());
assert!(config.is_origin_allowed("https://allowed.com"));
assert!(!config.is_origin_allowed("https://denied.com"));
}
#[test]
fn test_max_request_body_mb() {
let config = ServerConfig::default();
assert_eq!(config.max_request_body_mb(), 100);
}
#[test]
fn test_max_multipart_field_mb() {
let config = ServerConfig::default();
assert_eq!(config.max_multipart_field_mb(), 100);
}
#[test]
fn test_max_bytes_to_mb_rounding() {
let mut config = ServerConfig {
max_request_body_bytes: 1_048_576, // 1 MB
..Default::default()
};
assert_eq!(config.max_request_body_mb(), 1);
config.max_request_body_bytes = 1_048_577; // 1 MB + 1 byte
assert_eq!(config.max_request_body_mb(), 2); // Rounds up
}
#[test]
fn test_serde_default_serialization() {
let config = ServerConfig::default();
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("host"));
assert!(json.contains("port"));
}

View File

@@ -0,0 +1,192 @@
//! Tests for environment variable overrides.
#![allow(unsafe_code)]
use crate::core::ServerConfig;
#[serial_test::serial]
#[test]
fn test_apply_env_host_override() {
let original = std::env::var("KREUZBERG_HOST").ok();
unsafe {
std::env::set_var("KREUZBERG_HOST", "192.168.1.1");
}
let mut config = ServerConfig::default();
config.apply_env_overrides().unwrap();
assert_eq!(config.host, "192.168.1.1");
// Cleanup
unsafe {
if let Some(orig) = original {
std::env::set_var("KREUZBERG_HOST", orig);
} else {
std::env::remove_var("KREUZBERG_HOST");
}
}
}
#[serial_test::serial]
#[test]
fn test_apply_env_port_override() {
let original = std::env::var("KREUZBERG_PORT").ok();
unsafe {
std::env::set_var("KREUZBERG_PORT", "5000");
}
let mut config = ServerConfig::default();
config.apply_env_overrides().unwrap();
assert_eq!(config.port, 5000);
// Cleanup
unsafe {
if let Some(orig) = original {
std::env::set_var("KREUZBERG_PORT", orig);
} else {
std::env::remove_var("KREUZBERG_PORT");
}
}
}
#[serial_test::serial]
#[test]
fn test_apply_env_port_invalid() {
let original = std::env::var("KREUZBERG_PORT").ok();
unsafe {
std::env::set_var("KREUZBERG_PORT", "not_a_number");
}
let mut config = ServerConfig::default();
let result = config.apply_env_overrides();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("KREUZBERG_PORT must be a valid u16")
);
// Cleanup
unsafe {
if let Some(orig) = original {
std::env::set_var("KREUZBERG_PORT", orig);
} else {
std::env::remove_var("KREUZBERG_PORT");
}
}
}
#[serial_test::serial]
#[test]
fn test_apply_env_cors_origins_override() {
let original = std::env::var("KREUZBERG_CORS_ORIGINS").ok();
unsafe {
std::env::set_var("KREUZBERG_CORS_ORIGINS", "https://example.com, https://other.com");
}
let mut config = ServerConfig::default();
config.apply_env_overrides().unwrap();
assert_eq!(config.cors_origins.len(), 2);
assert!(config.cors_origins.contains(&"https://example.com".to_string()));
assert!(config.cors_origins.contains(&"https://other.com".to_string()));
// Cleanup
unsafe {
if let Some(orig) = original {
std::env::set_var("KREUZBERG_CORS_ORIGINS", orig);
} else {
std::env::remove_var("KREUZBERG_CORS_ORIGINS");
}
}
}
#[serial_test::serial]
#[test]
fn test_apply_env_max_request_body_bytes_override() {
let original = std::env::var("KREUZBERG_MAX_REQUEST_BODY_BYTES").ok();
unsafe {
std::env::set_var("KREUZBERG_MAX_REQUEST_BODY_BYTES", "52428800");
}
let mut config = ServerConfig::default();
config.apply_env_overrides().unwrap();
assert_eq!(config.max_request_body_bytes, 52_428_800);
// Cleanup
unsafe {
if let Some(orig) = original {
std::env::set_var("KREUZBERG_MAX_REQUEST_BODY_BYTES", orig);
} else {
std::env::remove_var("KREUZBERG_MAX_REQUEST_BODY_BYTES");
}
}
}
#[serial_test::serial]
#[test]
fn test_apply_env_max_multipart_field_bytes_override() {
let original = std::env::var("KREUZBERG_MAX_MULTIPART_FIELD_BYTES").ok();
unsafe {
std::env::set_var("KREUZBERG_MAX_MULTIPART_FIELD_BYTES", "78643200");
}
let mut config = ServerConfig::default();
config.apply_env_overrides().unwrap();
assert_eq!(config.max_multipart_field_bytes, 78_643_200);
// Cleanup
unsafe {
if let Some(orig) = original {
std::env::set_var("KREUZBERG_MAX_MULTIPART_FIELD_BYTES", orig);
} else {
std::env::remove_var("KREUZBERG_MAX_MULTIPART_FIELD_BYTES");
}
}
}
#[serial_test::serial]
#[test]
fn test_apply_env_multiple_overrides() {
let host_orig = std::env::var("KREUZBERG_HOST").ok();
let port_orig = std::env::var("KREUZBERG_PORT").ok();
let cors_orig = std::env::var("KREUZBERG_CORS_ORIGINS").ok();
unsafe {
std::env::set_var("KREUZBERG_HOST", "0.0.0.0");
std::env::set_var("KREUZBERG_PORT", "4000");
std::env::set_var("KREUZBERG_CORS_ORIGINS", "https://api.example.com");
}
let mut config = ServerConfig::default();
config.apply_env_overrides().unwrap();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 4000);
assert_eq!(config.cors_origins.len(), 1);
assert_eq!(config.cors_origins[0], "https://api.example.com");
// Cleanup
unsafe {
if let Some(orig) = host_orig {
std::env::set_var("KREUZBERG_HOST", orig);
} else {
std::env::remove_var("KREUZBERG_HOST");
}
if let Some(orig) = port_orig {
std::env::set_var("KREUZBERG_PORT", orig);
} else {
std::env::remove_var("KREUZBERG_PORT");
}
if let Some(orig) = cors_orig {
std::env::set_var("KREUZBERG_CORS_ORIGINS", orig);
} else {
std::env::remove_var("KREUZBERG_CORS_ORIGINS");
}
}
}

View File

@@ -0,0 +1,321 @@
//! Tests for file loading functionality.
use crate::core::ServerConfig;
use std::fs;
use tempfile::tempdir;
#[test]
fn test_from_toml_file() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("server.toml");
fs::write(
&config_path,
r#"
host = "0.0.0.0"
port = 3000
cors_origins = ["https://example.com", "https://other.com"]
max_request_body_bytes = 50000000
max_multipart_field_bytes = 75000000
"#,
)
.unwrap();
let config = ServerConfig::from_toml_file(&config_path).unwrap();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 3000);
assert_eq!(config.cors_origins.len(), 2);
assert_eq!(config.max_request_body_bytes, 50_000_000);
assert_eq!(config.max_multipart_field_bytes, 75_000_000);
}
#[test]
fn test_from_yaml_file() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("server.yaml");
fs::write(
&config_path,
r#"
host: 0.0.0.0
port: 3000
cors_origins:
- https://example.com
- https://other.com
max_request_body_bytes: 50000000
max_multipart_field_bytes: 75000000
"#,
)
.unwrap();
let config = ServerConfig::from_yaml_file(&config_path).unwrap();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 3000);
assert_eq!(config.cors_origins.len(), 2);
assert_eq!(config.max_request_body_bytes, 50_000_000);
assert_eq!(config.max_multipart_field_bytes, 75_000_000);
}
#[test]
fn test_from_json_file() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("server.json");
fs::write(
&config_path,
r#"{
"host": "0.0.0.0",
"port": 3000,
"cors_origins": ["https://example.com", "https://other.com"],
"max_request_body_bytes": 50000000,
"max_multipart_field_bytes": 75000000
}
"#,
)
.unwrap();
let config = ServerConfig::from_json_file(&config_path).unwrap();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 3000);
assert_eq!(config.cors_origins.len(), 2);
assert_eq!(config.max_request_body_bytes, 50_000_000);
assert_eq!(config.max_multipart_field_bytes, 75_000_000);
}
#[test]
fn test_from_file_auto_detects_toml() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("server.toml");
fs::write(
&config_path,
r#"
host = "0.0.0.0"
port = 3000
"#,
)
.unwrap();
let config = ServerConfig::from_file(&config_path).unwrap();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 3000);
}
#[test]
fn test_from_file_auto_detects_yaml() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("server.yaml");
fs::write(
&config_path,
r#"
host: 0.0.0.0
port: 3000
"#,
)
.unwrap();
let config = ServerConfig::from_file(&config_path).unwrap();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 3000);
}
#[test]
fn test_from_file_auto_detects_json() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("server.json");
fs::write(&config_path, r#"{"host": "0.0.0.0", "port": 3000}"#).unwrap();
let config = ServerConfig::from_file(&config_path).unwrap();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 3000);
}
#[test]
fn test_from_file_unsupported_extension() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("server.txt");
fs::write(&config_path, "host = 0.0.0.0").unwrap();
let result = ServerConfig::from_file(&config_path);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Unsupported config file format")
);
}
#[test]
fn test_from_file_no_extension() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("server");
fs::write(&config_path, "host = 0.0.0.0").unwrap();
let result = ServerConfig::from_file(&config_path);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("no extension found"));
}
#[test]
fn test_cors_origins_empty_in_toml() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("server.toml");
fs::write(
&config_path,
r#"
host = "127.0.0.1"
port = 8000
"#,
)
.unwrap();
let config = ServerConfig::from_toml_file(&config_path).unwrap();
assert!(config.cors_origins.is_empty());
assert!(config.cors_allows_all());
}
#[test]
fn test_full_configuration_toml() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("server.toml");
fs::write(
&config_path,
r#"
host = "192.168.1.100"
port = 9000
cors_origins = ["https://app1.com", "https://app2.com", "https://app3.com"]
max_request_body_bytes = 200000000
max_multipart_field_bytes = 150000000
"#,
)
.unwrap();
let config = ServerConfig::from_toml_file(&config_path).unwrap();
assert_eq!(config.host, "192.168.1.100");
assert_eq!(config.port, 9000);
assert_eq!(config.listen_addr(), "192.168.1.100:9000");
assert_eq!(config.cors_origins.len(), 3);
assert!(!config.cors_allows_all());
assert!(config.is_origin_allowed("https://app1.com"));
assert!(!config.is_origin_allowed("https://app4.com"));
assert_eq!(config.max_request_body_bytes, 200_000_000);
assert_eq!(config.max_multipart_field_bytes, 150_000_000);
assert_eq!(config.max_request_body_mb(), 191);
assert_eq!(config.max_multipart_field_mb(), 144);
}
#[test]
fn test_from_file_with_nested_server_section_toml() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("kreuzberg.toml");
// Config file with [server] section and other sections (like ExtractionConfig)
fs::write(
&config_path,
r#"
[server]
host = "0.0.0.0"
port = 3000
cors_origins = ["https://example.com"]
[ocr]
backend = "tesseract"
language = "eng"
[extraction]
enabled = true
"#,
)
.unwrap();
let config = ServerConfig::from_file(&config_path).unwrap();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 3000);
assert_eq!(config.cors_origins.len(), 1);
assert_eq!(config.cors_origins[0], "https://example.com");
}
#[test]
fn test_from_file_with_nested_server_section_yaml() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("kreuzberg.yaml");
// Config file with server: section and other sections
fs::write(
&config_path,
r#"
server:
host: 0.0.0.0
port: 4000
cors_origins:
- https://example.com
ocr:
backend: tesseract
language: eng
"#,
)
.unwrap();
let config = ServerConfig::from_file(&config_path).unwrap();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 4000);
assert_eq!(config.cors_origins.len(), 1);
}
#[test]
fn test_from_file_with_nested_server_section_json() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("kreuzberg.json");
// Config file with "server" key and other sections
fs::write(
&config_path,
r#"
{
"server": {
"host": "0.0.0.0",
"port": 5000,
"cors_origins": ["https://example.com"]
},
"ocr": {
"backend": "tesseract",
"language": "eng"
}
}
"#,
)
.unwrap();
let config = ServerConfig::from_file(&config_path).unwrap();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 5000);
assert_eq!(config.cors_origins.len(), 1);
}
#[test]
fn test_from_file_flat_format_still_works() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("server.toml");
// Old flat format without [server] section
fs::write(
&config_path,
r#"
host = "192.168.1.1"
port = 6000
"#,
)
.unwrap();
let config = ServerConfig::from_file(&config_path).unwrap();
assert_eq!(config.host, "192.168.1.1");
assert_eq!(config.port, 6000);
}

View File

@@ -0,0 +1,5 @@
//! Tests for server configuration module.
mod basic_tests;
mod env_tests;
mod file_loading_tests;

View File

@@ -0,0 +1,4 @@
//! Validation and normalization for server configuration.
//!
//! This module provides functionality to validate and normalize server configuration
//! values.